diff --git a/main.go b/main.go index f310684..14c2425 100644 --- a/main.go +++ b/main.go @@ -81,98 +81,85 @@ func main() { func authMiddleware(allowedGroups []string) gin.HandlerFunc { return func(c *gin.Context) { var groups []string - var JwksURL = os.Getenv("jwksurl") + JwksURL := os.Getenv("jwksurl") tokenString := c.GetHeader("Authorization") if tokenString != "" { tokenString = strings.TrimPrefix(tokenString, "Bearer ") + } else { + tokenString = c.GetHeader("X-authentik-jwt") + } - ctx, cancel := context.WithCancel(context.Background()) + if tokenString == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "No Token in header", + }) + return + } - options := keyfunc.Options{ - Ctx: ctx, - RefreshErrorHandler: func(err error) { - log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error()) - }, - RefreshInterval: time.Hour, - RefreshRateLimit: time.Minute * 5, - RefreshTimeout: time.Second * 10, - RefreshUnknownKID: true, - } + ctx, cancel := context.WithCancel(context.Background()) - jwks, err := keyfunc.Get(JwksURL, options) - if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Failed to create JWKS: " + err.Error(), - }) - cancel() - jwks.EndBackground() - return - } - - token, err := jwt.Parse(tokenString, jwks.Keyfunc) - if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": err.Error(), - }) - cancel() - jwks.EndBackground() - return - } - - if !token.Valid { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Invalid Token: " + err.Error(), - }) - cancel() - jwks.EndBackground() - return - } + options := keyfunc.Options{ + Ctx: ctx, + RefreshErrorHandler: func(err error) { + log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error()) + }, + RefreshInterval: time.Hour, + RefreshRateLimit: time.Minute * 5, + RefreshTimeout: time.Second * 10, + RefreshUnknownKID: true, + } + jwks, err := keyfunc.Get(JwksURL, options) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Failed to create JWKS: " + err.Error(), + }) cancel() jwks.EndBackground() + return + } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Invalid authorization token claims", - }) - return - } + token, err := jwt.Parse(tokenString, jwks.Keyfunc) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": err.Error(), + }) + cancel() + jwks.EndBackground() + return + } - groupsClaim, ok := claims["groups"].([]interface{}) - if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "message": "Missing or invalid groups claim in the authorization token", - }) - return - } + if !token.Valid { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Invalid Token: " + err.Error(), + }) + cancel() + jwks.EndBackground() + return + } - for _, group := range groupsClaim { - if groupName, ok := group.(string); ok { - groups = append(groups, groupName) - } - } - } else { - if groupsenv != "" { - groups = strings.Split(groupsenv, ",") - } else { - groupsHeader := c.GetHeader("X-Authentik-Groups") + cancel() + jwks.EndBackground() - requestHeaders := c.Request.Header - for key, values := range requestHeaders { - for _, value := range values { - println(key + ": " + value) - } - } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Invalid authorization token claims", + }) + return + } - // Dump response headers - responseHeaders := c.Writer.Header() - for key, values := range responseHeaders { - for _, value := range values { - println(key + ": " + value) - } - } - groups = strings.Split(groupsHeader, "|") + groupsClaim, ok := claims["groups"].([]interface{}) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "message": "Missing or invalid groups claim in the authorization token", + }) + return + } + + for _, group := range groupsClaim { + if groupName, ok := group.(string); ok { + groups = append(groups, groupName) } } @@ -197,7 +184,6 @@ func authMiddleware(allowedGroups []string) gin.HandlerFunc { return } - // Call the next handler c.Next() } }