From d50bd6c4f5eaa66ce1fe2eaaeb8668a79ce26f29 Mon Sep 17 00:00:00 2001 From: LandaMm Date: Mon, 19 May 2025 16:36:19 +0200 Subject: [PATCH] feat: auth middleware --- cmd/hspguard/api/api.go | 9 ++++++- internal/middleware/auth.go | 53 +++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 internal/middleware/auth.go diff --git a/cmd/hspguard/api/api.go b/cmd/hspguard/api/api.go index 3e5f684..813f545 100644 --- a/cmd/hspguard/api/api.go +++ b/cmd/hspguard/api/api.go @@ -7,6 +7,8 @@ import ( "os" "path/filepath" + "gitea.local/admin/hspguard/internal/auth" + imiddleware "gitea.local/admin/hspguard/internal/middleware" "gitea.local/admin/hspguard/internal/repository" "gitea.local/admin/hspguard/internal/user" "github.com/go-chi/chi/v5" @@ -34,8 +36,13 @@ func (s *APIServer) Run() error { FileServer(router, "/static", staticDir) router.Route("/api/v1", func(r chi.Router) { + r.Use(imiddleware.WithSkipper(imiddleware.AuthMiddleware, "/api/v1/login", "/api/v1/register")) + userHandler := user.NewUserHandler(s.repo) userHandler.RegisterRoutes(router, r) + + authHandler := auth.NewAuthHandler(s.repo) + authHandler.RegisterRoutes(router, r) }) // Handle unknown routes @@ -48,7 +55,7 @@ func (s *APIServer) Run() error { router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusMethodNotAllowed) - fmt.Fprint(w, `{"error": "405 - method not allowed"}`) + _, _ = fmt.Fprint(w, `{"error": "405 - method not allowed"}`) }) log.Println("Listening on", s.addr) diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..619d4ab --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "strings" + + "gitea.local/admin/hspguard/internal/auth" + "gitea.local/admin/hspguard/internal/types" + "gitea.local/admin/hspguard/internal/web" +) + +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + web.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + parts := strings.Split(authHeader, "Bearer ") + if len(parts) != 2 { + web.Error(w, "invalid auth header format", http.StatusUnauthorized) + return + } + + tokenStr := parts[1] + token, userClaims, err := auth.VerifyToken(tokenStr) + if err != nil || !token.Valid { + http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), types.UserIdKey, userClaims.UserID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func WithSkipper(mw func(http.Handler) http.Handler, excludedPaths ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, path := range excludedPaths { + if strings.HasPrefix(r.URL.Path, path) { + next.ServeHTTP(w, r) + return + } + } + mw(next).ServeHTTP(w, r) + }) + } +} +