diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..38bf106 --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,89 @@ +package auth + +import ( + "crypto/ecdsa" + "crypto/x509" + "encoding/base64" + "fmt" + "os" + + "github.com/golang-jwt/jwt/v5" +) + +func parseBase64PrivateKey(envVar string) (*ecdsa.PrivateKey, error) { + b64 := os.Getenv(envVar) + if b64 == "" { + return nil, fmt.Errorf("env var %s is empty", envVar) + } + + decoded, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 key: %v", err) + } + + return x509.ParseECPrivateKey(decoded) +} + +func parseBase64PublicKey(envVar string) (*ecdsa.PublicKey, error) { + b64 := os.Getenv(envVar) + if b64 == "" { + return nil, fmt.Errorf("env var %s is empty", envVar) + } + + decoded, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 key: %v", err) + } + + pubInterface, err := x509.ParsePKIXPublicKey(decoded) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %v", err) + } + + pubKey, ok := pubInterface.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("not an ECDSA public key") + } + + return pubKey, nil +} + +func SignJwtToken(claims jwt.Claims) (string, error) { + privateKey, err := parseBase64PrivateKey("JWT_PRIVATE_KEY") + if err != nil { + return "", err + } + + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + s, err := token.SignedString(privateKey) + if err != nil { + return "", err + } + + return s, nil +} + +func VerifyToken(token string, claims jwt.Claims) (*jwt.Token, error) { + publicKey, err := parseBase64PublicKey("JWT_PUBLIC_KEY") + if err != nil { + return nil, err + } + + parsed, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return publicKey, nil + }) + + if err != nil { + return nil, fmt.Errorf("invalid token: %w", err) + } + + if !parsed.Valid { + return nil, fmt.Errorf("token is not valid") + } + + return parsed, nil +} +