346 lines
10 KiB
Go
346 lines
10 KiB
Go
package oauth
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"gitea.local/admin/hspguard/internal/repository"
|
|
"gitea.local/admin/hspguard/internal/types"
|
|
"gitea.local/admin/hspguard/internal/util"
|
|
"gitea.local/admin/hspguard/internal/web"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
func (h *OAuthHandler) signApiTokens(user *repository.User, apiService *repository.ApiService, nonce *string) (*types.SignedToken, *types.SignedToken, *types.SignedToken, error) {
|
|
accessExpiresIn := 15 * time.Minute
|
|
accessExpiresAt := time.Now().Add(accessExpiresIn)
|
|
accessJTI := uuid.New()
|
|
|
|
accessClaims := types.ApiClaims{
|
|
Permissions: []string{},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: h.cfg.Uri,
|
|
Subject: apiService.ClientID,
|
|
Audience: jwt.ClaimStrings{apiService.ClientID},
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(accessExpiresAt),
|
|
ID: accessJTI.String(),
|
|
},
|
|
}
|
|
|
|
access, err := util.SignJwtToken(accessClaims, h.cfg.Jwt.PrivateKey)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
var roles = []string{"user"}
|
|
|
|
if user.IsAdmin {
|
|
roles = append(roles, "admin")
|
|
}
|
|
|
|
idExpiresIn := 15 * time.Minute
|
|
idExpiresAt := time.Now().Add(idExpiresIn)
|
|
idJTI := uuid.New()
|
|
|
|
idClaims := types.IdTokenClaims{
|
|
Email: user.Email,
|
|
EmailVerified: user.EmailVerified,
|
|
Name: user.FullName,
|
|
Picture: user.ProfilePicture,
|
|
Nonce: nonce,
|
|
Roles: roles,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: h.cfg.Uri,
|
|
Subject: user.ID.String(),
|
|
Audience: jwt.ClaimStrings{apiService.ClientID},
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(idExpiresAt),
|
|
ID: idJTI.String(),
|
|
},
|
|
}
|
|
|
|
idToken, err := util.SignJwtToken(idClaims, h.cfg.Jwt.PrivateKey)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
refreshExpiresIn := 24 * time.Hour
|
|
refreshExpiresAt := time.Now().Add(refreshExpiresIn)
|
|
refreshJTI := uuid.New()
|
|
|
|
refreshClaims := types.ApiRefreshClaims{
|
|
UserID: user.ID.String(),
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: h.cfg.Uri,
|
|
Subject: apiService.ClientID,
|
|
Audience: jwt.ClaimStrings{apiService.ClientID},
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(refreshExpiresAt),
|
|
ID: refreshJTI.String(),
|
|
},
|
|
}
|
|
|
|
refresh, err := util.SignJwtToken(refreshClaims, h.cfg.Jwt.PrivateKey)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
return types.NewSignedToken(idToken, idExpiresAt, idJTI), types.NewSignedToken(access, accessExpiresAt, accessJTI), types.NewSignedToken(refresh, refreshExpiresAt, refreshJTI), nil
|
|
}
|
|
|
|
func (h *OAuthHandler) tokenEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
log.Println("[OAUTH] New request to token endpoint")
|
|
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" || !strings.HasPrefix(authHeader, "Basic ") {
|
|
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Decode credentials
|
|
payload, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader, "Basic "))
|
|
if err != nil {
|
|
http.Error(w, "Invalid auth encoding", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var clientId string
|
|
var clientSecret string
|
|
|
|
parts := strings.SplitN(string(payload), ":", 2)
|
|
if len(parts) != 2 {
|
|
http.Error(w, "Unauthorized", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
clientId = parts[0]
|
|
clientSecret = parts[1]
|
|
|
|
log.Printf("Some client is trying to exchange code with id: %s and secret: %s\n", clientId, clientSecret)
|
|
|
|
// Parse the form data
|
|
err = r.ParseForm()
|
|
if err != nil {
|
|
http.Error(w, "Failed to parse form", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
grantType := r.FormValue("grant_type")
|
|
|
|
log.Println("DEBUG: Verifying target oauth client before proceeding...")
|
|
|
|
if _, err := h.verifyOAuthClient(r.Context(), &VerifyOAuthClientParams{
|
|
ClientID: clientId,
|
|
RedirectURI: nil,
|
|
State: "",
|
|
Scopes: nil,
|
|
}); err != nil {
|
|
web.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
switch grantType {
|
|
case "authorization_code":
|
|
redirectUri := r.FormValue("redirect_uri")
|
|
log.Printf("Redirect URI is %s\n", redirectUri)
|
|
|
|
code := r.FormValue("code")
|
|
|
|
fmt.Printf("Code received: %s\n", code)
|
|
|
|
codeSession, err := h.cache.GetAuthCode(r.Context(), code)
|
|
if err != nil {
|
|
log.Printf("ERR: Failed to find session under the code %s: %v\n", code, err)
|
|
web.Error(w, "no session found under this auth code", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
log.Printf("DEBUG: Fetched code session: %#v\n", codeSession)
|
|
|
|
apiService, err := h.repo.GetApiServiceCID(r.Context(), codeSession.ClientID)
|
|
if err != nil {
|
|
log.Printf("ERR: Could not find API service with client %s: %v\n", codeSession.ClientID, err)
|
|
web.Error(w, "service is not registered", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if codeSession.ClientID != clientId {
|
|
web.Error(w, "invalid auth", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
user, err := h.repo.FindUserId(r.Context(), uuid.MustParse(codeSession.UserID))
|
|
if err != nil {
|
|
web.Error(w, "requested user not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
id, access, refresh, err := h.signApiTokens(&user, &apiService, &codeSession.Nonce)
|
|
if err != nil {
|
|
log.Println("ERR: Failed to sign api tokens:", err)
|
|
web.Error(w, "failed to sign tokens", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("DEBUG: Created api tokens: %v\n\n%v\n\n%v\n", id.ID.String(), access.ID.String(), refresh.ID.String())
|
|
|
|
userId, err := uuid.Parse(codeSession.UserID)
|
|
if err != nil {
|
|
log.Printf("ERR: Failed to parse user '%s' uuid: %v\n", codeSession.UserID, err)
|
|
web.Error(w, "failed to sign tokens", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
ipAddr := util.GetClientIP(r)
|
|
ua := r.UserAgent()
|
|
|
|
session, err := h.repo.CreateServiceSession(r.Context(), repository.CreateServiceSessionParams{
|
|
ServiceID: apiService.ID,
|
|
ClientID: apiService.ClientID,
|
|
UserID: &userId,
|
|
ExpiresAt: &refresh.ExpiresAt,
|
|
LastActive: nil,
|
|
IpAddress: &ipAddr,
|
|
UserAgent: &ua,
|
|
AccessTokenID: &access.ID,
|
|
RefreshTokenID: &refresh.ID,
|
|
})
|
|
if err != nil {
|
|
log.Printf("ERR: Failed to create new service session: %v\n", err)
|
|
web.Error(w, "failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("INFO: Service session created for '%s' client_id with '%s' id\n", apiService.ClientID, session.ID.String())
|
|
|
|
type Response struct {
|
|
IdToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
AccessToken string `json:"access_token"`
|
|
Email string `json:"email"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn float64 `json:"expires_in"`
|
|
// TODO: add scope (RFC 8693 $2)
|
|
}
|
|
|
|
response := Response{
|
|
IdToken: id.Token,
|
|
TokenType: "Bearer",
|
|
AccessToken: access.Token,
|
|
RefreshToken: refresh.Token,
|
|
ExpiresIn: math.Ceil(access.ExpiresAt.Sub(time.Now()).Seconds()),
|
|
Email: user.Email,
|
|
}
|
|
|
|
log.Printf("sending following response: %#v\n", response)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
encoder := json.NewEncoder(w)
|
|
if err := encoder.Encode(response); err != nil {
|
|
web.Error(w, "failed to encode response", http.StatusInternalServerError)
|
|
}
|
|
case "refresh_token":
|
|
refreshToken := r.FormValue("refresh_token")
|
|
|
|
var claims types.ApiRefreshClaims
|
|
|
|
token, err := util.VerifyToken(refreshToken, h.cfg.Jwt.PublicKey, &claims)
|
|
if err != nil || !token.Valid {
|
|
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
expire, err := claims.GetExpirationTime()
|
|
if err != nil {
|
|
web.Error(w, "failed to retrieve enough info from the token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if time.Now().After(expire.Time) {
|
|
web.Error(w, "token is expired", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
refreshJTI, err := uuid.Parse(claims.ID)
|
|
if err != nil {
|
|
log.Printf("ERR: Failed to parse refresh token JTI as uuid: %v\n", err)
|
|
web.Error(w, "failed to refresh token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
session, err := h.repo.GetServiceSessionByRefreshJTI(r.Context(), &refreshJTI)
|
|
if err != nil {
|
|
log.Printf("ERR: Failed to find session by '%s' refresh jti: %v\n", refreshJTI.String(), err)
|
|
web.Error(w, "session invalid", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if !session.IsActive {
|
|
log.Printf("INFO: Session with id '%s' is not active", session.ID.String())
|
|
web.Error(w, "session ended", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
userID, err := uuid.Parse(claims.UserID)
|
|
if err != nil {
|
|
web.Error(w, "invalid user credentials in refresh token", http.StatusBadRequest)
|
|
return
|
|
}
|
|
user, err := h.repo.FindUserId(r.Context(), userID)
|
|
|
|
apiService, err := h.repo.GetApiServiceCID(r.Context(), claims.Subject)
|
|
if err != nil {
|
|
web.Error(w, "api service is not registered", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
id, access, refresh, err := h.signApiTokens(&user, &apiService, nil)
|
|
|
|
if err := h.repo.UpdateServiceSessionTokens(r.Context(), repository.UpdateServiceSessionTokensParams{
|
|
ID: session.ID,
|
|
AccessTokenID: &access.ID,
|
|
RefreshTokenID: &refresh.ID,
|
|
ExpiresAt: &refresh.ExpiresAt,
|
|
}); err != nil {
|
|
log.Printf("ERR: Failed to update service session with '%s' id: %v\n", session.ID.String(), err)
|
|
web.Error(w, "failed to update session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
type Response struct {
|
|
IdToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn float64 `json:"expires_in"`
|
|
}
|
|
|
|
response := Response{
|
|
IdToken: id.Token,
|
|
TokenType: "Bearer",
|
|
AccessToken: access.Token,
|
|
RefreshToken: refresh.Token,
|
|
ExpiresIn: math.Ceil(access.ExpiresAt.Sub(time.Now()).Seconds()),
|
|
}
|
|
|
|
log.Printf("DEBUG: refresh - sending following response: %#v\n", response)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
encoder := json.NewEncoder(w)
|
|
if err := encoder.Encode(response); err != nil {
|
|
web.Error(w, "failed to encode response", http.StatusInternalServerError)
|
|
}
|
|
default:
|
|
web.Error(w, "unsupported grant type", http.StatusBadRequest)
|
|
}
|
|
}
|