diff --git a/main.go b/main.go index 22b49ab..1e379f3 100644 --- a/main.go +++ b/main.go @@ -45,13 +45,13 @@ func main() { } openai := v1.Group("/openai") { - //openai.Use(authMiddleware()) + openai.Use(authMiddleware("test")) openai.GET("general", c.GeneralOpenAI) openai.GET("travelagent", c.TravelAgentOpenAI) } unraid := v1.Group("/unraid") { - //unraid.Use(authMiddleware()) + unraid.Use(authMiddleware("grafana")) unraid.GET("powerusage", c.UnraidPowerUsage) } } @@ -59,34 +59,34 @@ func main() { r.Run(":8080") } -func authMiddleware() gin.HandlerFunc { +func authMiddleware(allowedGroups []string) gin.HandlerFunc { return func(c *gin.Context) { - // Get the authorization header from the request - authHeader := c.GetHeader("Authorization") + // Get the user groups from the request headers + groupsHeader := c.GetHeader("X-Forwarded-Groups") - // Check if the authorization header is missing or doesn't start with "Bearer" - if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { + // Split the groups header value into individual groups + groups := strings.Split(groupsHeader, ",") + + // Check if the user belongs to any of the allowed groups + isAllowed := false + for _, allowedGroup := range allowedGroups { + for _, group := range groups { + if group == allowedGroup { + isAllowed = true + break + } + } + if isAllowed { + break + } + } + + // If the user is not in any of the allowed groups, respond with unauthorized access + if !isAllowed { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "Unauthorized access"}) return } - // Extract the token from the authorization header - tokenString := strings.TrimPrefix(authHeader, "Bearer ") - - // Parse the token and validate its signature - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte(os.Getenv("jwtToken")), nil - }) - - // Check if there was an error parsing the token or if it is not valid - if err != nil || !token.Valid { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "Unauthorized access"}) - return - } - - // Add the token to the request context - c.Set("token", token) - // Call the next handler c.Next() }