package oauth import ( "encoding/base64" "encoding/json" "fmt" "log" "math" "net/http" "strings" "time" "gitea.local/admin/hspguard/internal/repository" "gitea.local/admin/hspguard/internal/types" "gitea.local/admin/hspguard/internal/util" "gitea.local/admin/hspguard/internal/web" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) func (h *OAuthHandler) signApiTokens(user *repository.User, apiService *repository.ApiService, nonce *string) (*types.SignedToken, *types.SignedToken, *types.SignedToken, error) { accessExpiresIn := 15 * time.Minute accessExpiresAt := time.Now().Add(accessExpiresIn) accessJTI := uuid.New() accessClaims := types.ApiClaims{ Permissions: []string{}, RegisteredClaims: jwt.RegisteredClaims{ Issuer: h.cfg.Uri, Subject: apiService.ClientID, Audience: jwt.ClaimStrings{apiService.ClientID}, IssuedAt: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(accessExpiresAt), ID: accessJTI.String(), }, } access, err := util.SignJwtToken(accessClaims, h.cfg.Jwt.PrivateKey) if err != nil { return nil, nil, nil, err } var roles = []string{"user"} if user.IsAdmin { roles = append(roles, "admin") } idExpiresIn := 15 * time.Minute idExpiresAt := time.Now().Add(idExpiresIn) idJTI := uuid.New() idClaims := types.IdTokenClaims{ Email: user.Email, EmailVerified: user.EmailVerified, Name: user.FullName, Picture: user.ProfilePicture, Nonce: nonce, Roles: roles, RegisteredClaims: jwt.RegisteredClaims{ Issuer: h.cfg.Uri, Subject: user.ID.String(), Audience: jwt.ClaimStrings{apiService.ClientID}, IssuedAt: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(idExpiresAt), ID: idJTI.String(), }, } idToken, err := util.SignJwtToken(idClaims, h.cfg.Jwt.PrivateKey) if err != nil { return nil, nil, nil, err } refreshExpiresIn := 24 * time.Hour refreshExpiresAt := time.Now().Add(refreshExpiresIn) refreshJTI := uuid.New() refreshClaims := types.ApiRefreshClaims{ UserID: user.ID.String(), RegisteredClaims: jwt.RegisteredClaims{ Issuer: h.cfg.Uri, Subject: apiService.ClientID, Audience: jwt.ClaimStrings{apiService.ClientID}, IssuedAt: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(refreshExpiresAt), ID: refreshJTI.String(), }, } refresh, err := util.SignJwtToken(refreshClaims, h.cfg.Jwt.PrivateKey) if err != nil { return nil, nil, nil, err } return types.NewSignedToken(idToken, idExpiresAt, idJTI), types.NewSignedToken(access, accessExpiresAt, accessJTI), types.NewSignedToken(refresh, refreshExpiresAt, refreshJTI), nil } func (h *OAuthHandler) tokenEndpoint(w http.ResponseWriter, r *http.Request) { log.Println("[OAUTH] New request to token endpoint") authHeader := r.Header.Get("Authorization") if authHeader == "" || !strings.HasPrefix(authHeader, "Basic ") { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // Decode credentials payload, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader, "Basic ")) if err != nil { http.Error(w, "Invalid auth encoding", http.StatusBadRequest) return } var clientId string var clientSecret string parts := strings.SplitN(string(payload), ":", 2) if len(parts) != 2 { http.Error(w, "Unauthorized", http.StatusForbidden) return } clientId = parts[0] clientSecret = parts[1] log.Printf("Some client is trying to exchange code with id: %s and secret: %s\n", clientId, clientSecret) // Parse the form data err = r.ParseForm() if err != nil { http.Error(w, "Failed to parse form", http.StatusBadRequest) return } grantType := r.FormValue("grant_type") log.Println("DEBUG: Verifying target oauth client before proceeding...") if _, err := h.verifyOAuthClient(r.Context(), &VerifyOAuthClientParams{ ClientID: clientId, RedirectURI: nil, State: "", Scopes: nil, }); err != nil { web.Error(w, err.Error(), http.StatusInternalServerError) return } switch grantType { case "authorization_code": redirectUri := r.FormValue("redirect_uri") log.Printf("Redirect URI is %s\n", redirectUri) code := r.FormValue("code") fmt.Printf("Code received: %s\n", code) codeSession, err := h.cache.GetAuthCode(r.Context(), code) if err != nil { log.Printf("ERR: Failed to find session under the code %s: %v\n", code, err) web.Error(w, "no session found under this auth code", http.StatusNotFound) return } log.Printf("DEBUG: Fetched code session: %#v\n", codeSession) apiService, err := h.repo.GetApiServiceCID(r.Context(), codeSession.ClientID) if err != nil { log.Printf("ERR: Could not find API service with client %s: %v\n", codeSession.ClientID, err) web.Error(w, "service is not registered", http.StatusForbidden) return } if codeSession.ClientID != clientId { web.Error(w, "invalid auth", http.StatusUnauthorized) return } user, err := h.repo.FindUserId(r.Context(), uuid.MustParse(codeSession.UserID)) if err != nil { web.Error(w, "requested user not found", http.StatusNotFound) return } id, access, refresh, err := h.signApiTokens(&user, &apiService, &codeSession.Nonce) if err != nil { log.Println("ERR: Failed to sign api tokens:", err) web.Error(w, "failed to sign tokens", http.StatusInternalServerError) return } log.Printf("DEBUG: Created api tokens: %v\n\n%v\n\n%v\n", id.ID.String(), access.ID.String(), refresh.ID.String()) userId, err := uuid.Parse(codeSession.UserID) if err != nil { log.Printf("ERR: Failed to parse user '%s' uuid: %v\n", codeSession.UserID, err) web.Error(w, "failed to sign tokens", http.StatusInternalServerError) return } ipAddr := util.GetClientIP(r) ua := r.UserAgent() session, err := h.repo.CreateServiceSession(r.Context(), repository.CreateServiceSessionParams{ ServiceID: apiService.ID, ClientID: apiService.ClientID, UserID: &userId, ExpiresAt: &refresh.ExpiresAt, LastActive: nil, IpAddress: &ipAddr, UserAgent: &ua, AccessTokenID: &access.ID, RefreshTokenID: &refresh.ID, }) if err != nil { log.Printf("ERR: Failed to create new service session: %v\n", err) web.Error(w, "failed to create session", http.StatusInternalServerError) return } log.Printf("INFO: Service session created for '%s' client_id with '%s' id\n", apiService.ClientID, session.ID.String()) type Response struct { IdToken string `json:"id_token"` TokenType string `json:"token_type"` AccessToken string `json:"access_token"` Email string `json:"email"` RefreshToken string `json:"refresh_token"` ExpiresIn float64 `json:"expires_in"` // TODO: add scope (RFC 8693 $2) } response := Response{ IdToken: id.Token, TokenType: "Bearer", AccessToken: access.Token, RefreshToken: refresh.Token, ExpiresIn: math.Ceil(access.ExpiresAt.Sub(time.Now()).Seconds()), Email: user.Email, } log.Printf("sending following response: %#v\n", response) w.Header().Set("Content-Type", "application/json") encoder := json.NewEncoder(w) if err := encoder.Encode(response); err != nil { web.Error(w, "failed to encode response", http.StatusInternalServerError) } case "refresh_token": refreshToken := r.FormValue("refresh_token") var claims types.ApiRefreshClaims token, err := util.VerifyToken(refreshToken, h.cfg.Jwt.PublicKey, &claims) if err != nil || !token.Valid { http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) return } expire, err := claims.GetExpirationTime() if err != nil { web.Error(w, "failed to retrieve enough info from the token", http.StatusInternalServerError) return } if time.Now().After(expire.Time) { web.Error(w, "token is expired", http.StatusUnauthorized) return } refreshJTI, err := uuid.Parse(claims.ID) if err != nil { log.Printf("ERR: Failed to parse refresh token JTI as uuid: %v\n", err) web.Error(w, "failed to refresh token", http.StatusInternalServerError) return } session, err := h.repo.GetServiceSessionByRefreshJTI(r.Context(), &refreshJTI) if err != nil { log.Printf("ERR: Failed to find session by '%s' refresh jti: %v\n", refreshJTI.String(), err) web.Error(w, "session invalid", http.StatusUnauthorized) return } if !session.IsActive { log.Printf("INFO: Session with id '%s' is not active", session.ID.String()) web.Error(w, "session ended", http.StatusUnauthorized) return } userID, err := uuid.Parse(claims.UserID) if err != nil { web.Error(w, "invalid user credentials in refresh token", http.StatusBadRequest) return } user, err := h.repo.FindUserId(r.Context(), userID) apiService, err := h.repo.GetApiServiceCID(r.Context(), claims.Subject) if err != nil { web.Error(w, "api service is not registered", http.StatusUnauthorized) return } id, access, refresh, err := h.signApiTokens(&user, &apiService, nil) if err := h.repo.UpdateServiceSessionTokens(r.Context(), repository.UpdateServiceSessionTokensParams{ ID: session.ID, AccessTokenID: &access.ID, RefreshTokenID: &refresh.ID, ExpiresAt: &refresh.ExpiresAt, }); err != nil { log.Printf("ERR: Failed to update service session with '%s' id: %v\n", session.ID.String(), err) web.Error(w, "failed to update session", http.StatusInternalServerError) return } type Response struct { IdToken string `json:"id_token"` TokenType string `json:"token_type"` AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn float64 `json:"expires_in"` } response := Response{ IdToken: id.Token, TokenType: "Bearer", AccessToken: access.Token, RefreshToken: refresh.Token, ExpiresIn: math.Ceil(access.ExpiresAt.Sub(time.Now()).Seconds()), } log.Printf("DEBUG: refresh - sending following response: %#v\n", response) w.Header().Set("Content-Type", "application/json") encoder := json.NewEncoder(w) if err := encoder.Encode(response); err != nil { web.Error(w, "failed to encode response", http.StatusInternalServerError) } default: web.Error(w, "unsupported grant type", http.StatusBadRequest) } }