From 8e22a3ac054ac6659ad113337cbb77926ec26b24 Mon Sep 17 00:00:00 2001 From: LandaMm Date: Sun, 25 May 2025 16:52:10 +0200 Subject: [PATCH] fix: cfg access --- cmd/hspguard/api/api.go | 6 ++-- internal/auth/routes.go | 5 ++-- internal/middleware/auth.go | 56 +++++++++++++++++++++---------------- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/cmd/hspguard/api/api.go b/cmd/hspguard/api/api.go index b80abaa..4fddfed 100644 --- a/cmd/hspguard/api/api.go +++ b/cmd/hspguard/api/api.go @@ -30,6 +30,7 @@ func NewAPIServer(addr string, db *repository.Queries, minio *storage.FileStorag addr: addr, repo: db, storage: minio, + cfg: cfg, } } @@ -44,12 +45,13 @@ func (s *APIServer) Run() error { oauthHandler := oauth.NewOAuthHandler(s.repo, s.cfg) router.Route("/api/v1", func(r chi.Router) { - r.Use(imiddleware.WithSkipper(imiddleware.AuthMiddleware(s.cfg), "/api/v1/login", "/api/v1/register", "/api/v1/oauth/token")) + am := imiddleware.New(s.cfg) + r.Use(imiddleware.WithSkipper(am.Runner, "/api/v1/login", "/api/v1/register", "/api/v1/oauth/token")) userHandler := user.NewUserHandler(s.repo, s.storage) userHandler.RegisterRoutes(r) - authHandler := auth.NewAuthHandler(s.repo) + authHandler := auth.NewAuthHandler(s.repo, s.cfg) authHandler.RegisterRoutes(r) oauthHandler.RegisterRoutes(r) diff --git a/internal/auth/routes.go b/internal/auth/routes.go index 56f849b..b29a01a 100644 --- a/internal/auth/routes.go +++ b/internal/auth/routes.go @@ -21,9 +21,10 @@ type AuthHandler struct { cfg *config.AppConfig } -func NewAuthHandler(repo *repository.Queries) *AuthHandler { +func NewAuthHandler(repo *repository.Queries, cfg *config.AppConfig) *AuthHandler { return &AuthHandler{ - repo: repo, + repo, + cfg, } } diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index dc69989..513631c 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -12,34 +12,42 @@ import ( "gitea.local/admin/hspguard/internal/web" ) -func AuthMiddleware(cfg *config.AppConfig) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - web.Error(w, "unauthorized", http.StatusUnauthorized) - return - } +type AuthMiddleware struct { + cfg *config.AppConfig +} - parts := strings.Split(authHeader, "Bearer ") - if len(parts) != 2 { - web.Error(w, "invalid auth header format", http.StatusUnauthorized) - return - } - - tokenStr := parts[1] - token, userClaims, err := auth.VerifyToken(tokenStr, cfg.Jwt.PublicKey) - if err != nil || !token.Valid { - http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) - return - } - - ctx := context.WithValue(r.Context(), types.UserIdKey, userClaims.Subject) - next.ServeHTTP(w, r.WithContext(ctx)) - }) +func New(cfg *config.AppConfig) *AuthMiddleware { + return &AuthMiddleware{ + cfg, } } +func (m *AuthMiddleware) Runner(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + web.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + parts := strings.Split(authHeader, "Bearer ") + if len(parts) != 2 { + web.Error(w, "invalid auth header format", http.StatusUnauthorized) + return + } + + tokenStr := parts[1] + token, userClaims, err := auth.VerifyToken(tokenStr, m.cfg.Jwt.PublicKey) + if err != nil || !token.Valid { + http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), types.UserIdKey, userClaims.Subject) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func WithSkipper(mw func(http.Handler) http.Handler, excludedPaths ...string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {