diff --git a/internal/oauth/authorize.go b/internal/oauth/authorize.go index 9896cfd..38bce3f 100644 --- a/internal/oauth/authorize.go +++ b/internal/oauth/authorize.go @@ -3,6 +3,7 @@ package oauth import ( "fmt" "net/http" + "strings" "gitea.local/admin/hspguard/internal/web" ) @@ -27,10 +28,13 @@ func (h *OAuthHandler) AuthorizeClient(w http.ResponseWriter, r *http.Request) { return } - if uri, err := h.verifyOAuthClient(w, r, &VerifyOAuthClientParams{ + scopes := strings.Split(strings.TrimSpace(r.URL.Query().Get("scope")), " ") + + if uri, err := h.verifyOAuthClient(r.Context(), &VerifyOAuthClientParams{ ClientID: clientId, - RedirectURI: redirectUri, + RedirectURI: &redirectUri, State: state, + Scopes: &scopes, }); err != nil { http.Redirect(w, r, uri, http.StatusFound) return diff --git a/internal/oauth/client.go b/internal/oauth/client.go index 80227c8..14c660e 100644 --- a/internal/oauth/client.go +++ b/internal/oauth/client.go @@ -1,22 +1,23 @@ package oauth import ( + "context" "fmt" - "net/http" "slices" "strings" ) type VerifyOAuthClientParams struct { - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - State string `json:"state"` + ClientID string `json:"client_id"` + RedirectURI *string `json:"redirect_uri"` + State string `json:"state"` + Scopes *[]string `json:"scopes"` } -func (h *OAuthHandler) verifyOAuthClient(w http.ResponseWriter, r *http.Request, params *VerifyOAuthClientParams) (string, error) { - client, err := h.repo.GetApiServiceCID(r.Context(), params.ClientID) +func (h *OAuthHandler) verifyOAuthClient(ctx context.Context, params *VerifyOAuthClientParams) (string, error) { + client, err := h.repo.GetApiServiceCID(ctx, params.ClientID) if err != nil { - uri := fmt.Sprintf("%s?error=access_denied&error_description=Service+not+authorized", params.RedirectURI) + uri := fmt.Sprintf("%s?error=access_denied&error_description=Service+not+authorized", *params.RedirectURI) if params.State != "" { uri += "&state=" + params.State } @@ -24,32 +25,33 @@ func (h *OAuthHandler) verifyOAuthClient(w http.ResponseWriter, r *http.Request, } if !client.IsActive { - uri := fmt.Sprintf("%s?error=temporarily_unavailable&error_description=Service+not+active", params.RedirectURI) + uri := fmt.Sprintf("%s?error=temporarily_unavailable&error_description=Service+not+active", *params.RedirectURI) if params.State != "" { uri += "&state=" + params.State } return uri, fmt.Errorf("target oauth service with client id '%s' is not available", client.ClientID) } - scopes := strings.SplitSeq(strings.TrimSpace(r.URL.Query().Get("scope")), " ") - - for scope := range scopes { - if !slices.Contains(client.Scopes, scope) { - uri := fmt.Sprintf("%s?error=invalid_scope&error_description=Scope+%s+is+not+allowed", params.RedirectURI, strings.ReplaceAll(scope, " ", "+")) - if params.State != "" { - uri += "&state=" + params.State + if params.Scopes != nil { + for _, scope := range *params.Scopes { + if !slices.Contains(client.Scopes, scope) { + uri := fmt.Sprintf("%s?error=invalid_scope&error_description=Scope+%s+is+not+allowed", *params.RedirectURI, strings.ReplaceAll(scope, " ", "+")) + if params.State != "" { + uri += "&state=" + params.State + } + return uri, fmt.Errorf("unallowed scope '%s' requested", scope) } - return uri, fmt.Errorf("unallowed scope '%s' requested", scope) } } - if !slices.Contains(client.RedirectUris, params.RedirectURI) { - uri := fmt.Sprintf("%s?error=invalid_request&error_description=Redirect+URI+is+not+allowed", params.RedirectURI) - if params.State != "" { - uri += "&state=" + params.State + if params.RedirectURI != nil { + if !slices.Contains(client.RedirectUris, *params.RedirectURI) { + uri := fmt.Sprintf("%s?error=invalid_request&error_description=Redirect+URI+is+not+allowed", *params.RedirectURI) + if params.State != "" { + uri += "&state=" + params.State + } + return uri, fmt.Errorf("redirect uri '%s' is unallowed", *params.RedirectURI) } - http.Redirect(w, r, uri, http.StatusFound) - return uri, fmt.Errorf("redirect uri '%s' is unallowed", params.RedirectURI) } return "", nil diff --git a/internal/oauth/code.go b/internal/oauth/code.go index 4d47a74..f175bbc 100644 --- a/internal/oauth/code.go +++ b/internal/oauth/code.go @@ -39,10 +39,11 @@ func (h *OAuthHandler) getAuthCode(w http.ResponseWriter, r *http.Request) { return } - if _, err := h.verifyOAuthClient(w, r, &VerifyOAuthClientParams{ - ClientID: "", - RedirectURI: "", + if _, err := h.verifyOAuthClient(r.Context(), &VerifyOAuthClientParams{ + ClientID: req.ClientID, + RedirectURI: nil, State: "", + Scopes: nil, }); err != nil { web.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 215126b..d11e16c 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -152,12 +152,24 @@ func (h *OAuthHandler) tokenEndpoint(w http.ResponseWriter, r *http.Request) { } grantType := r.FormValue("grant_type") - redirectUri := r.FormValue("redirect_uri") - log.Printf("Redirect URI is %s\n", redirectUri) + log.Println("DEBUG: Verifying target oauth client before proceeding...") + + if _, err := h.verifyOAuthClient(r.Context(), &VerifyOAuthClientParams{ + ClientID: clientId, + RedirectURI: nil, + State: "", + Scopes: nil, + }); err != nil { + web.Error(w, err.Error(), http.StatusInternalServerError) + return + } switch grantType { case "authorization_code": + redirectUri := r.FormValue("redirect_uri") + log.Printf("Redirect URI is %s\n", redirectUri) + code := r.FormValue("code") fmt.Printf("Code received: %s\n", code)