diff --git a/main.go b/main.go index 65b87e0..31a4fee 100644 --- a/main.go +++ b/main.go @@ -85,82 +85,98 @@ func authMiddleware(allowedGroups []string) gin.HandlerFunc { tokenString := c.GetHeader("Authorization") if tokenString != "" { tokenString = strings.TrimPrefix(tokenString, "Bearer ") - } else { - tokenString = c.GetHeader("X-authentik-jwt") - } - if tokenString == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "No Token in header", - }) - return - } + ctx, cancel := context.WithCancel(context.Background()) - ctx, cancel := context.WithCancel(context.Background()) - - options := keyfunc.Options{ - Ctx: ctx, - RefreshErrorHandler: func(err error) { - log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error()) - }, - RefreshInterval: time.Hour, - RefreshRateLimit: time.Minute * 5, - RefreshTimeout: time.Second * 10, - RefreshUnknownKID: true, - } - - jwks, err := keyfunc.Get(JwksURL, options) - if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Failed to create JWKS: " + err.Error(), - }) - cancel() - jwks.EndBackground() - return - } - - token, err := jwt.Parse(tokenString, jwks.Keyfunc) - if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": err.Error(), - }) - cancel() - jwks.EndBackground() - return - } - - if !token.Valid { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Invalid Token: " + err.Error(), - }) - cancel() - jwks.EndBackground() - return - } - - cancel() - jwks.EndBackground() - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Invalid authorization token claims", - }) - return - } - - groupsClaim, ok := claims["groups"].([]interface{}) - if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Missing or invalid groups claim in the authorization token", - }) - return - } - - for _, group := range groupsClaim { - if groupName, ok := group.(string); ok { - groups = append(groups, groupName) + options := keyfunc.Options{ + Ctx: ctx, + RefreshErrorHandler: func(err error) { + log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error()) + }, + RefreshInterval: time.Hour, + RefreshRateLimit: time.Minute * 5, + RefreshTimeout: time.Second * 10, + RefreshUnknownKID: true, } + + jwks, err := keyfunc.Get(JwksURL, options) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Failed to create JWKS: " + err.Error(), + }) + cancel() + jwks.EndBackground() + return + } + + token, err := jwt.Parse(tokenString, jwks.Keyfunc) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": err.Error(), + }) + cancel() + jwks.EndBackground() + return + } + + if !token.Valid { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Invalid Token: " + err.Error(), + }) + cancel() + jwks.EndBackground() + return + } + + cancel() + jwks.EndBackground() + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Invalid authorization token claims", + }) + return + } + + groupsClaim, ok := claims["groups"].([]interface{}) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Missing or invalid groups claim in the authorization token", + }) + return + } + + for _, group := range groupsClaim { + if groupName, ok := group.(string); ok { + groups = append(groups, groupName) + } + } + } else { + groupsHeader := c.GetHeader("X-Authentik-Groups") + + if groupsHeader == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "No token or groups detected", + }) + return + } + + requestHeaders := c.Request.Header + for key, values := range requestHeaders { + for _, value := range values { + println(key + ": " + value) + } + } + + // Dump response headers + responseHeaders := c.Writer.Header() + for key, values := range responseHeaders { + for _, value := range values { + println(key + ": " + value) + } + } + groups = strings.Split(groupsHeader, "|") } isAllowed := false