Compare commits

...

13 Commits

16 changed files with 347 additions and 83 deletions

View File

@ -44,7 +44,7 @@ func (s *APIServer) Run() error {
// staticDir := http.Dir(filepath.Join(workDir, "static"))
// FileServer(router, "/static", staticDir)
oauthHandler := oauth.NewOAuthHandler(s.repo, s.cfg)
oauthHandler := oauth.NewOAuthHandler(s.repo, s.cache, s.cfg)
router.Route("/api/v1", func(r chi.Router) {
userHandler := user.NewUserHandler(s.repo, s.storage, s.cfg)

View File

@ -7,6 +7,7 @@ import (
"strings"
"time"
"gitea.local/admin/hspguard/internal/types"
"gitea.local/admin/hspguard/internal/util"
"gitea.local/admin/hspguard/internal/web"
"github.com/google/uuid"
@ -26,7 +27,9 @@ func (h *AuthHandler) refreshToken(w http.ResponseWriter, r *http.Request) {
}
tokenStr := parts[1]
token, userClaims, err := util.VerifyToken(tokenStr, h.cfg.Jwt.PublicKey)
var userClaims types.UserClaims
token, err := util.VerifyToken(tokenStr, h.cfg.Jwt.PublicKey, &userClaims)
if err != nil || !token.Valid {
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
return

48
internal/cache/mod.go vendored
View File

@ -1,6 +1,8 @@
package cache
import (
"encoding/json"
"fmt"
"log"
"time"
@ -27,10 +29,56 @@ func NewClient(cfg *config.AppConfig) *Client {
}
}
type OAuthCode struct {
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
Nonce string `json:"nonce"`
}
type SaveAuthCodeParams struct {
AuthCode string
UserID string
ClientID string
Nonce string
}
func (c *Client) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd {
return c.rClient.Set(ctx, key, value, expiration)
}
func (c *Client) SaveAuthCode(ctx context.Context, params *SaveAuthCodeParams) error {
code := OAuthCode{
ClientID: params.ClientID,
UserID: params.UserID,
Nonce: params.Nonce,
}
row, err := json.Marshal(&code)
if err != nil {
return err
}
return c.Set(ctx, fmt.Sprintf("oauth.%s", params.AuthCode), string(row), 5*time.Minute).Err()
}
func (c *Client) GetAuthCode(ctx context.Context, authCode string) (*OAuthCode, error) {
row, err := c.Get(ctx, fmt.Sprintf("oauth.%s", authCode)).Result()
if err != nil {
return nil, err
}
if len(row) == 0 {
return nil, fmt.Errorf("no auth params found under %s", authCode)
}
var parsed OAuthCode
if err := json.Unmarshal([]byte(row), &parsed); err != nil {
return nil, err
}
return &parsed, nil
}
func (c *Client) Get(ctx context.Context, key string) *redis.StringCmd {
return c.rClient.Get(ctx, key)
}

View File

@ -37,7 +37,9 @@ func (m *AuthMiddleware) Runner(next http.Handler) http.Handler {
}
tokenStr := parts[1]
token, userClaims, err := util.VerifyToken(tokenStr, m.cfg.Jwt.PublicKey)
var userClaims types.UserClaims
token, err := util.VerifyToken(tokenStr, m.cfg.Jwt.PublicKey, &userClaims)
if err != nil || !token.Valid {
web.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
return

View File

@ -61,5 +61,14 @@ func (h *OAuthHandler) AuthorizeClient(w http.ResponseWriter, r *http.Request) {
}
}
if !slices.Contains(client.RedirectUris, redirectUri) {
uri := fmt.Sprintf("%s?error=invalid_request&error_description=Redirect+URI+is+not+allowed", redirectUri)
if state != "" {
uri += "&state=" + state
}
http.Redirect(w, r, uri, http.StatusFound)
return
}
http.Redirect(w, r, fmt.Sprintf("/auth?%s", r.URL.Query().Encode()), http.StatusFound)
}

View File

@ -1,10 +1,13 @@
package oauth
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
"gitea.local/admin/hspguard/internal/cache"
"gitea.local/admin/hspguard/internal/util"
"gitea.local/admin/hspguard/internal/web"
"github.com/google/uuid"
@ -24,7 +27,8 @@ func (h *OAuthHandler) getAuthCode(w http.ResponseWriter, r *http.Request) {
}
type Request struct {
Nonce string `json:"nonce"`
Nonce string `json:"nonce"`
ClientID string `json:"client_id"`
}
var req Request
@ -35,7 +39,29 @@ func (h *OAuthHandler) getAuthCode(w http.ResponseWriter, r *http.Request) {
return
}
// TODO: Create real authorization code
buf := make([]byte, 32)
_, err = rand.Read(buf)
if err != nil {
log.Println("ERR: Failed to generate auth code:", err)
web.Error(w, "failed to create authorization code", http.StatusInternalServerError)
return
}
authCode := base64.RawURLEncoding.EncodeToString(buf)
params := cache.SaveAuthCodeParams{
AuthCode: authCode,
UserID: user.ID.String(),
ClientID: req.ClientID,
Nonce: req.Nonce,
}
log.Printf("DEBUG: Saving auth code session with params: %#v\n", params)
if err := h.cache.SaveAuthCode(r.Context(), &params); err != nil {
log.Println("ERR: Failed to save auth code in redis:", err)
web.Error(w, "failed to generate auth code", http.StatusInternalServerError)
return
}
type Response struct {
Code string `json:"code"`
@ -46,7 +72,7 @@ func (h *OAuthHandler) getAuthCode(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := encoder.Encode(Response{
Code: fmt.Sprintf("%s,%s", user.ID.String(), req.Nonce),
Code: authCode,
}); err != nil {
web.Error(w, "failed to encode response", http.StatusInternalServerError)
}

View File

@ -1,6 +1,7 @@
package oauth
import (
"gitea.local/admin/hspguard/internal/cache"
"gitea.local/admin/hspguard/internal/config"
imiddleware "gitea.local/admin/hspguard/internal/middleware"
"gitea.local/admin/hspguard/internal/repository"
@ -8,13 +9,15 @@ import (
)
type OAuthHandler struct {
repo *repository.Queries
cfg *config.AppConfig
repo *repository.Queries
cache *cache.Client
cfg *config.AppConfig
}
func NewOAuthHandler(repo *repository.Queries, cfg *config.AppConfig) *OAuthHandler {
func NewOAuthHandler(repo *repository.Queries, cache *cache.Client, cfg *config.AppConfig) *OAuthHandler {
return &OAuthHandler{
repo,
cache,
cfg,
}
}

View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
"gitea.local/admin/hspguard/internal/repository"
"gitea.local/admin/hspguard/internal/types"
"gitea.local/admin/hspguard/internal/util"
"gitea.local/admin/hspguard/internal/web"
@ -16,6 +17,102 @@ import (
"github.com/google/uuid"
)
type ApiToken struct {
Token string
Expiration float64
}
type ApiTokens struct {
ID ApiToken
Access ApiToken
Refresh ApiToken
}
func (h *OAuthHandler) signApiTokens(user *repository.User, apiService *repository.ApiService, nonce *string) (*ApiTokens, error) {
accessExpiresIn := 15 * time.Minute
accessExpiresAt := time.Now().Add(accessExpiresIn)
accessClaims := types.ApiClaims{
Permissions: []string{},
RegisteredClaims: jwt.RegisteredClaims{
Issuer: h.cfg.Uri,
Subject: apiService.ClientID,
Audience: jwt.ClaimStrings{apiService.ClientID},
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(accessExpiresAt),
},
}
access, err := util.SignJwtToken(accessClaims, h.cfg.Jwt.PrivateKey)
if err != nil {
return nil, err
}
var roles = []string{"user"}
if user.IsAdmin {
roles = append(roles, "admin")
}
idExpiresIn := 15 * time.Minute
idExpiresAt := time.Now().Add(idExpiresIn)
idClaims := types.IdTokenClaims{
Email: user.Email,
EmailVerified: user.EmailVerified,
Name: user.FullName,
Picture: user.ProfilePicture,
Nonce: nonce,
Roles: roles,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: h.cfg.Uri,
Subject: user.ID.String(),
Audience: jwt.ClaimStrings{apiService.ClientID},
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(idExpiresAt),
},
}
idToken, err := util.SignJwtToken(idClaims, h.cfg.Jwt.PrivateKey)
if err != nil {
return nil, err
}
refreshExpiresIn := 24 * time.Hour
refreshExpiresAt := time.Now().Add(refreshExpiresIn)
refreshClaims := types.ApiRefreshClaims{
UserID: user.ID.String(),
RegisteredClaims: jwt.RegisteredClaims{
Issuer: h.cfg.Uri,
Subject: apiService.ClientID,
Audience: jwt.ClaimStrings{apiService.ClientID},
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(refreshExpiresAt),
},
}
refresh, err := util.SignJwtToken(refreshClaims, h.cfg.Jwt.PrivateKey)
if err != nil {
return nil, err
}
return &ApiTokens{
ID: ApiToken{
Token: idToken,
Expiration: idExpiresIn.Seconds(),
},
Access: ApiToken{
Token: access,
Expiration: accessExpiresIn.Seconds(),
},
Refresh: ApiToken{
Token: refresh,
Expiration: refreshExpiresIn.Seconds(),
},
}, nil
}
func (h *OAuthHandler) tokenEndpoint(w http.ResponseWriter, r *http.Request) {
log.Println("[OAUTH] New request to token endpoint")
@ -65,65 +162,121 @@ func (h *OAuthHandler) tokenEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Code received: %s\n", code)
// TODO: Verify code from another db table
nonce := strings.Split(code, ",")[1]
session, err := h.cache.GetAuthCode(r.Context(), code)
if err != nil {
log.Printf("ERR: Failed to find session under the code %s: %v\n", code, err)
web.Error(w, "no session found under this auth code", http.StatusNotFound)
return
}
userId := strings.Split(code, ",")[0]
log.Printf("DEBUG: Fetched code session: %#v\n", session)
user, err := h.repo.FindUserId(r.Context(), uuid.MustParse(userId))
apiService, err := h.repo.GetApiServiceCID(r.Context(), session.ClientID)
if err != nil {
log.Printf("ERR: Could not find API service with client %s: %v\n", session.ClientID, err)
web.Error(w, "service is not registered", http.StatusForbidden)
return
}
if session.ClientID != clientId {
web.Error(w, "invalid auth", http.StatusUnauthorized)
return
}
user, err := h.repo.FindUserId(r.Context(), uuid.MustParse(session.UserID))
if err != nil {
web.Error(w, "requested user not found", http.StatusNotFound)
return
}
var roles = []string{"user"}
if user.IsAdmin {
roles = append(roles, "admin")
}
claims := types.ApiClaims{
Email: user.Email,
// TODO:
EmailVerified: true,
Name: user.FullName,
Picture: user.ProfilePicture,
Nonce: nonce,
Roles: roles,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: h.cfg.Uri,
// TODO: use dedicated API id that is in local DB and bind to user there
Subject: user.ID.String(),
Audience: jwt.ClaimStrings{clientId},
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(15 * time.Minute)),
},
}
idToken, err := util.SignJwtToken(claims, h.cfg.Jwt.PrivateKey)
tokens, err := h.signApiTokens(&user, &apiService, &session.Nonce)
if err != nil {
web.Error(w, "failed to sign id token", http.StatusInternalServerError)
log.Println("ERR: Failed to sign api tokens:", err)
web.Error(w, "failed to sign tokens", http.StatusInternalServerError)
return
}
type Response struct {
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
Email string `json:"email"`
// TODO: add expires_in, refresh_token, scope (RFC 8693 $2)
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
Email string `json:"email"`
RefreshToken string `json:"refresh_token"`
ExpiresIn float64 `json:"expires_in"`
// TODO: add scope (RFC 8693 $2)
}
response := Response{
IdToken: idToken,
TokenType: "Bearer",
// FIXME:
AccessToken: idToken,
Email: user.Email,
IdToken: tokens.ID.Token,
TokenType: "Bearer",
AccessToken: tokens.Access.Token,
RefreshToken: tokens.Refresh.Token,
ExpiresIn: tokens.Access.Expiration,
Email: user.Email,
}
log.Printf("sending following response: %#v\n", response)
w.Header().Set("Content-Type", "application/json")
encoder := json.NewEncoder(w)
if err := encoder.Encode(response); err != nil {
web.Error(w, "failed to encode response", http.StatusInternalServerError)
}
case "refresh_token":
refreshToken := r.FormValue("refresh_token")
var claims types.ApiRefreshClaims
token, err := util.VerifyToken(refreshToken, h.cfg.Jwt.PublicKey, &claims)
if err != nil || !token.Valid {
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
return
}
expire, err := claims.GetExpirationTime()
if err != nil {
web.Error(w, "failed to retrieve enough info from the token", http.StatusInternalServerError)
return
}
if time.Now().After(expire.Time) {
web.Error(w, "token is expired", http.StatusUnauthorized)
return
}
userID, err := uuid.Parse(claims.UserID)
if err != nil {
web.Error(w, "invalid user credentials in refresh token", http.StatusBadRequest)
return
}
user, err := h.repo.FindUserId(r.Context(), userID)
apiService, err := h.repo.GetApiServiceCID(r.Context(), claims.Subject)
if err != nil {
web.Error(w, "api service is not registered", http.StatusUnauthorized)
return
}
tokens, err := h.signApiTokens(&user, &apiService, nil)
type Response struct {
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn float64 `json:"expires_in"`
}
response := Response{
IdToken: tokens.ID.Token,
TokenType: "Bearer",
AccessToken: tokens.Access.Token,
RefreshToken: tokens.Refresh.Token,
ExpiresIn: tokens.Access.Expiration,
}
log.Printf("DEBUG: refresh - sending following response: %#v\n", response)
w.Header().Set("Content-Type", "application/json")
encoder := json.NewEncoder(w)
if err := encoder.Encode(response); err != nil {

View File

@ -8,13 +8,26 @@ type UserClaims struct {
jwt.RegisteredClaims
}
type ApiClaims struct {
type IdTokenClaims struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
Picture *string `json:"picture"`
Nonce string `json:"nonce"`
Nonce *string `json:"nonce"`
Roles []string `json:"roles"`
// TODO: add given_name, family_name, locale...
jwt.RegisteredClaims
}
type ApiClaims struct {
// FIXME: correct permissions
Permissions []string `json:"permissions"`
jwt.RegisteredClaims
// Subject = ClientID
}
type ApiRefreshClaims struct {
UserID string `json:"user_id"`
jwt.RegisteredClaims
// Subject = ClientID
}

View File

@ -6,7 +6,6 @@ import (
"encoding/base64"
"fmt"
"gitea.local/admin/hspguard/internal/types"
"github.com/golang-jwt/jwt/v5"
)
@ -57,13 +56,12 @@ func SignJwtToken(claims jwt.Claims, key string) (string, error) {
return s, nil
}
func VerifyToken(token string, key string) (*jwt.Token, *types.UserClaims, error) {
func VerifyToken(token string, key string, claims jwt.Claims) (*jwt.Token, error) {
publicKey, err := ParseBase64PublicKey(key)
if err != nil {
return nil, nil, err
return nil, err
}
claims := &types.UserClaims{}
parsed, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
@ -72,12 +70,12 @@ func VerifyToken(token string, key string) (*jwt.Token, *types.UserClaims, error
})
if err != nil {
return nil, nil, fmt.Errorf("invalid token: %w", err)
return nil, fmt.Errorf("invalid token: %w", err)
}
if !parsed.Valid {
return nil, nil, fmt.Errorf("token is not valid")
return nil, fmt.Errorf("token is not valid")
}
return parsed, claims, nil
return parsed, nil
}

View File

@ -4,10 +4,14 @@ export interface CodeResponse {
code: string;
}
export const codeApi = async (accessToken: string, nonce: string) => {
export const codeApi = async (
accessToken: string,
nonce: string,
clientId: string,
) => {
const response = await axios.post(
"/api/v1/oauth/code",
{ nonce },
{ nonce, client_id: clientId },
{
headers: {
"Content-Type": "application/json",

View File

@ -54,7 +54,11 @@ const VerificationLayout: FC = () => {
}
}, [navigate, redirect, step]);
if (step === "email" && !location.pathname.startsWith("/verify/email")) {
if (
step === "email" &&
!location.pathname.startsWith("/verify/email") &&
location.pathname.replace(/\/$/i, "") !== "/verify"
) {
return <Navigate to="/verify/email" />;
}

View File

@ -2,7 +2,7 @@ import { Button } from "@/components/ui/button";
import Avatar from "@/feature/Avatar";
import { useAuth } from "@/store/auth";
import { useVerify } from "@/store/verify";
import type { FC } from "react";
import { type FC } from "react";
const VerifyReviewPage: FC = () => {
const profile = useAuth((s) => s.profile);
@ -38,7 +38,7 @@ const VerifyReviewPage: FC = () => {
disabled={verifying}
onClick={finishVerify}
>
Back Home
Finish
</Button>
</div>
);

View File

@ -1,12 +1,9 @@
import { Button } from "@/components/ui/button";
import { useAuth } from "@/store/auth";
import { ArrowRight } from "lucide-react";
import type { FC } from "react";
import { Link } from "react-router";
const VerifyStartPage: FC = () => {
const profile = useAuth((s) => s.profile);
return (
<div className="flex flex-col items-center justify-center gap-5 w-full h-screen px-4 sm:px-0 sm:max-w-xl sm:h-auto text-center">
<img src="/icon.png" className="w-16 h-16" alt="icon" />

View File

@ -38,9 +38,9 @@ export const useOAuth = create<OAuthState>((set, get) => ({
},
selectSession: async (token) => {
const { active, redirectURI, nonce, state } = get();
const { active, redirectURI, nonce, state, clientID } = get();
if (active && redirectURI) {
const codeResponse = await codeApi(token, nonce);
const codeResponse = await codeApi(token, nonce, clientID);
const params = new URLSearchParams({
code: codeResponse.code,

View File

@ -13,28 +13,25 @@ export type VerifyStep = "email" | "avatar" | "review";
export interface IVerifyState {
step: VerifyStep | null | false;
redirect: string | null;
loadStep: (profile: UserProfile) => void;
requesting: boolean;
requested: boolean;
requestOTP: () => Promise<void>;
confirming: boolean;
confirmOTP: (req: ConfirmEmailRequest) => Promise<void>;
uploading: boolean;
uploadAvatar: (image: File) => Promise<void>;
verifying: boolean;
verify: () => Promise<void>;
setRedirect: (redirect: string) => void;
}
export const useVerify = create<IVerifyState>((set) => ({
export interface IVerifyActions {
loadStep: (profile: UserProfile) => void;
requestOTP: () => Promise<void>;
confirmOTP: (req: ConfirmEmailRequest) => Promise<void>;
uploadAvatar: (image: File) => Promise<void>;
verify: () => Promise<void>;
setRedirect: (redirect: string) => void;
reset: () => void;
}
const initialState: IVerifyState = {
step: null,
redirect: null,
@ -43,6 +40,12 @@ export const useVerify = create<IVerifyState>((set) => ({
confirming: false,
uploading: false,
verifying: false,
};
export const useVerify = create<IVerifyState & IVerifyActions>((set, get) => ({
...initialState,
reset: () => set(initialState),
loadStep: (profile) => {
if (!profile.email_verified) {
@ -116,6 +119,7 @@ export const useVerify = create<IVerifyState>((set) => ({
console.log("ERR: Failed to finish verification:", err);
} finally {
set({ verifying: false });
get().reset();
}
},
}));