Compare commits

...

8 Commits

14 changed files with 265 additions and 60 deletions

View File

@ -21,7 +21,7 @@ func New(repo *repository.Queries, cfg *config.AppConfig) *AdminHandler {
func (h *AdminHandler) RegisterRoutes(router chi.Router) {
router.Route("/admin", func(r chi.Router) {
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg)
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg, h.repo)
adminMiddleware := imiddleware.NewAdminMiddleware(h.repo)
r.Use(authMiddleware.Runner, adminMiddleware.Runner)
@ -37,6 +37,8 @@ func (h *AdminHandler) RegisterRoutes(router chi.Router) {
r.Get("/users/{id}", h.GetUser)
r.Get("/user-sessions", h.GetUserSessions)
r.Patch("/user-sessions/revoke/{id}", h.RevokeUserSession)
r.Get("/service-sessions", h.GetServiceSessions)
})

View File

@ -3,17 +3,20 @@ package admin
import (
"encoding/json"
"log"
"math"
"net/http"
"strconv"
"gitea.local/admin/hspguard/internal/repository"
"gitea.local/admin/hspguard/internal/types"
"gitea.local/admin/hspguard/internal/web"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
type GetSessionsParams struct {
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
PageSize int `json:"size"`
Page int `json:"page"`
// TODO: More filtering possibilities like onlyActive, expired, not-expired etc.
}
@ -22,17 +25,22 @@ func (h *AdminHandler) GetUserSessions(w http.ResponseWriter, r *http.Request) {
params := GetSessionsParams{}
if limit, err := strconv.Atoi(q.Get("limit")); err == nil {
params.Limit = int32(limit)
if pageSize, err := strconv.Atoi(q.Get("size")); err == nil {
params.PageSize = pageSize
} else {
params.PageSize = 15
}
if offset, err := strconv.Atoi(q.Get("offset")); err == nil {
params.Offset = int32(offset)
if page, err := strconv.Atoi(q.Get("page")); err == nil {
params.Page = page
} else {
web.Error(w, "page is required", http.StatusBadRequest)
return
}
sessions, err := h.repo.GetUserSessions(r.Context(), repository.GetUserSessionsParams{
Limit: params.Limit,
Offset: params.Offset,
Limit: int32(params.PageSize),
Offset: int32(params.Page-1) * int32(params.PageSize),
})
if err != nil {
log.Println("ERR: Failed to read user sessions from db:", err)
@ -40,35 +48,77 @@ func (h *AdminHandler) GetUserSessions(w http.ResponseWriter, r *http.Request) {
return
}
var mapped []*types.UserSessionDTO
totalSessions, err := h.repo.GetUserSessionsCount(r.Context())
if err != nil {
log.Println("ERR: Failed to get total count of user sessions:", err)
web.Error(w, "failed to retrieve sessions", http.StatusInternalServerError)
return
}
mapped := make([]*types.UserSessionDTO, 0)
for _, session := range sessions {
mapped = append(mapped, types.NewUserSessionDTO(&session))
}
if err := json.NewEncoder(w).Encode(mapped); err != nil {
type Response struct {
Items []*types.UserSessionDTO `json:"items"`
Page int `json:"page"`
TotalPages int `json:"total_pages"`
}
response := Response{
Items: mapped,
Page: params.Page,
TotalPages: int(math.Ceil(float64(totalSessions) / float64(params.PageSize))),
}
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Println("ERR: Failed to encode sessions in response:", err)
web.Error(w, "failed to encode sessions", http.StatusInternalServerError)
return
}
}
func (h *AdminHandler) RevokeUserSession(w http.ResponseWriter, r *http.Request) {
sessionId := chi.URLParam(r, "id")
parsed, err := uuid.Parse(sessionId)
if err != nil {
web.Error(w, "provided service id is not valid", http.StatusBadRequest)
return
}
if err := h.repo.RevokeUserSession(r.Context(), parsed); err != nil {
log.Println("ERR: Failed to revoke user session:", err)
web.Error(w, "failed to revoke user session", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("{\"success\":true}"))
}
func (h *AdminHandler) GetServiceSessions(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
params := GetSessionsParams{}
if limit, err := strconv.Atoi(q.Get("limit")); err == nil {
params.Limit = int32(limit)
if pageSize, err := strconv.Atoi(q.Get("size")); err == nil {
params.PageSize = pageSize
} else {
params.PageSize = 15
}
if offset, err := strconv.Atoi(q.Get("offset")); err == nil {
params.Offset = int32(offset)
if page, err := strconv.Atoi(q.Get("page")); err == nil {
params.Page = page
} else {
web.Error(w, "page is required", http.StatusBadRequest)
return
}
sessions, err := h.repo.GetServiceSessions(r.Context(), repository.GetServiceSessionsParams{
Limit: params.Limit,
Offset: params.Offset,
Limit: int32(params.PageSize),
Offset: int32(params.Page-1) * int32(params.PageSize),
})
if err != nil {
log.Println("ERR: Failed to read api sessions from db:", err)
@ -76,13 +126,32 @@ func (h *AdminHandler) GetServiceSessions(w http.ResponseWriter, r *http.Request
return
}
var mapped []*types.ServiceSessionDTO
totalSessions, err := h.repo.GetServiceSessionsCount(r.Context())
if err != nil {
log.Println("ERR: Failed to get total count of service sessions:", err)
web.Error(w, "failed to retrieve sessions", http.StatusInternalServerError)
return
}
mapped := make([]*types.ServiceSessionDTO, 0)
for _, session := range sessions {
mapped = append(mapped, types.NewServiceSessionDTO(&session))
}
if err := json.NewEncoder(w).Encode(sessions); err != nil {
type Response struct {
Items []*types.ServiceSessionDTO `json:"items"`
Page int `json:"page"`
TotalPages int `json:"total_pages"`
}
response := Response{
Items: mapped,
Page: params.Page,
TotalPages: int(math.Ceil(float64(totalSessions) / float64(params.PageSize))),
}
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Println("ERR: Failed to encode sessions in response:", err)
web.Error(w, "failed to encode sessions", http.StatusInternalServerError)
}

View File

@ -89,7 +89,7 @@ func NewAuthHandler(repo *repository.Queries, cache *cache.Client, cfg *config.A
func (h *AuthHandler) RegisterRoutes(api chi.Router) {
api.Route("/auth", func(r chi.Router) {
r.Group(func(protected chi.Router) {
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg)
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg, h.repo)
protected.Use(authMiddleware.Runner)
protected.Get("/profile", h.getProfile)

View File

@ -3,22 +3,27 @@ package middleware
import (
"context"
"fmt"
"log"
"net/http"
"strings"
"gitea.local/admin/hspguard/internal/config"
"gitea.local/admin/hspguard/internal/repository"
"gitea.local/admin/hspguard/internal/types"
"gitea.local/admin/hspguard/internal/util"
"gitea.local/admin/hspguard/internal/web"
"github.com/google/uuid"
)
type AuthMiddleware struct {
cfg *config.AppConfig
repo *repository.Queries
}
func NewAuthMiddleware(cfg *config.AppConfig) *AuthMiddleware {
func NewAuthMiddleware(cfg *config.AppConfig, repo *repository.Queries) *AuthMiddleware {
return &AuthMiddleware{
cfg,
repo,
}
}
@ -45,6 +50,26 @@ func (m *AuthMiddleware) Runner(next http.Handler) http.Handler {
return
}
// TODO: redis caching
parsed, err := uuid.Parse(userClaims.ID)
if err != nil {
log.Printf("ERR: Failed to parse token JTI '%s': %v\n", userClaims.ID, err)
web.Error(w, "failed to get session", http.StatusUnauthorized)
return
}
session, err := m.repo.GetUserSessionByAccessJTI(r.Context(), &parsed)
if err != nil {
log.Printf("ERR: Failed to find session with '%s' JTI: %v\n", parsed.String(), err)
web.Error(w, "no session found", http.StatusUnauthorized)
return
}
if !session.IsActive {
log.Printf("INFO: Inactive session trying to authorize: %s\n", session.AccessTokenID)
web.Error(w, "no session found", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), types.UserIdKey, userClaims.Subject)
ctx = context.WithValue(ctx, types.JTIKey, userClaims.ID)
next.ServeHTTP(w, r.WithContext(ctx))

View File

@ -25,7 +25,7 @@ func NewOAuthHandler(repo *repository.Queries, cache *cache.Client, cfg *config.
func (h *OAuthHandler) RegisterRoutes(router chi.Router) {
router.Route("/oauth", func(r chi.Router) {
r.Group(func(protected chi.Router) {
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg)
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg, h.repo)
protected.Use(authMiddleware.Runner)
protected.Post("/code", h.getAuthCode)

View File

@ -102,7 +102,6 @@ func (q *Queries) GetServiceSessionByAccessJTI(ctx context.Context, accessTokenI
const getServiceSessionByRefreshJTI = `-- name: GetServiceSessionByRefreshJTI :one
SELECT id, service_id, client_id, user_id, issued_at, expires_at, last_active, ip_address, user_agent, access_token_id, refresh_token_id, is_active, revoked_at, scope, claims FROM service_sessions
WHERE refresh_token_id = $1
AND is_active = TRUE
`
func (q *Queries) GetServiceSessionByRefreshJTI(ctx context.Context, refreshTokenID *uuid.UUID) (ServiceSession, error) {
@ -210,6 +209,17 @@ func (q *Queries) GetServiceSessions(ctx context.Context, arg GetServiceSessions
return items, nil
}
const getServiceSessionsCount = `-- name: GetServiceSessionsCount :one
SELECT COUNT(*) FROM service_sessions
`
func (q *Queries) GetServiceSessionsCount(ctx context.Context) (int64, error) {
row := q.db.QueryRow(ctx, getServiceSessionsCount)
var count int64
err := row.Scan(&count)
return count, err
}
const listActiveServiceSessionsByClient = `-- name: ListActiveServiceSessionsByClient :many
SELECT id, service_id, client_id, user_id, issued_at, expires_at, last_active, ip_address, user_agent, access_token_id, refresh_token_id, is_active, revoked_at, scope, claims FROM service_sessions
WHERE client_id = $1

View File

@ -98,7 +98,6 @@ func (q *Queries) GetUserSessionByAccessJTI(ctx context.Context, accessTokenID *
const getUserSessionByRefreshJTI = `-- name: GetUserSessionByRefreshJTI :one
SELECT id, user_id, session_type, issued_at, expires_at, last_active, ip_address, user_agent, access_token_id, refresh_token_id, device_info, is_active, revoked_at FROM user_sessions
WHERE refresh_token_id = $1
AND is_active = TRUE
`
func (q *Queries) GetUserSessionByRefreshJTI(ctx context.Context, refreshTokenID *uuid.UUID) (UserSession, error) {
@ -188,6 +187,17 @@ func (q *Queries) GetUserSessions(ctx context.Context, arg GetUserSessionsParams
return items, nil
}
const getUserSessionsCount = `-- name: GetUserSessionsCount :one
SELECT COUNT(*) FROM user_sessions
`
func (q *Queries) GetUserSessionsCount(ctx context.Context) (int64, error) {
row := q.db.QueryRow(ctx, getUserSessionsCount)
var count int64
err := row.Scan(&count)
return count, err
}
const listActiveUserSessions = `-- name: ListActiveUserSessions :many
SELECT id, user_id, session_type, issued_at, expires_at, last_active, ip_address, user_agent, access_token_id, refresh_token_id, device_info, is_active, revoked_at FROM user_sessions
WHERE user_id = $1

View File

@ -38,7 +38,7 @@ func NewUserHandler(repo *repository.Queries, minio *storage.FileStorage, cfg *c
func (h *UserHandler) RegisterRoutes(api chi.Router) {
api.Group(func(protected chi.Router) {
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg)
authMiddleware := imiddleware.NewAuthMiddleware(h.cfg, h.repo)
protected.Use(authMiddleware.Runner)
protected.Put("/avatar", h.uploadAvatar)

View File

@ -31,8 +31,7 @@ WHERE access_token_id = $1
-- name: GetServiceSessionByRefreshJTI :one
SELECT * FROM service_sessions
WHERE refresh_token_id = $1
AND is_active = TRUE;
WHERE refresh_token_id = $1;
-- name: RevokeServiceSession :exec
UPDATE service_sessions
@ -59,3 +58,6 @@ JOIN api_services AS service ON service.id = session.service_id
JOIN users AS u ON u.id = session.user_id
ORDER BY session.issued_at DESC
LIMIT $1 OFFSET $2;
-- name: GetServiceSessionsCount :one
SELECT COUNT(*) FROM service_sessions;

View File

@ -23,8 +23,7 @@ WHERE access_token_id = $1
-- name: GetUserSessionByRefreshJTI :one
SELECT * FROM user_sessions
WHERE refresh_token_id = $1
AND is_active = TRUE;
WHERE refresh_token_id = $1;
-- name: RevokeUserSession :exec
UPDATE user_sessions
@ -56,3 +55,6 @@ FROM user_sessions AS session
JOIN users AS u ON u.id = session.user_id
ORDER BY session.issued_at DESC
LIMIT $1 OFFSET $2;
-- name: GetUserSessionsCount :one
SELECT COUNT(*) FROM user_sessions;

View File

@ -2,11 +2,15 @@ import type { ServiceSession, UserSession } from "@/types";
import { axios, handleApiError } from "..";
export interface FetchUserSessionsRequest {
limit: number;
offset: number;
page: number;
size: number;
}
export type FetchUserSessionsResponse = UserSession[];
export interface FetchUserSessionsResponse {
items: UserSession[];
page: number;
total_pages: number;
}
export const adminGetUserSessionsApi = async (
req: FetchUserSessionsRequest,
@ -24,6 +28,17 @@ export const adminGetUserSessionsApi = async (
return response.data;
};
export const adminRevokeUserSessionApi = async (
sessionId: string,
): Promise<void> => {
const response = await axios.patch<FetchServiceSessionsResponse>(
`/api/v1/admin/user-sessions/revoke/${sessionId}`,
);
if (response.status !== 200 && response.status !== 201)
throw await handleApiError(response);
};
export interface FetchServiceSessionsRequest {
limit: number;
offset: number;

View File

@ -1,13 +1,14 @@
import { adminGetUserSessionsApi } from "@/api/admin/sessions";
import Breadcrumbs from "@/components/ui/breadcrumbs";
import { Button } from "@/components/ui/button";
import Avatar from "@/feature/Avatar";
import type { DeviceInfo, UserSession } from "@/types";
import type { DeviceInfo } from "@/types";
import { Ban } from "lucide-react";
import { useEffect, useMemo, useState, type FC } from "react";
import { useCallback, useEffect, useMemo, type FC } from "react";
import { Link } from "react-router";
import moment from "moment";
import Pagination from "@/components/ui/pagination";
import { useUserSessions } from "@/store/admin/userSessions";
import { useAuth } from "@/store/auth";
const SessionSource: FC<{ deviceInfo: string }> = ({ deviceInfo }) => {
const parsed = useMemo<DeviceInfo>(
@ -23,20 +24,29 @@ const SessionSource: FC<{ deviceInfo: string }> = ({ deviceInfo }) => {
};
const AdminSessionsPage: FC = () => {
const loading = false;
const [sessions, setSessions] = useState<UserSession[]>([]);
const loading = useUserSessions((s) => s.loading);
const sessions = useUserSessions((s) => s.items);
const page = useUserSessions((s) => s.page);
const totalPages = useUserSessions((s) => s.totalPages);
const fetchSessions = useUserSessions((s) => s.fetch);
const revokeSession = useUserSessions((s) => s.revoke);
const revokingId = useUserSessions((s) => s.revokingId);
const profile = useAuth((s) => s.profile);
const handleRevokeSession = useCallback(
(id: string) => {
revokeSession(id);
},
[revokeSession],
);
useEffect(() => {
adminGetUserSessionsApi({
limit: 10,
offset: 0,
}).then((res) => {
console.log("get sessions response:", res);
if (Array.isArray(res)) {
return setSessions(res);
}
});
}, []);
fetchSessions(1);
}, [fetchSessions]);
return (
<div className="relative flex flex-col items-stretch w-full">
@ -113,11 +123,6 @@ const AdminSessionsPage: FC = () => {
key={session.id}
className="hover:bg-gray-50 dark:hover:bg-gray-800"
>
{/* <td className="px-6 py-4 text-sm font-medium text-blue-600 border border-gray-300 dark:border-gray-700">
<span className="inline-block px-2 py-1 text-xs rounded-full font-semibold bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-300">
{sessionsType}
</span>
</td> */}
<td className="px-6 py-4 text-sm text-gray-700 dark:text-gray-300 border border-gray-300 dark:border-gray-700">
<div className="flex flex-row items-center gap-2 justify-start">
{typeof session.user?.profile_picture === "string" && (
@ -126,9 +131,11 @@ const AdminSessionsPage: FC = () => {
className="w-7 h-7 min-w-7"
/>
)}
<Link to={`/admin/users/${session.user_id}`}>
<p className="cursor-pointer text-blue-500">
{session.user?.full_name ?? ""}
<Link to={`/admin/users/view/${session.user_id}`}>
<p className="cursor-pointer text-blue-500 text-nowrap">
{session.user?.full_name ?? ""}{" "}
{session.user_id === profile?.id ? "(You)" : ""}
</p>
</Link>
</div>
@ -177,6 +184,8 @@ const AdminSessionsPage: FC = () => {
<Button
variant="contained"
className="bg-red-500 hover:bg-red-600 !px-1.5 !py-1.5"
onClick={() => handleRevokeSession(session.id)}
disabled={revokingId === session.id}
>
<Ban size={18} />
</Button>
@ -187,9 +196,12 @@ const AdminSessionsPage: FC = () => {
)}
</tbody>
</table>
<Pagination currentPage={1} onPageChange={console.log} totalPages={2} />
</div>
<Pagination
currentPage={page}
onPageChange={(newPage) => fetchSessions(newPage)}
totalPages={totalPages}
/>
</div>
);
};

View File

@ -62,8 +62,9 @@ export default function LoginPage() {
} catch (err: any) {
console.log(err);
setError(
"Failed to create account. " +
(err.message ?? "Unexpected error happened"),
err.response?.data?.error ??
err.message ??
"Unexpected error happened",
);
} finally {
setLoading(false);

View File

@ -0,0 +1,57 @@
import {
adminGetUserSessionsApi,
adminRevokeUserSessionApi,
} from "@/api/admin/sessions";
import type { UserSession } from "@/types";
import { create } from "zustand";
export const ADMIN_USER_SESSIONS_PAGE_SIZE = 10;
export interface IUserSessionsState {
items: UserSession[];
totalPages: number;
page: number;
loading: boolean;
revokingId: string | null;
fetch: (page: number) => Promise<void>;
revoke: (id: string) => Promise<void>;
}
export const useUserSessions = create<IUserSessionsState>((set) => ({
items: [],
totalPages: 0,
page: 1,
loading: false,
revokingId: null,
fetch: async (page) => {
set({ loading: true, page });
try {
const response = await adminGetUserSessionsApi({
page,
size: ADMIN_USER_SESSIONS_PAGE_SIZE,
});
set({ items: response.items, totalPages: response.total_pages });
} catch (err) {
console.log("ERR: Failed to fetch admin user sessions:", err);
} finally {
set({ loading: false });
}
},
revoke: async (id) => {
set({ revokingId: id });
try {
await adminRevokeUserSessionApi(id);
} catch (err) {
console.log("ERR: Failed to revoke user sessions:", err);
} finally {
set({ revokingId: null });
}
},
}));