diff --git a/internal/cache/mod.go b/internal/cache/mod.go index 70ad834..38f8492 100644 --- a/internal/cache/mod.go +++ b/internal/cache/mod.go @@ -1,6 +1,8 @@ package cache import ( + "encoding/json" + "fmt" "log" "time" @@ -27,10 +29,56 @@ func NewClient(cfg *config.AppConfig) *Client { } } +type OAuthCode struct { + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + Nonce string `json:"nonce"` +} + +type SaveAuthCodeParams struct { + AuthCode string + UserID string + ClientID string + Nonce string +} + func (c *Client) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { return c.rClient.Set(ctx, key, value, expiration) } +func (c *Client) SaveAuthCode(ctx context.Context, params *SaveAuthCodeParams) error { + code := OAuthCode{ + ClientID: params.ClientID, + UserID: params.UserID, + Nonce: params.Nonce, + } + row, err := json.Marshal(&code) + if err != nil { + return err + } + + return c.Set(ctx, fmt.Sprintf("oauth.%s", params.AuthCode), string(row), 5*time.Minute).Err() +} + +func (c *Client) GetAuthCode(ctx context.Context, authCode string) (*OAuthCode, error) { + row, err := c.Get(ctx, fmt.Sprintf("oauth.%s", authCode)).Result() + if err != nil { + return nil, err + } + + if len(row) == 0 { + return nil, fmt.Errorf("no auth params found under %s", authCode) + } + + var parsed OAuthCode + + if err := json.Unmarshal([]byte(row), &parsed); err != nil { + return nil, err + } + + return &parsed, nil +} + func (c *Client) Get(ctx context.Context, key string) *redis.StringCmd { return c.rClient.Get(ctx, key) }