첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
192
baron-sso/backend/internal/service/backchannel_logout_service.go
Normal file
192
baron-sso/backend/internal/service/backchannel_logout_service.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
)
|
||||
|
||||
const backchannelLogoutEventURI = "http://schemas.openid.net/event/backchannel-logout"
|
||||
|
||||
type BackchannelLogoutService struct {
|
||||
issuer string
|
||||
keyID string
|
||||
signer jose.Signer
|
||||
publicJWK jose.JSONWebKey
|
||||
client *http.Client
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewBackchannelLogoutService() (*BackchannelLogoutService, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backchannel logout key: %w", err)
|
||||
}
|
||||
|
||||
keyID := randomBackchannelKeyID()
|
||||
if keyID == "" {
|
||||
keyID = fmt.Sprintf("bcl-%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: jose.JSONWebKey{
|
||||
Key: privateKey,
|
||||
KeyID: keyID,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
},
|
||||
}, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize backchannel logout signer: %w", err)
|
||||
}
|
||||
|
||||
return &BackchannelLogoutService{
|
||||
issuer: resolveBackchannelLogoutIssuer(),
|
||||
keyID: keyID,
|
||||
signer: signer,
|
||||
publicJWK: jose.JSONWebKey{
|
||||
Key: &privateKey.PublicKey,
|
||||
KeyID: keyID,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
},
|
||||
client: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 3 * time.Second,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func randomBackchannelKeyID() string {
|
||||
buf := make([]byte, 8)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func resolveBackchannelLogoutIssuer() string {
|
||||
if explicit := strings.TrimSpace(os.Getenv("BACKCHANNEL_LOGOUT_ISSUER")); explicit != "" {
|
||||
return strings.TrimRight(explicit, "/")
|
||||
}
|
||||
|
||||
if hydraPublic := strings.TrimSpace(os.Getenv("HYDRA_PUBLIC_URL")); hydraPublic != "" {
|
||||
return strings.TrimRight(hydraPublic, "/")
|
||||
}
|
||||
|
||||
if oathkeeperPublic := strings.TrimSpace(os.Getenv("OATHKEEPER_PUBLIC_URL")); oathkeeperPublic != "" {
|
||||
return strings.TrimRight(oathkeeperPublic, "/") + "/oidc"
|
||||
}
|
||||
|
||||
if userfrontURL := strings.TrimSpace(os.Getenv("USERFRONT_URL")); userfrontURL != "" {
|
||||
return strings.TrimRight(userfrontURL, "/") + "/oidc"
|
||||
}
|
||||
|
||||
return "http://localhost:5000/oidc"
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) Issuer() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.issuer
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) PublicJWKS() map[string]any {
|
||||
if s == nil {
|
||||
return map[string]any{"keys": []any{}}
|
||||
}
|
||||
return map[string]any{
|
||||
"keys": []jose.JSONWebKey{s.publicJWK.Public()},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) BuildLogoutToken(clientID, subject, sessionID string) (string, error) {
|
||||
if s == nil || s.signer == nil {
|
||||
return "", fmt.Errorf("backchannel logout service is unavailable")
|
||||
}
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
subject = strings.TrimSpace(subject)
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if clientID == "" {
|
||||
return "", fmt.Errorf("client id is required")
|
||||
}
|
||||
if subject == "" && sessionID == "" {
|
||||
return "", fmt.Errorf("subject or session id is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
claims := josejwt.Claims{
|
||||
Issuer: s.issuer,
|
||||
Audience: josejwt.Audience{clientID},
|
||||
IssuedAt: josejwt.NewNumericDate(now),
|
||||
ID: fmt.Sprintf("%s-%d", s.keyID, now.UnixNano()),
|
||||
}
|
||||
if subject != "" {
|
||||
claims.Subject = subject
|
||||
}
|
||||
|
||||
extra := map[string]any{
|
||||
"events": map[string]any{
|
||||
backchannelLogoutEventURI: map[string]any{},
|
||||
},
|
||||
}
|
||||
if sessionID != "" {
|
||||
extra["sid"] = sessionID
|
||||
}
|
||||
|
||||
return josejwt.Signed(s.signer).Claims(claims).Claims(extra).Serialize()
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) SendLogoutToken(ctx context.Context, endpoint, logoutToken string) (int, error) {
|
||||
if s == nil {
|
||||
return 0, fmt.Errorf("backchannel logout service is unavailable")
|
||||
}
|
||||
form := url.Values{}
|
||||
form.Set("logout_token", logoutToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := s.client
|
||||
if s.HTTPClient != nil {
|
||||
client = s.HTTPClient
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return resp.StatusCode, fmt.Errorf("backchannel logout endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (s *BackchannelLogoutService) MarshalPublicJWKS() ([]byte, error) {
|
||||
return json.Marshal(s.PublicJWKS())
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
josejwt "github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBackchannelLogoutService_BuildLogoutToken(t *testing.T) {
|
||||
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
|
||||
|
||||
svc, err := NewBackchannelLogoutService()
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := svc.BuildLogoutToken("client-1", "user-1", "sid-1")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
jwksRaw, err := svc.MarshalPublicJWKS()
|
||||
require.NoError(t, err)
|
||||
|
||||
var jwks struct {
|
||||
Keys []jose.JSONWebKey `json:"keys"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(jwksRaw, &jwks))
|
||||
require.Len(t, jwks.Keys, 1)
|
||||
|
||||
parsed, err := josejwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.RS256})
|
||||
require.NoError(t, err)
|
||||
|
||||
var claims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Aud any `json:"aud"`
|
||||
Iat int64 `json:"iat"`
|
||||
Jti string `json:"jti"`
|
||||
Sid string `json:"sid"`
|
||||
Events map[string]any `json:"events"`
|
||||
}
|
||||
require.NoError(t, parsed.Claims(jwks.Keys[0].Key, &claims))
|
||||
|
||||
assert.Equal(t, "https://sso.example.com/oidc", claims.Issuer)
|
||||
assert.Equal(t, "user-1", claims.Subject)
|
||||
switch aud := claims.Aud.(type) {
|
||||
case string:
|
||||
assert.Equal(t, "client-1", aud)
|
||||
case []any:
|
||||
assert.Len(t, aud, 1)
|
||||
assert.Equal(t, "client-1", aud[0])
|
||||
default:
|
||||
t.Fatalf("unexpected aud type: %T", claims.Aud)
|
||||
}
|
||||
assert.NotZero(t, claims.Iat)
|
||||
assert.NotEmpty(t, claims.Jti)
|
||||
assert.Equal(t, "sid-1", claims.Sid)
|
||||
_, ok := claims.Events[backchannelLogoutEventURI]
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestBackchannelLogoutService_SendLogoutToken(t *testing.T) {
|
||||
var body string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
body = string(raw)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
svc, err := NewBackchannelLogoutService()
|
||||
require.NoError(t, err)
|
||||
svc.HTTPClient = clientForHandler(handler)
|
||||
|
||||
statusCode, err := svc.SendLogoutToken(context.Background(), "https://rp.example.com/backchannel-logout", "signed-token")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, statusCode)
|
||||
assert.Equal(t, "logout_token=signed-token", body)
|
||||
}
|
||||
86
baron-sso/backend/internal/service/developer_service.go
Normal file
86
baron-sso/backend/internal/service/developer_service.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DeveloperService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewDeveloperService(db *gorm.DB) *DeveloperService {
|
||||
return &DeveloperService{db: db}
|
||||
}
|
||||
|
||||
func (s *DeveloperService) RequestAccess(ctx context.Context, req domain.DeveloperRequest) error {
|
||||
// Check if there is already a pending request
|
||||
var existing domain.DeveloperRequest
|
||||
err := s.db.WithContext(ctx).Where("user_id = ? AND tenant_id = ? AND status = ?", req.UserID, req.TenantID, domain.DeveloperRequestStatusPending).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Create(&req).Error
|
||||
}
|
||||
|
||||
func (s *DeveloperService) GetRequestStatus(ctx context.Context, userID, tenantID string) (*domain.DeveloperRequest, error) {
|
||||
var req domain.DeveloperRequest
|
||||
err := s.db.WithContext(ctx).Where("user_id = ? AND tenant_id = ?", userID, tenantID).Order("created_at DESC").First(&req).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (s *DeveloperService) GetRequestByID(ctx context.Context, id uint) (*domain.DeveloperRequest, error) {
|
||||
var req domain.DeveloperRequest
|
||||
err := s.db.WithContext(ctx).First(&req, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (s *DeveloperService) ListRequests(ctx context.Context, userID, status string) ([]domain.DeveloperRequest, error) {
|
||||
var requests []domain.DeveloperRequest
|
||||
query := s.db.WithContext(ctx)
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
err := query.Order("created_at DESC").Find(&requests).Error
|
||||
return requests, err
|
||||
}
|
||||
|
||||
func (s *DeveloperService) ApproveRequest(ctx context.Context, id uint, adminNotes string) error {
|
||||
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.DeveloperRequestStatusApproved,
|
||||
"admin_notes": adminNotes,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (s *DeveloperService) RejectRequest(ctx context.Context, id uint, adminNotes string) error {
|
||||
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.DeveloperRequestStatusRejected,
|
||||
"admin_notes": adminNotes,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (s *DeveloperService) CancelApprovedRequest(ctx context.Context, id uint, adminNotes string) error {
|
||||
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": domain.DeveloperRequestStatusCancelled,
|
||||
"admin_notes": adminNotes,
|
||||
}).Error
|
||||
}
|
||||
19
baron-sso/backend/internal/service/dry_run_service.go
Normal file
19
baron-sso/backend/internal/service/dry_run_service.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/logger"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func IsProductionEnv() bool {
|
||||
env := strings.ToLower(os.Getenv("APP_ENV"))
|
||||
if env == "" {
|
||||
env = strings.ToLower(os.Getenv("GO_ENV"))
|
||||
}
|
||||
return logger.IsProductionLikeEnv(env)
|
||||
}
|
||||
|
||||
func IsDryRunAllowed() bool {
|
||||
return !IsProductionEnv()
|
||||
}
|
||||
43
baron-sso/backend/internal/service/dry_run_service_test.go
Normal file
43
baron-sso/backend/internal/service/dry_run_service_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsProductionEnv_StageIsProductionLike(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "stage")
|
||||
t.Setenv("GO_ENV", "")
|
||||
|
||||
if !IsProductionEnv() {
|
||||
t.Fatalf("expected stage to be treated as production-like")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDryRunAllowed_DisabledInStage(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "stage")
|
||||
t.Setenv("GO_ENV", "")
|
||||
|
||||
if IsDryRunAllowed() {
|
||||
t.Fatalf("expected dry-run to be disabled in stage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsProductionEnv_FallsBackToGoEnv(t *testing.T) {
|
||||
originalAppEnv, hadAppEnv := os.LookupEnv("APP_ENV")
|
||||
if hadAppEnv {
|
||||
t.Cleanup(func() {
|
||||
_ = os.Setenv("APP_ENV", originalAppEnv)
|
||||
})
|
||||
} else {
|
||||
t.Cleanup(func() {
|
||||
_ = os.Unsetenv("APP_ENV")
|
||||
})
|
||||
}
|
||||
_ = os.Unsetenv("APP_ENV")
|
||||
t.Setenv("GO_ENV", "production")
|
||||
|
||||
if !IsProductionEnv() {
|
||||
t.Fatalf("expected GO_ENV=production fallback to be production-like")
|
||||
}
|
||||
}
|
||||
90
baron-sso/backend/internal/service/federation_service.go
Normal file
90
baron-sso/backend/internal/service/federation_service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type FederationService struct {
|
||||
repo repository.FederationRepository
|
||||
hydraSvc *HydraAdminService
|
||||
redisSvc *RedisService
|
||||
}
|
||||
|
||||
func NewFederationService(repo repository.FederationRepository, hydraSvc *HydraAdminService, redisSvc *RedisService) *FederationService {
|
||||
return &FederationService{repo: repo, hydraSvc: hydraSvc, redisSvc: redisSvc}
|
||||
}
|
||||
|
||||
func (s *FederationService) InitiateOIDCLogin(ctx context.Context, providerID, loginChallenge string) (string, error) {
|
||||
provider, err := s.repo.FindProviderByID(ctx, providerID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to find provider: %w", err)
|
||||
}
|
||||
|
||||
if provider == nil || provider.IssuerURL == nil || provider.OIDCClientID == nil || provider.OIDCClientSecret == nil || provider.Scopes == nil {
|
||||
return "", fmt.Errorf("OIDC configuration for provider %s is incomplete", providerID)
|
||||
}
|
||||
|
||||
oidcProvider, err := oidc.NewProvider(ctx, *provider.IssuerURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OIDC provider: %w", err)
|
||||
}
|
||||
|
||||
config := oauth2.Config{
|
||||
ClientID: *provider.OIDCClientID,
|
||||
ClientSecret: *provider.OIDCClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: "http://localhost:8080/api/v1/federation/oidc/callback", // This should be configurable
|
||||
Scopes: []string{*provider.Scopes},
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Store state and login_challenge in Redis
|
||||
redisKey := fmt.Sprintf("oidc_state:%s", state)
|
||||
if err := s.redisSvc.Set(redisKey, loginChallenge, 10*time.Minute); err != nil {
|
||||
return "", fmt.Errorf("failed to save state to Redis: %w", err)
|
||||
}
|
||||
|
||||
return config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
func (s *FederationService) HandleOIDCCallback(ctx context.Context, code, state string) (string, error) {
|
||||
// 1. Retrieve login_challenge from Redis
|
||||
redisKey := fmt.Sprintf("oidc_state:%s", state)
|
||||
loginChallenge, err := s.redisSvc.Get(redisKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get state from Redis or state expired: %w", err)
|
||||
}
|
||||
// Delete the state from Redis now that it's been used
|
||||
s.redisSvc.Delete(redisKey)
|
||||
|
||||
// TODO: Finish the rest of the callback logic
|
||||
// 2. Exchange code for token
|
||||
// 3. Verify ID token
|
||||
// 4. JIT Provisioning
|
||||
// 5. Accept Hydra Login Request
|
||||
|
||||
fmt.Println("Login challenge found:", loginChallenge)
|
||||
|
||||
return "http://localhost:3000/login?login_successful=true", nil // Placeholder
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
545
baron-sso/backend/internal/service/headless_jwks_cache.go
Normal file
545
baron-sso/backend/internal/service/headless_jwks_cache.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
headlessJWKSCacheKeyPrefix = "headless_jwks_cache:"
|
||||
)
|
||||
|
||||
type HeadlessJWKSCacheService struct {
|
||||
Redis domain.RedisRepository
|
||||
HTTPClient *http.Client
|
||||
TTL time.Duration
|
||||
PrefetchWindow time.Duration
|
||||
RequestTimeout time.Duration
|
||||
FailureThreshold int
|
||||
FailureBackoff time.Duration
|
||||
}
|
||||
|
||||
type headlessJWKSCacheStateStore struct {
|
||||
ClientID string `json:"clientId"`
|
||||
JWKSURI string `json:"jwksUri"`
|
||||
CachedAt *time.Time `json:"cachedAt,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
|
||||
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
|
||||
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
|
||||
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
|
||||
CachedKids []string `json:"cachedKids,omitempty"`
|
||||
ETag string `json:"etag,omitempty"`
|
||||
LastModified string `json:"lastModified,omitempty"`
|
||||
RawJWKS string `json:"rawJwks,omitempty"`
|
||||
}
|
||||
|
||||
type HeadlessJWKSCacheWorker struct {
|
||||
Hydra *HydraAdminService
|
||||
Cache *HeadlessJWKSCacheService
|
||||
Interval time.Duration
|
||||
PageSize int
|
||||
}
|
||||
|
||||
func NewHeadlessJWKSCacheService(redis domain.RedisRepository, httpClient *http.Client) *HeadlessJWKSCacheService {
|
||||
ttlSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_CACHE_TTL_SECONDS", "1800")))
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = 1800
|
||||
}
|
||||
|
||||
prefetchSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_PREFETCH_WINDOW_SECONDS", "600")))
|
||||
if prefetchSeconds <= 0 {
|
||||
prefetchSeconds = 600
|
||||
}
|
||||
|
||||
timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "2")))
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 2
|
||||
}
|
||||
|
||||
failureThreshold, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_THRESHOLD", "3")))
|
||||
if failureThreshold <= 0 {
|
||||
failureThreshold = 3
|
||||
}
|
||||
|
||||
backoffSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_BACKOFF_SECONDS", "1800")))
|
||||
if backoffSeconds <= 0 {
|
||||
backoffSeconds = 1800
|
||||
}
|
||||
|
||||
return &HeadlessJWKSCacheService{
|
||||
Redis: redis,
|
||||
HTTPClient: httpClient,
|
||||
TTL: time.Duration(ttlSeconds) * time.Second,
|
||||
PrefetchWindow: time.Duration(prefetchSeconds) * time.Second,
|
||||
RequestTimeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
FailureThreshold: failureThreshold,
|
||||
FailureBackoff: time.Duration(backoffSeconds) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func NewHeadlessJWKSCacheWorker(hydra *HydraAdminService, cache *HeadlessJWKSCacheService) *HeadlessJWKSCacheWorker {
|
||||
intervalSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_REFRESH_INTERVAL_SECONDS", "600")))
|
||||
if intervalSeconds <= 0 {
|
||||
intervalSeconds = 600
|
||||
}
|
||||
|
||||
return &HeadlessJWKSCacheWorker{
|
||||
Hydra: hydra,
|
||||
Cache: cache,
|
||||
Interval: time.Duration(intervalSeconds) * time.Second,
|
||||
PageSize: 100,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) httpClient() *http.Client {
|
||||
if s.HTTPClient != nil {
|
||||
return s.HTTPClient
|
||||
}
|
||||
timeout := s.RequestTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
return &http.Client{Timeout: timeout}
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) cacheKey(clientID string) string {
|
||||
return headlessJWKSCacheKeyPrefix + strings.TrimSpace(clientID)
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) SaveState(clientID string, state domain.HeadlessJWKSCacheState) error {
|
||||
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(headlessJWKSCacheStateStore{
|
||||
ClientID: state.ClientID,
|
||||
JWKSURI: state.JWKSURI,
|
||||
CachedAt: state.CachedAt,
|
||||
ExpiresAt: state.ExpiresAt,
|
||||
LastCheckedAt: state.LastCheckedAt,
|
||||
NextRetryAt: state.NextRetryAt,
|
||||
LastSuccessfulVerificationAt: state.LastSuccessfulVerificationAt,
|
||||
LastRefreshStatus: state.LastRefreshStatus,
|
||||
LastError: state.LastError,
|
||||
ConsecutiveFailures: state.ConsecutiveFailures,
|
||||
CachedKids: state.CachedKids,
|
||||
ETag: state.ETag,
|
||||
LastModified: state.LastModified,
|
||||
RawJWKS: state.RawJWKS,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Redis.Set(s.cacheKey(clientID), string(payload), 0)
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) GetState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
|
||||
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
raw, err := s.Redis.Get(s.cacheKey(clientID))
|
||||
if err != nil || strings.TrimSpace(raw) == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var stored headlessJWKSCacheStateStore
|
||||
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.HeadlessJWKSCacheState{
|
||||
ClientID: stored.ClientID,
|
||||
JWKSURI: stored.JWKSURI,
|
||||
CachedAt: stored.CachedAt,
|
||||
ExpiresAt: stored.ExpiresAt,
|
||||
LastCheckedAt: stored.LastCheckedAt,
|
||||
NextRetryAt: stored.NextRetryAt,
|
||||
LastSuccessfulVerificationAt: stored.LastSuccessfulVerificationAt,
|
||||
LastRefreshStatus: stored.LastRefreshStatus,
|
||||
LastError: stored.LastError,
|
||||
ConsecutiveFailures: stored.ConsecutiveFailures,
|
||||
CachedKids: stored.CachedKids,
|
||||
ETag: stored.ETag,
|
||||
LastModified: stored.LastModified,
|
||||
RawJWKS: stored.RawJWKS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) DeleteState(clientID string) error {
|
||||
if s == nil || s.Redis == nil {
|
||||
return nil
|
||||
}
|
||||
return s.Redis.Delete(s.cacheKey(clientID))
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) PublicState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
|
||||
state, err := s.GetState(clientID)
|
||||
if err != nil || state == nil {
|
||||
return state, err
|
||||
}
|
||||
state.ParsedKeys = summarizeHeadlessJWKS(state.RawJWKS)
|
||||
state.RawJWKS = ""
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) MarkVerificationSuccess(clientID string) error {
|
||||
state, err := s.GetState(clientID)
|
||||
if err != nil || state == nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
state.LastSuccessfulVerificationAt = &now
|
||||
return s.SaveState(clientID, *state)
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
|
||||
if state == nil {
|
||||
return true
|
||||
}
|
||||
if s.ShouldSkipRefresh(state, now) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(state.RawJWKS) == "" {
|
||||
return true
|
||||
}
|
||||
if state.ExpiresAt == nil {
|
||||
return true
|
||||
}
|
||||
return !state.ExpiresAt.After(now.Add(s.PrefetchWindow))
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) ShouldSkipRefresh(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
|
||||
if state == nil || state.NextRetryAt == nil {
|
||||
return false
|
||||
}
|
||||
return state.NextRetryAt.After(now)
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) EnsureFreshKeySet(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
|
||||
if s == nil {
|
||||
return nil, nil, false, fmt.Errorf("headless jwks cache service is not configured")
|
||||
}
|
||||
|
||||
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
|
||||
if jwksURI == "" {
|
||||
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
|
||||
}
|
||||
|
||||
state, err := s.GetState(client.ClientID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load headless jwks cache state", "clientID", client.ClientID, "error", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
switch {
|
||||
case state == nil:
|
||||
return s.refreshClient(ctx, client, nil, "cache_miss")
|
||||
case strings.TrimSpace(state.JWKSURI) != jwksURI:
|
||||
return s.refreshClient(ctx, client, state, "config_changed")
|
||||
case strings.TrimSpace(state.RawJWKS) == "":
|
||||
return s.refreshClient(ctx, client, state, "cache_empty")
|
||||
case state.ExpiresAt == nil || !state.ExpiresAt.After(now):
|
||||
return s.refreshClient(ctx, client, state, "ttl_expired")
|
||||
case expectedKid != "" && !containsString(state.CachedKids, expectedKid):
|
||||
return s.refreshClient(ctx, client, state, "kid_missing")
|
||||
default:
|
||||
keySet, err := decodeHeadlessJWKS(state.RawJWKS)
|
||||
if err != nil {
|
||||
return s.refreshClient(ctx, client, state, "cache_corrupt")
|
||||
}
|
||||
return keySet, state, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) ForceRefresh(ctx context.Context, client domain.HydraClient, reason string) (*domain.HeadlessJWKSCacheState, error) {
|
||||
_, state, err := s.ForceRefreshKeySet(ctx, client, reason)
|
||||
return state, err
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) ForceRefreshKeySet(ctx context.Context, client domain.HydraClient, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, error) {
|
||||
previous, err := s.GetState(client.ClientID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load headless jwks cache state before force refresh", "clientID", client.ClientID, "error", err)
|
||||
}
|
||||
keySet, state, _, err := s.refreshClient(ctx, client, previous, reason)
|
||||
return keySet, state, err
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
|
||||
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
|
||||
if jwksURI == "" {
|
||||
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI, nil)
|
||||
if err != nil {
|
||||
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to build jwks request: %w", err)), false, err
|
||||
}
|
||||
if previous != nil {
|
||||
if etag := strings.TrimSpace(previous.ETag); etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
if lastModified := strings.TrimSpace(previous.LastModified); lastModified != "" {
|
||||
req.Header.Set("If-Modified-Since", lastModified)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to fetch jwksUri: %w", err)), false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
now := time.Now()
|
||||
if resp.StatusCode == http.StatusNotModified && previous != nil && strings.TrimSpace(previous.RawJWKS) != "" {
|
||||
updated := *previous
|
||||
updated.JWKSURI = jwksURI
|
||||
updated.LastCheckedAt = &now
|
||||
updated.ExpiresAt = new(now.Add(s.TTL))
|
||||
updated.NextRetryAt = nil
|
||||
updated.LastRefreshStatus = "success"
|
||||
updated.LastError = ""
|
||||
updated.ConsecutiveFailures = 0
|
||||
_ = s.SaveState(client.ClientID, updated)
|
||||
keySet, decodeErr := decodeHeadlessJWKS(updated.RawJWKS)
|
||||
return keySet, &updated, true, decodeErr
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
err = fmt.Errorf("failed to fetch jwksUri status=%d body=%s", resp.StatusCode, string(body))
|
||||
return nil, s.persistRefreshFailure(client, previous, err), false, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||
if err != nil {
|
||||
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to read jwks response: %w", err)), false, err
|
||||
}
|
||||
|
||||
keySet, err := decodeHeadlessJWKS(string(body))
|
||||
if err != nil {
|
||||
return nil, s.persistRefreshFailure(client, previous, err), false, err
|
||||
}
|
||||
|
||||
state := domain.HeadlessJWKSCacheState{
|
||||
ClientID: client.ClientID,
|
||||
JWKSURI: jwksURI,
|
||||
CachedAt: &now,
|
||||
ExpiresAt: new(now.Add(s.TTL)),
|
||||
LastCheckedAt: &now,
|
||||
NextRetryAt: nil,
|
||||
LastSuccessfulVerificationAt: previousLastVerification(previous),
|
||||
LastRefreshStatus: "success",
|
||||
LastError: "",
|
||||
ConsecutiveFailures: 0,
|
||||
CachedKids: extractHeadlessKids(keySet),
|
||||
ETag: strings.TrimSpace(resp.Header.Get("ETag")),
|
||||
LastModified: strings.TrimSpace(resp.Header.Get("Last-Modified")),
|
||||
RawJWKS: string(body),
|
||||
}
|
||||
if err := s.SaveState(client.ClientID, state); err != nil {
|
||||
return nil, &state, false, err
|
||||
}
|
||||
slog.Info("headless jwks cache refreshed", "clientID", client.ClientID, "reason", reason, "keyCount", len(keySet.Keys))
|
||||
return keySet, &state, true, nil
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) persistRefreshFailure(client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, refreshErr error) *domain.HeadlessJWKSCacheState {
|
||||
now := time.Now()
|
||||
state := domain.HeadlessJWKSCacheState{
|
||||
ClientID: client.ClientID,
|
||||
JWKSURI: strings.TrimSpace(client.HeadlessJWKSURI()),
|
||||
LastCheckedAt: &now,
|
||||
LastRefreshStatus: "failure",
|
||||
LastError: refreshErr.Error(),
|
||||
ConsecutiveFailures: 1,
|
||||
}
|
||||
if previous != nil {
|
||||
state.CachedAt = previous.CachedAt
|
||||
state.ExpiresAt = previous.ExpiresAt
|
||||
state.LastSuccessfulVerificationAt = previous.LastSuccessfulVerificationAt
|
||||
state.CachedKids = previous.CachedKids
|
||||
state.ETag = previous.ETag
|
||||
state.LastModified = previous.LastModified
|
||||
state.RawJWKS = previous.RawJWKS
|
||||
state.ConsecutiveFailures = previous.ConsecutiveFailures + 1
|
||||
}
|
||||
if s.shouldBackoff(state.ConsecutiveFailures) {
|
||||
state.NextRetryAt = new(now.Add(s.failureBackoffDuration()))
|
||||
}
|
||||
_ = s.SaveState(client.ClientID, state)
|
||||
return &state
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) shouldBackoff(consecutiveFailures int) bool {
|
||||
threshold := s.FailureThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 3
|
||||
}
|
||||
return consecutiveFailures >= threshold
|
||||
}
|
||||
|
||||
func (s *HeadlessJWKSCacheService) failureBackoffDuration() time.Duration {
|
||||
if s.FailureBackoff > 0 {
|
||||
return s.FailureBackoff
|
||||
}
|
||||
return 30 * time.Minute
|
||||
}
|
||||
|
||||
func decodeHeadlessJWKS(raw string) (*jose.JSONWebKeySet, error) {
|
||||
var keySet jose.JSONWebKeySet
|
||||
if err := json.Unmarshal([]byte(raw), &keySet); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks from jwksUri: %w", err)
|
||||
}
|
||||
if len(keySet.Keys) == 0 {
|
||||
return nil, fmt.Errorf("configured jwksUri returned no keys")
|
||||
}
|
||||
return &keySet, nil
|
||||
}
|
||||
|
||||
type headlessJWKSPreviewDocument struct {
|
||||
Keys []headlessJWKSPreviewKey `json:"keys"`
|
||||
}
|
||||
|
||||
type headlessJWKSPreviewKey struct {
|
||||
Kid string `json:"kid"`
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
Alg string `json:"alg"`
|
||||
N string `json:"n"`
|
||||
}
|
||||
|
||||
func summarizeHeadlessJWKS(raw string) []domain.HeadlessJWKSParsedKey {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var document headlessJWKSPreviewDocument
|
||||
if err := json.Unmarshal([]byte(raw), &document); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedKeys := make([]domain.HeadlessJWKSParsedKey, 0, len(document.Keys))
|
||||
for _, key := range document.Keys {
|
||||
parsedKeys = append(parsedKeys, domain.HeadlessJWKSParsedKey{
|
||||
Kid: strings.TrimSpace(key.Kid),
|
||||
Kty: strings.TrimSpace(key.Kty),
|
||||
Use: strings.TrimSpace(key.Use),
|
||||
Alg: strings.TrimSpace(key.Alg),
|
||||
N: strings.TrimSpace(key.N),
|
||||
})
|
||||
}
|
||||
return parsedKeys
|
||||
}
|
||||
|
||||
func extractHeadlessKids(keySet *jose.JSONWebKeySet) []string {
|
||||
if keySet == nil {
|
||||
return nil
|
||||
}
|
||||
kids := make([]string, 0, len(keySet.Keys))
|
||||
for _, key := range keySet.Keys {
|
||||
if kid := strings.TrimSpace(key.KeyID); kid != "" {
|
||||
kids = append(kids, kid)
|
||||
}
|
||||
}
|
||||
return kids
|
||||
}
|
||||
|
||||
func containsString(values []string, needle string) bool {
|
||||
needle = strings.TrimSpace(needle)
|
||||
if needle == "" {
|
||||
return false
|
||||
}
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func previousLastVerification(previous *domain.HeadlessJWKSCacheState) *time.Time {
|
||||
if previous == nil {
|
||||
return nil
|
||||
}
|
||||
return previous.LastSuccessfulVerificationAt
|
||||
}
|
||||
|
||||
//go:fix inline
|
||||
func ptrTime(value time.Time) *time.Time {
|
||||
return new(value)
|
||||
}
|
||||
|
||||
func (w *HeadlessJWKSCacheWorker) Start(ctx context.Context) {
|
||||
if w == nil || w.Hydra == nil || w.Cache == nil {
|
||||
return
|
||||
}
|
||||
w.runOnce(ctx)
|
||||
ticker := time.NewTicker(w.Interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.runOnce(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HeadlessJWKSCacheWorker) runOnce(ctx context.Context) {
|
||||
offset := 0
|
||||
pageSize := w.PageSize
|
||||
if pageSize <= 0 {
|
||||
pageSize = 100
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
for {
|
||||
clients, err := w.Hydra.ListClients(ctx, pageSize, offset)
|
||||
if err != nil {
|
||||
slog.Warn("headless jwks worker failed to list clients", "error", err)
|
||||
return
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
if !client.IsHeadlessLoginEnabled() {
|
||||
continue
|
||||
}
|
||||
state, err := w.Cache.GetState(client.ClientID)
|
||||
if err != nil {
|
||||
slog.Warn("headless jwks worker failed to load cache state", "clientID", client.ClientID, "error", err)
|
||||
continue
|
||||
}
|
||||
if !w.Cache.ShouldPrefetch(state, now) {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Cache.ForceRefresh(ctx, client, "cron_prefetch"); err != nil {
|
||||
slog.Warn("headless jwks worker refresh failed", "clientID", client.ClientID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(clients) < pageSize {
|
||||
return
|
||||
}
|
||||
offset += len(clients)
|
||||
}
|
||||
}
|
||||
450
baron-sso/backend/internal/service/headless_jwks_cache_test.go
Normal file
450
baron-sso/backend/internal/service/headless_jwks_cache_test.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type headlessJWKSCacheTestRedis struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (m *headlessJWKSCacheTestRedis) Set(key string, value string, expiration time.Duration) error {
|
||||
if m.data == nil {
|
||||
m.data = map[string]string{}
|
||||
}
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *headlessJWKSCacheTestRedis) Get(key string) (string, error) {
|
||||
if m.data == nil {
|
||||
return "", nil
|
||||
}
|
||||
return m.data[key], nil
|
||||
}
|
||||
|
||||
func (m *headlessJWKSCacheTestRedis) Delete(key string) error {
|
||||
if m.data != nil {
|
||||
delete(m.data, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *headlessJWKSCacheTestRedis) StoreVerificationCode(phone, code string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *headlessJWKSCacheTestRedis) GetVerificationCode(phone string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *headlessJWKSCacheTestRedis) DeleteVerificationCode(phone string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_UsesCachedJWKSWhenFresh(t *testing.T) {
|
||||
_, jwks := mustServiceHeadlessRSAJWK(t, "cached-key")
|
||||
raw, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
redisRepo := &headlessJWKSCacheTestRedis{}
|
||||
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
|
||||
now := time.Now()
|
||||
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-headless",
|
||||
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
|
||||
RawJWKS: string(raw),
|
||||
CachedKids: []string{"cached-key"},
|
||||
CachedAt: &now,
|
||||
LastCheckedAt: &now,
|
||||
ExpiresAt: new(now.Add(30 * time.Minute)),
|
||||
LastRefreshStatus: "success",
|
||||
ConsecutiveFailures: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected network fetch: %s", r.URL.String())
|
||||
}))
|
||||
|
||||
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
|
||||
ClientID: "client-headless",
|
||||
Metadata: map[string]any{
|
||||
domain.MetadataHeadlessLoginEnabled: true,
|
||||
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
}, "cached-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, refreshed)
|
||||
require.NotNil(t, keySet)
|
||||
assert.Len(t, keySet.Keys, 1)
|
||||
require.NotNil(t, state)
|
||||
assert.Equal(t, []string{"cached-key"}, state.CachedKids)
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *testing.T) {
|
||||
_, staleJWKS := mustServiceHeadlessRSAJWK(t, "stale-key")
|
||||
staleRaw, err := json.Marshal(staleJWKS)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
|
||||
freshRaw, err := json.Marshal(freshJWKS)
|
||||
require.NoError(t, err)
|
||||
|
||||
redisRepo := &headlessJWKSCacheTestRedis{}
|
||||
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
|
||||
now := time.Now()
|
||||
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-headless",
|
||||
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
|
||||
RawJWKS: string(staleRaw),
|
||||
CachedKids: []string{"stale-key"},
|
||||
CachedAt: &now,
|
||||
LastCheckedAt: &now,
|
||||
ExpiresAt: new(now.Add(30 * time.Minute)),
|
||||
LastRefreshStatus: "success",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(freshRaw)
|
||||
}))
|
||||
|
||||
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
|
||||
ClientID: "client-headless",
|
||||
Metadata: map[string]any{
|
||||
domain.MetadataHeadlessLoginEnabled: true,
|
||||
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
}, "fresh-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, refreshed)
|
||||
require.NotNil(t, keySet)
|
||||
assert.Len(t, keySet.Keys, 1)
|
||||
require.NotNil(t, state)
|
||||
assert.Equal(t, []string{"fresh-key"}, state.CachedKids)
|
||||
|
||||
stored, err := cacheService.GetState("client-headless")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored)
|
||||
assert.Equal(t, []string{"fresh-key"}, stored.CachedKids)
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheService_PersistRefreshFailure_SetsNextRetryAtAfterThreshold(t *testing.T) {
|
||||
redisRepo := &headlessJWKSCacheTestRedis{}
|
||||
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
|
||||
cacheService.FailureThreshold = 3
|
||||
cacheService.FailureBackoff = 15 * time.Minute
|
||||
|
||||
client := domain.HydraClient{
|
||||
ClientID: "client-headless",
|
||||
Metadata: map[string]any{
|
||||
domain.MetadataHeadlessLoginEnabled: true,
|
||||
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
}
|
||||
|
||||
previous := &domain.HeadlessJWKSCacheState{
|
||||
ClientID: client.ClientID,
|
||||
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
|
||||
LastRefreshStatus: "failure",
|
||||
ConsecutiveFailures: 2,
|
||||
}
|
||||
|
||||
state := cacheService.persistRefreshFailure(client, previous, assert.AnError)
|
||||
require.NotNil(t, state)
|
||||
assert.Equal(t, 3, state.ConsecutiveFailures)
|
||||
require.NotNil(t, state.NextRetryAt)
|
||||
assert.WithinDuration(t, time.Now().Add(15*time.Minute), *state.NextRetryAt, 3*time.Second)
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheService_ShouldPrefetch_SkipsUntilNextRetryAt(t *testing.T) {
|
||||
cacheService := NewHeadlessJWKSCacheService(&headlessJWKSCacheTestRedis{}, nil)
|
||||
now := time.Now()
|
||||
|
||||
state := &domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-headless",
|
||||
LastRefreshStatus: "failure",
|
||||
ConsecutiveFailures: 3,
|
||||
NextRetryAt: new(now.Add(10 * time.Minute)),
|
||||
}
|
||||
|
||||
assert.False(t, cacheService.ShouldPrefetch(state, now))
|
||||
assert.True(t, cacheService.ShouldPrefetch(state, now.Add(11*time.Minute)))
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheWorker_RunOnce_SkipsBackoffTargets(t *testing.T) {
|
||||
clients := []domain.HydraClient{
|
||||
newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json"),
|
||||
newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json"),
|
||||
}
|
||||
|
||||
hydra := &HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: clientForHandler(jsonHandler(t, clients)),
|
||||
}
|
||||
|
||||
redisRepo := &headlessJWKSCacheTestRedis{}
|
||||
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
|
||||
cacheService.FailureThreshold = 3
|
||||
cacheService.FailureBackoff = 15 * time.Minute
|
||||
|
||||
now := time.Now()
|
||||
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-fail",
|
||||
JWKSURI: clients[0].HeadlessJWKSURI(),
|
||||
LastRefreshStatus: "failure",
|
||||
ConsecutiveFailures: 2,
|
||||
}))
|
||||
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-skip",
|
||||
JWKSURI: clients[1].HeadlessJWKSURI(),
|
||||
LastRefreshStatus: "failure",
|
||||
ConsecutiveFailures: 3,
|
||||
NextRetryAt: new(now.Add(10 * time.Minute)),
|
||||
}))
|
||||
|
||||
fetchCounts := map[string]int{}
|
||||
cacheService.HTTPClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
fetchCounts[req.URL.Host]++
|
||||
if req.URL.Host == "fail.example.com" {
|
||||
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
|
||||
}
|
||||
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
|
||||
return nil, nil
|
||||
}),
|
||||
}
|
||||
|
||||
worker := &HeadlessJWKSCacheWorker{
|
||||
Hydra: hydra,
|
||||
Cache: cacheService,
|
||||
PageSize: 100,
|
||||
}
|
||||
|
||||
worker.runOnce(context.Background())
|
||||
|
||||
assert.Equal(t, 1, fetchCounts["fail.example.com"])
|
||||
assert.Equal(t, 0, fetchCounts["skip.example.com"])
|
||||
|
||||
failedState, err := cacheService.GetState("client-fail")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, failedState)
|
||||
assert.Equal(t, 3, failedState.ConsecutiveFailures)
|
||||
require.NotNil(t, failedState.NextRetryAt)
|
||||
|
||||
skippedState, err := cacheService.GetState("client-skip")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, skippedState)
|
||||
assert.Equal(t, 3, skippedState.ConsecutiveFailures)
|
||||
require.NotNil(t, skippedState.NextRetryAt)
|
||||
assert.WithinDuration(t, now.Add(10*time.Minute), *skippedState.NextRetryAt, time.Second)
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheWorker_RunOnce_RetriesAfterBackoffAndClearsFailureStateOnSuccess(t *testing.T) {
|
||||
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
|
||||
freshRaw, err := json.Marshal(freshJWKS)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := newTestHeadlessClient("client-recover", "https://recover.example.com/.well-known/jwks.json")
|
||||
hydra := &HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{client})),
|
||||
}
|
||||
|
||||
redisRepo := &headlessJWKSCacheTestRedis{}
|
||||
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
|
||||
cacheService.FailureThreshold = 3
|
||||
cacheService.FailureBackoff = 15 * time.Minute
|
||||
|
||||
require.NoError(t, cacheService.SaveState("client-recover", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-recover",
|
||||
JWKSURI: client.HeadlessJWKSURI(),
|
||||
LastRefreshStatus: "failure",
|
||||
LastError: "previous failure",
|
||||
ConsecutiveFailures: 3,
|
||||
NextRetryAt: new(time.Now().Add(-time.Minute)),
|
||||
}))
|
||||
|
||||
fetchCount := 0
|
||||
cacheService.HTTPClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
fetchCount++
|
||||
assert.Equal(t, "recover.example.com", req.URL.Host)
|
||||
return jsonHTTPResponse(http.StatusOK, string(freshRaw)), nil
|
||||
}),
|
||||
}
|
||||
|
||||
worker := &HeadlessJWKSCacheWorker{
|
||||
Hydra: hydra,
|
||||
Cache: cacheService,
|
||||
PageSize: 100,
|
||||
}
|
||||
|
||||
worker.runOnce(context.Background())
|
||||
|
||||
assert.Equal(t, 1, fetchCount)
|
||||
|
||||
recoveredState, err := cacheService.GetState("client-recover")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, recoveredState)
|
||||
assert.Equal(t, "success", recoveredState.LastRefreshStatus)
|
||||
assert.Empty(t, recoveredState.LastError)
|
||||
assert.Equal(t, 0, recoveredState.ConsecutiveFailures)
|
||||
assert.Nil(t, recoveredState.NextRetryAt)
|
||||
assert.Equal(t, []string{"fresh-key"}, recoveredState.CachedKids)
|
||||
}
|
||||
|
||||
func TestHeadlessJWKSCacheWorker_RunOnce_MixedClients(t *testing.T) {
|
||||
_, successJWKS := mustServiceHeadlessRSAJWK(t, "success-key")
|
||||
successRaw, err := json.Marshal(successJWKS)
|
||||
require.NoError(t, err)
|
||||
|
||||
successClient := newTestHeadlessClient("client-success", "https://success.example.com/.well-known/jwks.json")
|
||||
failClient := newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json")
|
||||
skipClient := newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json")
|
||||
disabledClient := domain.HydraClient{
|
||||
ClientID: "client-disabled",
|
||||
Metadata: map[string]any{
|
||||
domain.MetadataHeadlessLoginEnabled: false,
|
||||
domain.MetadataHeadlessJWKSURI: "https://disabled.example.com/.well-known/jwks.json",
|
||||
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
|
||||
},
|
||||
}
|
||||
|
||||
hydra := &HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{
|
||||
successClient,
|
||||
failClient,
|
||||
skipClient,
|
||||
disabledClient,
|
||||
})),
|
||||
}
|
||||
|
||||
redisRepo := &headlessJWKSCacheTestRedis{}
|
||||
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
|
||||
cacheService.FailureThreshold = 3
|
||||
cacheService.FailureBackoff = 20 * time.Minute
|
||||
|
||||
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-fail",
|
||||
JWKSURI: failClient.HeadlessJWKSURI(),
|
||||
LastRefreshStatus: "failure",
|
||||
ConsecutiveFailures: 2,
|
||||
}))
|
||||
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
|
||||
ClientID: "client-skip",
|
||||
JWKSURI: skipClient.HeadlessJWKSURI(),
|
||||
LastRefreshStatus: "failure",
|
||||
ConsecutiveFailures: 3,
|
||||
NextRetryAt: new(time.Now().Add(10 * time.Minute)),
|
||||
}))
|
||||
|
||||
fetchCounts := map[string]int{}
|
||||
cacheService.HTTPClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
fetchCounts[req.URL.Host]++
|
||||
switch req.URL.Host {
|
||||
case "success.example.com":
|
||||
return jsonHTTPResponse(http.StatusOK, string(successRaw)), nil
|
||||
case "fail.example.com":
|
||||
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
|
||||
return nil, nil
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
worker := &HeadlessJWKSCacheWorker{
|
||||
Hydra: hydra,
|
||||
Cache: cacheService,
|
||||
PageSize: 100,
|
||||
}
|
||||
|
||||
worker.runOnce(context.Background())
|
||||
|
||||
assert.Equal(t, 1, fetchCounts["success.example.com"])
|
||||
assert.Equal(t, 1, fetchCounts["fail.example.com"])
|
||||
assert.Equal(t, 0, fetchCounts["skip.example.com"])
|
||||
assert.Equal(t, 0, fetchCounts["disabled.example.com"])
|
||||
|
||||
successState, err := cacheService.GetState("client-success")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, successState)
|
||||
assert.Equal(t, "success", successState.LastRefreshStatus)
|
||||
assert.Equal(t, 0, successState.ConsecutiveFailures)
|
||||
assert.Nil(t, successState.NextRetryAt)
|
||||
|
||||
failState, err := cacheService.GetState("client-fail")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, failState)
|
||||
assert.Equal(t, "failure", failState.LastRefreshStatus)
|
||||
assert.Equal(t, 3, failState.ConsecutiveFailures)
|
||||
require.NotNil(t, failState.NextRetryAt)
|
||||
}
|
||||
|
||||
func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicJWK := jose.JSONWebKey{
|
||||
Key: &privateKey.PublicKey,
|
||||
KeyID: kid,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
}
|
||||
|
||||
return privateKey, jose.JSONWebKeySet{Keys: []jose.JSONWebKey{publicJWK}}
|
||||
}
|
||||
|
||||
//go:fix inline
|
||||
func ptrTestTime(value time.Time) *time.Time {
|
||||
return new(value)
|
||||
}
|
||||
|
||||
func newTestHeadlessClient(clientID, jwksURI string) domain.HydraClient {
|
||||
return domain.HydraClient{
|
||||
ClientID: clientID,
|
||||
Metadata: map[string]any{
|
||||
domain.MetadataHeadlessLoginEnabled: true,
|
||||
domain.MetadataHeadlessJWKSURI: jwksURI,
|
||||
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func jsonHandler(t *testing.T, payload any) http.HandlerFunc {
|
||||
t.Helper()
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func jsonHTTPResponse(status int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
639
baron-sso/backend/internal/service/hydra_admin_service.go
Normal file
639
baron-sso/backend/internal/service/hydra_admin_service.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrHydraNotFound = errors.New("hydra admin: resource not found")
|
||||
|
||||
// HydraAdminService는 Hydra Admin API 호출을 래핑합니다.
|
||||
type HydraAdminService struct {
|
||||
AdminURL string
|
||||
PublicURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewHydraAdminService() *HydraAdminService {
|
||||
return &HydraAdminService{
|
||||
AdminURL: getenv("HYDRA_ADMIN_URL", "http://hydra:4445"),
|
||||
PublicURL: getenv("HYDRA_PUBLIC_URL", "http://hydra:4444"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error) {
|
||||
endpoint, err := s.buildURL("/clients", map[string]int{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrHydraNotFound
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("hydra admin: list clients failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var clients []domain.HydraClient
|
||||
if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode clients failed: %w", err)
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) GetClient(ctx context.Context, clientID string) (*domain.HydraClient, error) {
|
||||
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrHydraNotFound
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("hydra admin: get client failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var client domain.HydraClient
|
||||
if err := json.NewDecoder(resp.Body).Decode(&client); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode client failed: %w", err)
|
||||
}
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) PatchClientStatus(ctx context.Context, clientID, status string) (*domain.HydraClient, error) {
|
||||
// JSON Patch format
|
||||
payload := []map[string]any{
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/metadata/status",
|
||||
"value": status,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json-patch+json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrHydraNotFound
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("hydra admin: patch client failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var updated domain.HydraClient
|
||||
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode patched client failed: %w", err)
|
||||
}
|
||||
return &updated, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) CreateClient(ctx context.Context, client domain.HydraClient) (*domain.HydraClient, error) {
|
||||
body, _ := json.Marshal(client)
|
||||
endpoint := fmt.Sprintf("%s/clients", strings.TrimRight(s.AdminURL, "/"))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("hydra admin: create client failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var created domain.HydraClient
|
||||
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode created client failed: %w", err)
|
||||
}
|
||||
return &created, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) UpdateClient(ctx context.Context, clientID string, client domain.HydraClient) (*domain.HydraClient, error) {
|
||||
client.ClientID = clientID
|
||||
body, _ := json.Marshal(client)
|
||||
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrHydraNotFound
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("hydra admin: update client failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var updated domain.HydraClient
|
||||
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode updated client failed: %w", err)
|
||||
}
|
||||
return &updated, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return ErrHydraNotFound
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return fmt.Errorf("hydra admin: delete client failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) ListConsentSessions(ctx context.Context, subject, clientID string) ([]domain.HydraConsentSession, error) {
|
||||
params := map[string]string{
|
||||
"subject": subject,
|
||||
}
|
||||
if clientID != "" {
|
||||
params["client"] = clientID
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/sessions/consent", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return []domain.HydraConsentSession{}, nil
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||
if resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("hydra admin: list consent sessions failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return []domain.HydraConsentSession{}, nil
|
||||
}
|
||||
|
||||
var sessions []domain.HydraConsentSession
|
||||
if err := json.Unmarshal(body, &sessions); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode consent sessions failed: %w body=%s", err, string(body))
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) RevokeConsentSessions(ctx context.Context, subject, clientID string) error {
|
||||
params := map[string]string{
|
||||
"subject": subject,
|
||||
}
|
||||
if clientID != "" {
|
||||
params["client"] = clientID
|
||||
} else {
|
||||
params["all"] = "true"
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/sessions/consent", params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return fmt.Errorf("hydra admin: revoke consent failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) httpClient() *http.Client {
|
||||
if s.HTTPClient != nil {
|
||||
return s.HTTPClient
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) buildURL(path string, ints map[string]int) (string, error) {
|
||||
base := strings.TrimRight(s.AdminURL, "/")
|
||||
u, err := url.Parse(base + path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
q := u.Query()
|
||||
for key, value := range ints {
|
||||
if value > 0 {
|
||||
q.Set(key, strconv.Itoa(value))
|
||||
}
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) buildURLWithParams(path string, params map[string]string) (string, error) {
|
||||
base := strings.TrimRight(s.AdminURL, "/")
|
||||
u, err := url.Parse(base + path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
q := u.Query()
|
||||
for key, value := range params {
|
||||
if value != "" {
|
||||
q.Set(key, value)
|
||||
}
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
type AcceptLoginRequestResponse struct {
|
||||
RedirectTo string `json:"redirectTo"`
|
||||
}
|
||||
|
||||
type AcceptConsentRequestResponse struct {
|
||||
RedirectTo string `json:"redirectTo"`
|
||||
}
|
||||
|
||||
type RejectConsentRequestResponse struct {
|
||||
RedirectTo string `json:"redirectTo"`
|
||||
}
|
||||
|
||||
type RejectLoginRequestResponse struct {
|
||||
RedirectTo string `json:"redirectTo"`
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) GetConsentRequest(ctx context.Context, challenge string) (*domain.HydraConsentRequest, error) {
|
||||
params := map[string]string{
|
||||
"consent_challenge": challenge,
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/consent", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: create request for get consent failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: get consent request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("hydra admin: get consent failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var consentReq domain.HydraConsentRequest
|
||||
if err := json.Unmarshal(body, &consentReq); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode get consent response failed: %w", err)
|
||||
}
|
||||
|
||||
return &consentReq, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) RejectConsentRequest(ctx context.Context, challenge string) (*RejectConsentRequestResponse, error) {
|
||||
params := map[string]string{
|
||||
"consent_challenge": challenge,
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/consent/reject", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"error": "access_denied",
|
||||
"error_description": "The user decided to reject the consent request.",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: create request for reject consent failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: reject consent request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("hydra admin: reject consent failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var hydraResp struct {
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode reject consent response failed: %w", err)
|
||||
}
|
||||
|
||||
return &RejectConsentRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) RejectLoginRequest(ctx context.Context, challenge, error, errorDescription string) (*RejectLoginRequestResponse, error) {
|
||||
params := map[string]string{
|
||||
"login_challenge": challenge,
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/login/reject", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"error": error,
|
||||
"error_description": errorDescription,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: create request for reject login failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: reject login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("hydra admin: reject login failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var hydraResp struct {
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode reject login response failed: %w", err)
|
||||
}
|
||||
|
||||
return &RejectLoginRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) GetLoginRequest(ctx context.Context, challenge string) (*domain.HydraLoginRequest, error) {
|
||||
params := map[string]string{
|
||||
"login_challenge": challenge,
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/login", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: create request for get login failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: get login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("hydra admin: get login failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var loginReq domain.HydraLoginRequest
|
||||
if err := json.Unmarshal(body, &loginReq); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode get login response failed: %w", err)
|
||||
}
|
||||
|
||||
return &loginReq, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) AcceptConsentRequest(ctx context.Context, challenge string, grantInfo *domain.HydraConsentRequest, sessionClaims map[string]any) (*AcceptConsentRequestResponse, error) {
|
||||
params := map[string]string{
|
||||
"consent_challenge": challenge,
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/consent/accept", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"grant_scope": grantInfo.RequestedScope,
|
||||
"grant_audience": grantInfo.RequestedAudience,
|
||||
"remember": true,
|
||||
"remember_for": 2592000,
|
||||
}
|
||||
if len(sessionClaims) > 0 {
|
||||
payload["session"] = map[string]any{
|
||||
"id_token": sessionClaims,
|
||||
"access_token": sessionClaims,
|
||||
}
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: create request for accept consent failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: accept consent request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("hydra admin: accept consent failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Hydra 응답(redirect_to)을 읽어서 우리 응답(redirectTo)으로 변환
|
||||
var hydraResp struct {
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode accept consent response failed: %w", err)
|
||||
}
|
||||
|
||||
return &AcceptConsentRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) AcceptLoginRequest(ctx context.Context, challenge string, subject string) (*AcceptLoginRequestResponse, error) {
|
||||
params := map[string]string{
|
||||
"login_challenge": challenge,
|
||||
}
|
||||
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/login/accept", params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"subject": subject,
|
||||
"remember": true,
|
||||
"remember_for": 2592000,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: create request for accept login failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: accept login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("hydra admin: accept login failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Hydra 응답(redirect_to)을 읽어서 우리 응답(redirectTo)으로 변환
|
||||
var hydraResp struct {
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
|
||||
return nil, fmt.Errorf("hydra admin: decode accept login response failed: %w", err)
|
||||
}
|
||||
|
||||
return &AcceptLoginRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
|
||||
}
|
||||
|
||||
type HydraIntrospectionResponse struct {
|
||||
Active bool `json:"active"`
|
||||
Subject string `json:"sub"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
Ext map[string]any `json:"ext"`
|
||||
}
|
||||
|
||||
func (s *HydraAdminService) IntrospectToken(ctx context.Context, token string) (*HydraIntrospectionResponse, error) {
|
||||
endpoint := fmt.Sprintf("%s/oauth2/introspect", strings.TrimRight(s.AdminURL, "/"))
|
||||
form := url.Values{}
|
||||
form.Set("token", token)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("hydra admin: introspection failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var res HydraIntrospectionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
331
baron-sso/backend/internal/service/hydra_admin_service_test.go
Normal file
331
baron-sso/backend/internal/service/hydra_admin_service_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHydraAdminService_ListClients(t *testing.T) {
|
||||
clients := []domain.HydraClient{
|
||||
{ClientID: "client1", ClientName: "Client 1"},
|
||||
{ClientID: "client2", ClientName: "Client 2"},
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients", r.URL.Path)
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Equal(t, "10", r.URL.Query().Get("limit"))
|
||||
assert.Equal(t, "5", r.URL.Query().Get("offset"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(clients)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
result, err := s.ListClients(context.Background(), 10, 5)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, clients, result)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_GetClient(t *testing.T) {
|
||||
client := domain.HydraClient{ClientID: "test-client", ClientName: "Test Client"}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients/test-client", r.URL.Path)
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(client)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
result, err := s.GetClient(context.Background(), "test-client")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &client, result)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_CreateClient(t *testing.T) {
|
||||
client := domain.HydraClient{ClientName: "New Client"}
|
||||
created := domain.HydraClient{ClientID: "new-id", ClientName: "New Client"}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients", r.URL.Path)
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
|
||||
var received domain.HydraClient
|
||||
_ = json.NewDecoder(r.Body).Decode(&received)
|
||||
assert.Equal(t, client.ClientName, received.ClientName)
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(created)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
result, err := s.CreateClient(context.Background(), client)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &created, result)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_DeleteClient(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients/to-delete", r.URL.Path)
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
err := s.DeleteClient(context.Background(), "to-delete")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_GetConsentRequest(t *testing.T) {
|
||||
challenge := "challenge123"
|
||||
consentReq := domain.HydraConsentRequest{Challenge: challenge, Subject: "user1"}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/requests/consent", r.URL.Path)
|
||||
assert.Equal(t, challenge, r.URL.Query().Get("consent_challenge"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(consentReq)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
result, err := s.GetConsentRequest(context.Background(), challenge)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &consentReq, result)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_PatchClientStatus(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients/test-client", r.URL.Path)
|
||||
assert.Equal(t, "PATCH", r.Method)
|
||||
assert.Equal(t, "application/json-patch+json", r.Header.Get("Content-Type"))
|
||||
|
||||
var payload []map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&payload)
|
||||
assert.Equal(t, "replace", payload[0]["op"])
|
||||
assert.Equal(t, "/metadata/status", payload[0]["path"])
|
||||
assert.Equal(t, "inactive", payload[0]["value"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraClient{ClientID: "test-client"})
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
_, err := s.PatchClientStatus(context.Background(), "test-client", "inactive")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_UpdateClient(t *testing.T) {
|
||||
client := domain.HydraClient{ClientName: "Updated Name"}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/clients/test-client", r.URL.Path)
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(client)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
_, err := s.UpdateClient(context.Background(), "test-client", client)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_ListConsentSessions(t *testing.T) {
|
||||
sessions := []domain.HydraConsentSession{{Subject: "user1"}}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/sessions/consent", r.URL.Path)
|
||||
assert.Equal(t, "user1", r.URL.Query().Get("subject"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(sessions)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
result, err := s.ListConsentSessions(context.Background(), "user1", "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sessions, result)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_RevokeConsentSessions(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/sessions/consent", r.URL.Path)
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
err := s.RevokeConsentSessions(context.Background(), "user1", "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_RejectConsentRequest(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/requests/consent/reject", r.URL.Path)
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://reject"})
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
resp, err := s.RejectConsentRequest(context.Background(), "challenge")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http://reject", resp.RedirectTo)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_RejectLoginRequest(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/requests/login/reject", r.URL.Path)
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://reject-login"})
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
resp, err := s.RejectLoginRequest(context.Background(), "challenge", "error", "desc")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http://reject-login", resp.RedirectTo)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_GetLoginRequest(t *testing.T) {
|
||||
loginReq := domain.HydraLoginRequest{Challenge: "challenge"}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/requests/login", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(loginReq)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
result, err := s.GetLoginRequest(context.Background(), "challenge")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &loginReq, result)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_AcceptConsentRequest(t *testing.T) {
|
||||
grant := &domain.HydraConsentRequest{RequestedScope: []string{"openid"}}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/requests/consent/accept", r.URL.Path)
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://accept"})
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
resp, err := s.AcceptConsentRequest(context.Background(), "challenge", grant, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http://accept", resp.RedirectTo)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_AcceptLoginRequest(t *testing.T) {
|
||||
challenge := "login_challenge"
|
||||
subject := "user@example.com"
|
||||
redirectTo := "http://hydra/auth/confirm"
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/oauth2/auth/requests/login/accept", r.URL.Path)
|
||||
assert.Equal(t, challenge, r.URL.Query().Get("login_challenge"))
|
||||
|
||||
var body map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
assert.Equal(t, subject, body["subject"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": redirectTo})
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
result, err := s.AcceptLoginRequest(context.Background(), challenge, subject)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, redirectTo, result.RedirectTo)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_ErrorHandling(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
_, err := s.GetClient(context.Background(), "invalid")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status=400")
|
||||
|
||||
err = s.DeleteClient(context.Background(), "invalid")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = s.ListClients(context.Background(), 10, 0)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = s.PatchClientStatus(context.Background(), "invalid", "active")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHydraAdminService_NotFound(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
s := &HydraAdminService{
|
||||
AdminURL: "http://hydra-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
_, err := s.GetClient(context.Background(), "none")
|
||||
assert.Equal(t, ErrHydraNotFound, err)
|
||||
}
|
||||
78
baron-sso/backend/internal/service/keto_relay_worker.go
Normal file
78
baron-sso/backend/internal/service/keto_relay_worker.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type KetoRelayWorker interface {
|
||||
Start(ctx context.Context)
|
||||
}
|
||||
|
||||
type ketoRelayWorker struct {
|
||||
outboxRepo repository.KetoOutboxRepository
|
||||
ketoService KetoService
|
||||
interval time.Duration
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
func NewKetoRelayWorker(outboxRepo repository.KetoOutboxRepository, ketoService KetoService) KetoRelayWorker {
|
||||
return &ketoRelayWorker{
|
||||
outboxRepo: outboxRepo,
|
||||
ketoService: ketoService,
|
||||
interval: 5 * time.Second, // Poll every 5 seconds
|
||||
maxRetries: 5,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ketoRelayWorker) Start(ctx context.Context) {
|
||||
slog.Info("[KetoRelayWorker] Starting worker...")
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("[KetoRelayWorker] Stopping worker...")
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.processEntries(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ketoRelayWorker) processEntries(ctx context.Context) {
|
||||
entries, err := w.outboxRepo.FindPending(ctx, 50) // Process up to 50 at once
|
||||
if err != nil {
|
||||
slog.Error("[KetoRelayWorker] Failed to fetch pending entries", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
w.processEntry(ctx, entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ketoRelayWorker) processEntry(ctx context.Context, entry domain.KetoOutbox) {
|
||||
var err error
|
||||
if entry.Action == domain.KetoOutboxActionCreate {
|
||||
err = w.ketoService.CreateRelation(ctx, entry.Namespace, entry.Object, entry.Relation, entry.Subject)
|
||||
} else if entry.Action == domain.KetoOutboxActionDelete {
|
||||
err = w.ketoService.DeleteRelation(ctx, entry.Namespace, entry.Object, entry.Relation, entry.Subject)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
slog.Error("[KetoRelayWorker] Failed to process entry", "id", entry.ID, "error", err)
|
||||
newRetryCount := entry.RetryCount + 1
|
||||
status := domain.KetoOutboxStatusPending
|
||||
if newRetryCount >= w.maxRetries {
|
||||
status = domain.KetoOutboxStatusFailed
|
||||
}
|
||||
_ = w.outboxRepo.UpdateStatus(ctx, entry.ID, status, newRetryCount, err.Error())
|
||||
} else {
|
||||
_ = w.outboxRepo.MarkProcessed(ctx, entry.ID)
|
||||
}
|
||||
}
|
||||
267
baron-sso/backend/internal/service/keto_service.go
Normal file
267
baron-sso/backend/internal/service/keto_service.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type KetoService interface {
|
||||
CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error)
|
||||
CreateRelation(ctx context.Context, namespace, object, relation, subject string) error
|
||||
DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error
|
||||
ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error)
|
||||
ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error)
|
||||
}
|
||||
|
||||
type ketoService struct {
|
||||
readURL string
|
||||
writeURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewKetoService() KetoService {
|
||||
readURL := os.Getenv("KETO_READ_URL")
|
||||
if readURL == "" {
|
||||
readURL = "http://keto:4466"
|
||||
}
|
||||
writeURL := os.Getenv("KETO_WRITE_URL")
|
||||
if writeURL == "" {
|
||||
writeURL = "http://keto:4467"
|
||||
}
|
||||
|
||||
return &ketoService{
|
||||
readURL: readURL,
|
||||
writeURL: writeURL,
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
type RelationTuple struct {
|
||||
Namespace string `json:"namespace"`
|
||||
Object string `json:"object"`
|
||||
Relation string `json:"relation"`
|
||||
SubjectID string `json:"subject_id"`
|
||||
}
|
||||
|
||||
type relationTuplesResponse struct {
|
||||
RelationTuples []RelationTuple `json:"relation_tuples"`
|
||||
NextPageToken string `json:"next_page_token"`
|
||||
}
|
||||
|
||||
func (s *ketoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) {
|
||||
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples", s.readURL))
|
||||
q := u.Query()
|
||||
if namespace != "" {
|
||||
q.Set("namespace", namespace)
|
||||
}
|
||||
if object != "" {
|
||||
q.Set("object", object)
|
||||
}
|
||||
if relation != "" {
|
||||
q.Set("relation", relation)
|
||||
}
|
||||
if subject != "" {
|
||||
q.Set("subject_id", subject)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var res relationTuplesResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.RelationTuples, nil
|
||||
}
|
||||
|
||||
type checkResponse struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
}
|
||||
|
||||
func (s *ketoService) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
|
||||
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples/check", s.readURL))
|
||||
q := u.Query()
|
||||
q.Set("namespace", namespace)
|
||||
q.Set("object", object)
|
||||
q.Set("relation", relation)
|
||||
q.Set("subject_id", subject)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
var lastErr error
|
||||
maxRetries := 5
|
||||
backoff := 200 * time.Millisecond
|
||||
|
||||
for i := range maxRetries {
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
resp, err := s.client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var res checkResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return res.Allowed, nil
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return false, nil
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
lastErr = fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if i < maxRetries-1 {
|
||||
slog.Debug("Retrying Keto CheckPermission...", "attempt", i+1, "error", lastErr)
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
}
|
||||
}
|
||||
|
||||
return false, lastErr
|
||||
}
|
||||
|
||||
func (s *ketoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
u := fmt.Sprintf("%s/admin/relation-tuples", s.writeURL)
|
||||
payload := map[string]any{
|
||||
"namespace": namespace,
|
||||
"object": object,
|
||||
"relation": relation,
|
||||
"subject_id": subject,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
// Exponential Backoff Retry Logic
|
||||
var lastErr error
|
||||
maxRetries := 5
|
||||
backoff := 200 * time.Millisecond
|
||||
|
||||
for i := range maxRetries {
|
||||
req, _ := http.NewRequestWithContext(ctx, "PUT", u, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusOK {
|
||||
slog.Debug("Keto relation created", "namespace", namespace, "object", object, "relation", relation, "subject", subject)
|
||||
return nil
|
||||
}
|
||||
resBody, _ := io.ReadAll(resp.Body)
|
||||
lastErr = fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(resBody))
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if i < maxRetries-1 {
|
||||
slog.Debug("Retrying Keto CreateRelation...", "attempt", i+1, "error", lastErr)
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
}
|
||||
}
|
||||
|
||||
slog.Error("Keto create relation failed after retries", "error", lastErr, "namespace", namespace, "object", object, "relation", relation, "subject", subject)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s *ketoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
u, _ := url.Parse(fmt.Sprintf("%s/admin/relation-tuples", s.writeURL))
|
||||
q := u.Query()
|
||||
q.Set("namespace", namespace)
|
||||
q.Set("object", object)
|
||||
q.Set("relation", relation)
|
||||
q.Set("subject_id", subject)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
var lastErr error
|
||||
maxRetries := 5
|
||||
backoff := 200 * time.Millisecond
|
||||
|
||||
for i := range maxRetries {
|
||||
req, _ := http.NewRequestWithContext(ctx, "DELETE", u.String(), nil)
|
||||
resp, err := s.client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusOK {
|
||||
slog.Debug("Keto relation deleted", "namespace", namespace, "object", object, "relation", relation, "subject", subject)
|
||||
return nil
|
||||
}
|
||||
resBody, _ := io.ReadAll(resp.Body)
|
||||
lastErr = fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(resBody))
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if i < maxRetries-1 {
|
||||
slog.Debug("Retrying Keto DeleteRelation...", "attempt", i+1, "error", lastErr)
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
}
|
||||
}
|
||||
|
||||
slog.Error("Keto delete relation failed after retries", "error", lastErr)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s *ketoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples", s.readURL))
|
||||
q := u.Query()
|
||||
if namespace != "" {
|
||||
q.Set("namespace", namespace)
|
||||
}
|
||||
if relation != "" {
|
||||
q.Set("relation", relation)
|
||||
}
|
||||
if subject != "" {
|
||||
q.Set("subject_id", subject)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var res relationTuplesResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
objects := make([]string, 0, len(res.RelationTuples))
|
||||
seen := make(map[string]bool)
|
||||
for _, rt := range res.RelationTuples {
|
||||
if !seen[rt.Object] {
|
||||
objects = append(objects, rt.Object)
|
||||
seen[rt.Object] = true
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
156
baron-sso/backend/internal/service/keto_service_test.go
Normal file
156
baron-sso/backend/internal/service/keto_service_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestKetoService_CheckPermission(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/relation-tuples/check", r.URL.Path)
|
||||
assert.Equal(t, "user1", r.URL.Query().Get("subject_id"))
|
||||
assert.Equal(t, "tenants", r.URL.Query().Get("namespace"))
|
||||
assert.Equal(t, "tenant1", r.URL.Query().Get("object"))
|
||||
assert.Equal(t, "admin", r.URL.Query().Get("relation"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(checkResponse{Allowed: true})
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
readURL: "http://keto-read.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
|
||||
allowed, err := s.CheckPermission(context.Background(), "user1", "tenants", "tenant1", "admin")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, allowed)
|
||||
}
|
||||
|
||||
func TestKetoService_CreateRelation(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/admin/relation-tuples", r.URL.Path)
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
|
||||
var body map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
assert.Equal(t, "tenants", body["namespace"])
|
||||
assert.Equal(t, "tenant1", body["object"])
|
||||
assert.Equal(t, "admin", body["relation"])
|
||||
assert.Equal(t, "user1", body["subject_id"])
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
writeURL: "http://keto-write.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
|
||||
err := s.CreateRelation(context.Background(), "tenants", "tenant1", "admin", "user1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestKetoService_DeleteRelation(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/admin/relation-tuples", r.URL.Path)
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
assert.Equal(t, "user1", r.URL.Query().Get("subject_id"))
|
||||
assert.Equal(t, "tenants", r.URL.Query().Get("namespace"))
|
||||
assert.Equal(t, "tenant1", r.URL.Query().Get("object"))
|
||||
assert.Equal(t, "admin", r.URL.Query().Get("relation"))
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
writeURL: "http://keto-write.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
|
||||
err := s.DeleteRelation(context.Background(), "tenants", "tenant1", "admin", "user1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestKetoService_ListRelations(t *testing.T) {
|
||||
tuples := []RelationTuple{
|
||||
{Namespace: "tenants", Object: "tenant1", Relation: "admin", SubjectID: "user1"},
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/relation-tuples", r.URL.Path)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(relationTuplesResponse{RelationTuples: tuples})
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
readURL: "http://keto-read.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
|
||||
result, err := s.ListRelations(context.Background(), "tenants", "tenant1", "admin", "user1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tuples, result)
|
||||
}
|
||||
|
||||
func TestKetoService_ErrorHandling(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("internal error"))
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
readURL: "http://keto-read.local",
|
||||
writeURL: "http://keto-write.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
|
||||
_, err := s.CheckPermission(context.Background(), "u", "n", "o", "r")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = s.DeleteRelation(context.Background(), "n", "o", "r", "s")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = s.ListRelations(context.Background(), "n", "o", "r", "s")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestKetoService_CheckPermission_Forbidden(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
readURL: "http://keto-read.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
allowed, err := s.CheckPermission(context.Background(), "u", "n", "o", "r")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
}
|
||||
|
||||
func TestKetoService_CreateRelation_Retry(t *testing.T) {
|
||||
attempts := 0
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts < 2 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
|
||||
s := &ketoService{
|
||||
writeURL: "http://keto-write.local",
|
||||
client: clientForHandler(handler),
|
||||
}
|
||||
|
||||
err := s.CreateRelation(context.Background(), "n", "o", "r", "s")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, attempts)
|
||||
}
|
||||
556
baron-sso/backend/internal/service/kratos_admin_service.go
Normal file
556
baron-sso/backend/internal/service/kratos_admin_service.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type KratosIdentity struct {
|
||||
ID string `json:"id"`
|
||||
SchemaID string `json:"schema_id,omitempty"`
|
||||
Traits map[string]any `json:"traits"`
|
||||
State string `json:"state,omitempty"`
|
||||
MetadataAdmin any `json:"metadata_admin,omitempty"`
|
||||
MetadataPublic any `json:"metadata_public,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type KratosSessionDevice struct {
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
}
|
||||
|
||||
type KratosSession struct {
|
||||
ID string `json:"id"`
|
||||
Active bool `json:"active"`
|
||||
AuthenticatedAt time.Time `json:"authenticated_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
Identity *KratosIdentity `json:"identity,omitempty"`
|
||||
Devices []KratosSessionDevice `json:"devices,omitempty"`
|
||||
}
|
||||
|
||||
type KratosAdminService interface {
|
||||
ListIdentities(ctx context.Context) ([]KratosIdentity, error)
|
||||
FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error)
|
||||
GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error)
|
||||
UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*KratosIdentity, error)
|
||||
UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error
|
||||
DeleteIdentity(ctx context.Context, identityID string) error
|
||||
CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error)
|
||||
ListIdentitySessions(ctx context.Context, identityID string) ([]KratosSession, error)
|
||||
GetSession(ctx context.Context, sessionID string) (*KratosSession, error)
|
||||
DeleteSession(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type kratosAdminService struct {
|
||||
AdminURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewKratosAdminService() KratosAdminService {
|
||||
return &kratosAdminService{
|
||||
AdminURL: getenvKratos("KRATOS_ADMIN_URL", "http://kratos:4434"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) ListIdentities(ctx context.Context) ([]KratosIdentity, error) {
|
||||
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
|
||||
var identities []KratosIdentity
|
||||
pageToken := ""
|
||||
seenTokens := make(map[string]bool)
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query := req.URL.Query()
|
||||
query.Set("page_size", "250")
|
||||
if pageToken != "" {
|
||||
query.Set("page_token", pageToken)
|
||||
}
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
_ = resp.Body.Close()
|
||||
return nil, fmt.Errorf("kratos admin list identities failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var page []KratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&page); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
identities = append(identities, page...)
|
||||
|
||||
nextToken := kratosNextPageToken(resp.Header.Values("Link"))
|
||||
if nextToken == "" {
|
||||
return identities, nil
|
||||
}
|
||||
if seenTokens[nextToken] {
|
||||
return nil, fmt.Errorf("kratos admin list identities pagination loop detected page_token=%s", nextToken)
|
||||
}
|
||||
seenTokens[nextToken] = true
|
||||
pageToken = nextToken
|
||||
}
|
||||
}
|
||||
|
||||
func kratosNextPageToken(linkHeaders []string) string {
|
||||
for _, header := range linkHeaders {
|
||||
for _, part := range strings.Split(header, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if !strings.Contains(part, `rel="next"`) && !strings.Contains(part, `rel=next`) {
|
||||
continue
|
||||
}
|
||||
start := strings.Index(part, "<")
|
||||
end := strings.Index(part, ">")
|
||||
if start < 0 || end <= start+1 {
|
||||
continue
|
||||
}
|
||||
rawURL := part[start+1 : end]
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if token := strings.TrimSpace(parsed.Query().Get("page_token")); token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
|
||||
|
||||
// 1. Try credentials_identifier (Email/LoginID/Phone)
|
||||
id, err := s.searchIdentities(ctx, endpoint, "credentials_identifier", identifier)
|
||||
if err == nil && id != "" {
|
||||
// VERIFY: Kratos sometimes ignores unknown query params and returns the first identity.
|
||||
if s.verifyIdentityMatch(ctx, id, identifier) {
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := s.GetIdentity(ctx, identifier)
|
||||
if err == nil && identity != nil {
|
||||
return identity.ID, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) verifyIdentityMatch(ctx context.Context, id, identifier string) bool {
|
||||
identity, err := s.GetIdentity(ctx, id)
|
||||
if err != nil || identity == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Exact ID match
|
||||
if strings.EqualFold(identity.ID, identifier) {
|
||||
return true
|
||||
}
|
||||
// Check traits (Email, CustomLoginIDs)
|
||||
if email, ok := identity.Traits["email"].(string); ok && strings.EqualFold(email, identifier) {
|
||||
return true
|
||||
}
|
||||
if phone, ok := identity.Traits["phone_number"].(string); ok && strings.EqualFold(phone, identifier) {
|
||||
return true
|
||||
}
|
||||
if lids, ok := identity.Traits["custom_login_ids"].([]any); ok {
|
||||
for _, lid := range lids {
|
||||
if s, ok := lid.(string); ok && strings.EqualFold(s, identifier) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else if lids, ok := identity.Traits["custom_login_ids"].([]string); ok {
|
||||
for _, lid := range lids {
|
||||
if strings.EqualFold(lid, identifier) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) searchIdentities(ctx context.Context, endpoint, key, value string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Set(key, value)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return "", nil
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return "", fmt.Errorf("kratos admin search by %s failed status=%d body=%s", key, resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identities []struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identities); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(identities) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return identities[0].ID, nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) {
|
||||
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("kratos admin get identity failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identity KratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*KratosIdentity, error) {
|
||||
payload := map[string]any{
|
||||
"schema_id": "default",
|
||||
"traits": traits,
|
||||
}
|
||||
if strings.TrimSpace(state) != "" {
|
||||
payload["state"] = strings.TrimSpace(state)
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("kratos admin update identity failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var updated KratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &updated, nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
|
||||
identity, err := s.GetIdentity(ctx, identityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identity == nil {
|
||||
return fmt.Errorf("kratos admin identity not found: %s", identityID)
|
||||
}
|
||||
|
||||
hashedPassword, err := hashPasswordForKratosAdmin(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"schema_id": identity.SchemaID,
|
||||
"traits": identity.Traits,
|
||||
"state": identity.State,
|
||||
"credentials": map[string]any{
|
||||
"password": map[string]any{
|
||||
"config": map[string]string{
|
||||
"hashed_password": hashedPassword,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if payload["schema_id"] == "" {
|
||||
payload["schema_id"] = "default"
|
||||
}
|
||||
if payload["state"] == "" {
|
||||
payload["state"] = "active"
|
||||
}
|
||||
if identity.MetadataAdmin != nil {
|
||||
payload["metadata_admin"] = identity.MetadataAdmin
|
||||
}
|
||||
if identity.MetadataPublic != nil {
|
||||
payload["metadata_public"] = identity.MetadataPublic
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return fmt.Errorf("kratos admin update password failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
|
||||
if user == nil {
|
||||
return "", fmt.Errorf("kratos admin: user payload is nil")
|
||||
}
|
||||
if strings.TrimSpace(user.ID) != "" {
|
||||
return "", fmt.Errorf("kratos admin: requested identity id import is disabled; use backup/restore")
|
||||
}
|
||||
|
||||
traits := map[string]any{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
}
|
||||
if user.PhoneNumber != "" {
|
||||
traits["phone_number"] = user.PhoneNumber
|
||||
}
|
||||
for k, v := range user.Attributes {
|
||||
if k == "id" || k == "email" {
|
||||
continue
|
||||
}
|
||||
traits[k] = v
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"schema_id": "default",
|
||||
"traits": traits,
|
||||
"credentials": map[string]any{
|
||||
"password": map[string]any{
|
||||
"config": map[string]string{
|
||||
"password": password,
|
||||
},
|
||||
},
|
||||
},
|
||||
"state": "active",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return "", fmt.Errorf("kratos admin create identity failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return created.ID, nil
|
||||
}
|
||||
|
||||
func hashPasswordForKratosAdmin(password string) (string, error) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashed), nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error {
|
||||
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := s.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return fmt.Errorf("kratos admin delete identity failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) httpClient() *http.Client {
|
||||
if s.HTTPClient != nil {
|
||||
return s.HTTPClient
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getenvKratos(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) ListIdentitySessions(ctx context.Context, identityID string) ([]KratosSession, error) {
|
||||
url := fmt.Sprintf("%s/admin/identities/%s/sessions", s.AdminURL, identityID)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := s.HTTPClient
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return []KratosSession{}, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var sessions []KratosSession
|
||||
if err := json.NewDecoder(resp.Body).Decode(&sessions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) GetSession(ctx context.Context, sessionID string) (*KratosSession, error) {
|
||||
url := fmt.Sprintf("%s/admin/sessions/%s", s.AdminURL, sessionID)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := s.HTTPClient
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var session KratosSession
|
||||
if err := json.NewDecoder(resp.Body).Decode(&session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (s *kratosAdminService) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
url := fmt.Sprintf("%s/admin/sessions/%s", s.AdminURL, sessionID)
|
||||
req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := s.HTTPClient
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestKratosAdminService_ListIdentitiesFollowsNextPagination(t *testing.T) {
|
||||
var requestedTokens []string
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
require.Equal(t, "/admin/identities", r.URL.Path)
|
||||
token := r.URL.Query().Get("page_token")
|
||||
requestedTokens = append(requestedTokens, token)
|
||||
|
||||
header := make(http.Header)
|
||||
header.Set("Content-Type", "application/json")
|
||||
status := http.StatusOK
|
||||
body := "[]"
|
||||
switch token {
|
||||
case "":
|
||||
header.Set(
|
||||
"Link",
|
||||
`</admin/identities?page_size=2&page_token=identity-2>; rel="next"`,
|
||||
)
|
||||
body = `[{"id":"identity-1","traits":{"email":"one@example.com"}},{"id":"identity-2","traits":{"email":"two@example.com"}}]`
|
||||
case "identity-2":
|
||||
body = `[{"id":"identity-3","traits":{"email":"three@example.com"}}]`
|
||||
default:
|
||||
t.Fatalf("unexpected page_token %q", token)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: header,
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Request: r,
|
||||
}, nil
|
||||
})}
|
||||
|
||||
service := &kratosAdminService{
|
||||
AdminURL: "http://kratos.example",
|
||||
HTTPClient: client,
|
||||
}
|
||||
|
||||
identities, err := service.ListIdentities(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"", "identity-2"}, requestedTokens)
|
||||
require.Len(t, identities, 3)
|
||||
require.Equal(t, "identity-1", identities[0].ID)
|
||||
require.Equal(t, "identity-2", identities[1].ID)
|
||||
require.Equal(t, "identity-3", identities[2].ID)
|
||||
}
|
||||
139
baron-sso/backend/internal/service/mock_common_test.go
Normal file
139
baron-sso/backend/internal/service/mock_common_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// --- Shared Mocks for Service Tests ---
|
||||
|
||||
type MockKetoOutboxRepositoryShared struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKetoOutboxRepositoryShared) Create(ctx context.Context, entry *domain.KetoOutbox) error {
|
||||
return m.Called(ctx, entry).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoOutboxRepositoryShared) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
|
||||
return m.Called(tx, entry).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoOutboxRepositoryShared) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
|
||||
args := m.Called(ctx, limit)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]domain.KetoOutbox), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoOutboxRepositoryShared) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
|
||||
args := m.Called(ctx, namespace, subject)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]domain.KetoOutbox), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoOutboxRepositoryShared) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
|
||||
return m.Called(ctx, id, status, retryCount, lastError).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoOutboxRepositoryShared) MarkProcessed(ctx context.Context, id string) error {
|
||||
return m.Called(ctx, id).Error(0)
|
||||
}
|
||||
|
||||
type MockKetoServiceShared struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKetoServiceShared) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
|
||||
args := m.Called(ctx, subject, namespace, object, relation)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoServiceShared) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoServiceShared) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoServiceShared) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]RelationTuple), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoServiceShared) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
args := m.Called(ctx, namespace, relation, subject)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]string), args.Error(1)
|
||||
}
|
||||
|
||||
type MockKratosAdminServiceShared struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) ListIdentities(ctx context.Context) ([]KratosIdentity, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Get(0).([]KratosIdentity), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
||||
args := m.Called(ctx, identifier)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) {
|
||||
args := m.Called(ctx, identityID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*KratosIdentity), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*KratosIdentity, error) {
|
||||
args := m.Called(ctx, identityID, traits, state)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*KratosIdentity), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
|
||||
args := m.Called(ctx, identityID, newPassword)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) DeleteIdentity(ctx context.Context, identityID string) error {
|
||||
args := m.Called(ctx, identityID)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
|
||||
args := m.Called(ctx, user, password)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) ListIdentitySessions(ctx context.Context, identityID string) ([]KratosSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) GetSession(ctx context.Context, sessionID string) (*KratosSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceShared) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
return nil
|
||||
}
|
||||
967
baron-sso/backend/internal/service/ory_service.go
Normal file
967
baron-sso/backend/internal/service/ory_service.go
Normal file
@@ -0,0 +1,967 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// OryProvider는 Kratos/Hydra를 기반으로 하는 IDP 어댑터의 최소 스켈레톤입니다.
|
||||
// 지금은 스키마 메타데이터만 반환하며, 나머지 동작은 후속 작업에서 구현합니다.
|
||||
type OryProvider struct {
|
||||
KratosAdminURL string
|
||||
KratosPublicURL string
|
||||
HydraAdminURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewOryProvider() *OryProvider {
|
||||
return &OryProvider{
|
||||
KratosAdminURL: getenv("KRATOS_ADMIN_URL", "http://kratos:4434"),
|
||||
KratosPublicURL: getenv("KRATOS_PUBLIC_URL", "http://kratos:4433"),
|
||||
HydraAdminURL: getenv("HYDRA_ADMIN_URL", "http://hydra:4445"),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OryProvider) Name() string {
|
||||
return "Ory (Kratos/Hydra)"
|
||||
}
|
||||
|
||||
// GetMetadata는 BrokerUser가 요구하는 필드를 Kratos traits에 매핑 가능하다는 가정으로 반환합니다.
|
||||
func (o *OryProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{
|
||||
SupportedFields: []string{
|
||||
"id", "custom_login_ids", "login_id", "email", "name", "phone_number",
|
||||
"grade", "department", "affiliationType", "tenant_id",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUser는 Kratos Admin API를 통해 identity를 생성합니다.
|
||||
func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
if user == nil {
|
||||
return "", fmt.Errorf("ory provider: user payload is nil")
|
||||
}
|
||||
if user.Email == "" || password == "" {
|
||||
return "", fmt.Errorf("ory provider: email and password are required")
|
||||
}
|
||||
if strings.TrimSpace(user.ID) != "" {
|
||||
return "", fmt.Errorf("ory provider: requested identity id import is disabled; use backup/restore")
|
||||
}
|
||||
|
||||
existingID, err := o.findIdentityID(user.Email)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
|
||||
}
|
||||
if existingID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for email=%s", user.Email)
|
||||
}
|
||||
|
||||
// [New] Check all custom login IDs for collisions
|
||||
for _, lid := range user.CustomLoginIDs {
|
||||
if lid == "" {
|
||||
continue
|
||||
}
|
||||
existing, err := o.findIdentityID(lid)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: search identity failed for %s: %w", lid, err)
|
||||
}
|
||||
if existing != "" {
|
||||
return "", fmt.Errorf("ory provider: identifier %s already exists", lid)
|
||||
}
|
||||
}
|
||||
|
||||
// [Legacy] check single LoginID
|
||||
if user.LoginID != "" {
|
||||
existingLoginID, err := o.findIdentityID(user.LoginID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
|
||||
}
|
||||
if existingLoginID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for login_id=%s", user.LoginID)
|
||||
}
|
||||
}
|
||||
|
||||
if user.PhoneNumber != "" {
|
||||
existingPhoneID, err := o.findIdentityID(user.PhoneNumber)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
|
||||
}
|
||||
if existingPhoneID != "" {
|
||||
return "", fmt.Errorf("ory provider: identity already exists for phone=%s", user.PhoneNumber)
|
||||
}
|
||||
}
|
||||
|
||||
traits := map[string]any{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
}
|
||||
if len(user.CustomLoginIDs) > 0 {
|
||||
traits["custom_login_ids"] = user.CustomLoginIDs
|
||||
} else if user.LoginID != "" {
|
||||
traits["custom_login_ids"] = []string{user.LoginID}
|
||||
}
|
||||
|
||||
if user.PhoneNumber != "" {
|
||||
traits["phone_number"] = user.PhoneNumber
|
||||
}
|
||||
for k, v := range user.Attributes {
|
||||
// [SoT Fix] Don't let attributes overwrite core traits or use old 'id' trait
|
||||
if k == "id" || k == "email" || k == "custom_login_ids" {
|
||||
continue
|
||||
}
|
||||
traits[k] = v
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"schema_id": "default",
|
||||
"traits": traits,
|
||||
"credentials": map[string]any{
|
||||
"password": map[string]any{
|
||||
"config": map[string]string{
|
||||
"password": password,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
verifiable := []map[string]any{
|
||||
{
|
||||
"value": user.Email,
|
||||
"verified": true,
|
||||
"via": "email",
|
||||
},
|
||||
}
|
||||
if user.PhoneNumber != "" {
|
||||
verifiable = append(verifiable, map[string]any{
|
||||
"value": user.PhoneNumber,
|
||||
"verified": true,
|
||||
"via": "sms",
|
||||
})
|
||||
}
|
||||
payload["verifiable_addresses"] = verifiable
|
||||
payload["recovery_addresses"] = []map[string]any{
|
||||
{
|
||||
"value": user.Email,
|
||||
"via": "email",
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/admin/identities", o.KratosAdminURL), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: build create request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: create identity request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return "", fmt.Errorf("ory provider: create identity failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
|
||||
return "", fmt.Errorf("ory provider: decode create identity response failed: %w", err)
|
||||
}
|
||||
slog.Info("Ory identity created", "identity_id", created.ID, "email", user.Email)
|
||||
return created.ID, nil
|
||||
}
|
||||
|
||||
// SignIn은 Kratos Public API의 login API 플로우를 사용해 세션 토큰을 발급합니다.
|
||||
func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
if loginID == "" || password == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID and password are required")
|
||||
}
|
||||
|
||||
flowID, err := o.startLoginFlow("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"identifier": loginID,
|
||||
"password": password,
|
||||
"method": "password",
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build login request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("ory provider: login failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
|
||||
Session struct {
|
||||
ID string `json:"id"`
|
||||
Identity struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"identity"`
|
||||
} `json:"session"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode login response failed: %w", err)
|
||||
}
|
||||
if result.SessionToken == "" {
|
||||
return nil, fmt.Errorf("ory provider: empty session token returned")
|
||||
}
|
||||
|
||||
slog.Info("Ory login successful",
|
||||
"identity_id", result.Session.Identity.ID,
|
||||
"loginID", loginID,
|
||||
"expires_at", result.SessionTokenExpiresAt,
|
||||
)
|
||||
|
||||
return &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{
|
||||
JWT: result.SessionToken,
|
||||
Expiration: result.SessionTokenExpiresAt,
|
||||
SessionID: result.Session.ID,
|
||||
},
|
||||
Subject: result.Session.Identity.ID,
|
||||
SetCookies: resp.Cookies(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserExists는 Kratos Admin API로 loginID 존재 여부를 확인합니다.
|
||||
func (o *OryProvider) UserExists(loginID string) (bool, error) {
|
||||
if loginID == "" {
|
||||
return false, fmt.Errorf("ory provider: loginID is empty")
|
||||
}
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
return identityID != "", nil
|
||||
}
|
||||
|
||||
// IssueSession은 Ory에서 별도 세션 발급이 필요할 때 사용합니다. (현재 미지원)
|
||||
func (o *OryProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
// InitiateLinkLogin은 Kratos Public API로 링크 로그인 플로우를 시작하고 이메일 전송을 트리거합니다.
|
||||
func (o *OryProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
if loginID == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID is required")
|
||||
}
|
||||
|
||||
effectiveLoginID, err := o.resolveEffectiveLoginID(loginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := o.ensureCodeLoginIdentifier(effectiveLoginID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
init, err := o.submitLoginCodeInit(effectiveLoginID, returnTo)
|
||||
if err == nil {
|
||||
init.LoginID = effectiveLoginID
|
||||
return init, nil
|
||||
}
|
||||
|
||||
if shouldBootstrapCodeLogin(err) {
|
||||
if ensureErr := o.ensureCodeLoginIdentifier(effectiveLoginID); ensureErr == nil {
|
||||
init, initErr := o.submitLoginCodeInit(effectiveLoginID, returnTo)
|
||||
if initErr == nil {
|
||||
init.LoginID = effectiveLoginID
|
||||
}
|
||||
return init, initErr
|
||||
} else {
|
||||
slog.Warn("Ory code login bootstrap failed", "loginID", effectiveLoginID, "error", ensureErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (o *OryProvider) resolveEffectiveLoginID(loginID string) (string, error) {
|
||||
if strings.Contains(loginID, "@") {
|
||||
return loginID, nil
|
||||
}
|
||||
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if identityID == "" {
|
||||
return "", fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
fullIdentity, err := o.fetchIdentityFull(identityID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if fullIdentity != nil {
|
||||
if emailRaw, ok := fullIdentity.Traits["email"]; ok {
|
||||
if email, ok := emailRaw.(string); ok && email != "" {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("ory provider: email trait missing for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
func (o *OryProvider) submitLoginCodeInit(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
flowID, err := o.startLoginFlow(returnTo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"method": "code",
|
||||
"identifier": loginID,
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build link login request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: link login request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
if resp.StatusCode >= 300 {
|
||||
init, ok := parseKratosLinkLoginResponse(flowID, respBody)
|
||||
if ok {
|
||||
slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode)
|
||||
return init, nil
|
||||
}
|
||||
return nil, fmt.Errorf("ory provider: link login failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &result)
|
||||
|
||||
slog.Info("Ory link login initiated", "loginID", loginID, "flow_id", flowID)
|
||||
|
||||
return &domain.LinkLoginInit{
|
||||
FlowID: flowID,
|
||||
ExpiresAt: result.ExpiresAt,
|
||||
Mode: "link",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseKratosLinkLoginResponse(flowID string, body []byte) (*domain.LinkLoginInit, bool) {
|
||||
if len(body) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var parsed struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
State string `json:"state"`
|
||||
Active string `json:"active"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
state := strings.ToLower(parsed.State)
|
||||
active := strings.ToLower(parsed.Active)
|
||||
if strings.Contains(state, "sent") || active == "code" {
|
||||
return &domain.LinkLoginInit{
|
||||
FlowID: flowID,
|
||||
ExpiresAt: parsed.ExpiresAt,
|
||||
Mode: "link",
|
||||
}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func shouldBootstrapCodeLogin(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "has not setup sign in with code") ||
|
||||
strings.Contains(msg, "4000035")
|
||||
}
|
||||
|
||||
type kratosVerifiableAddress struct {
|
||||
Value string `json:"value"`
|
||||
Via string `json:"via"`
|
||||
Verified bool `json:"verified"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (o *OryProvider) ensureCodeLoginIdentifier(loginID string) error {
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
if identityID == "" {
|
||||
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
identity, err := o.fetchIdentity(identityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
via := "sms"
|
||||
if strings.Contains(loginID, "@") {
|
||||
via = "email"
|
||||
}
|
||||
|
||||
exists := false
|
||||
existingIndex := -1
|
||||
addresses := make([]kratosVerifiableAddress, 0, len(identity.VerifiableAddresses)+1)
|
||||
for idx, addr := range identity.VerifiableAddresses {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: addr.Value,
|
||||
Via: addr.Via,
|
||||
Verified: addr.Verified,
|
||||
Status: addr.Status,
|
||||
})
|
||||
if addr.Value == loginID && addr.Via == via {
|
||||
exists = true
|
||||
existingIndex = idx
|
||||
}
|
||||
}
|
||||
ops := make([]map[string]any, 0, 2)
|
||||
if !exists {
|
||||
ops = append(ops, map[string]any{
|
||||
"op": "add",
|
||||
"path": "/verifiable_addresses/-",
|
||||
"value": map[string]any{
|
||||
"value": loginID,
|
||||
"via": via,
|
||||
"verified": true,
|
||||
"status": "completed",
|
||||
},
|
||||
})
|
||||
} else {
|
||||
addr := identity.VerifiableAddresses[existingIndex]
|
||||
if !addr.Verified {
|
||||
ops = append(ops, map[string]any{
|
||||
"op": "replace",
|
||||
"path": fmt.Sprintf("/verifiable_addresses/%d/verified", existingIndex),
|
||||
"value": true,
|
||||
})
|
||||
}
|
||||
if addr.Status != "" && addr.Status != "completed" {
|
||||
ops = append(ops, map[string]any{
|
||||
"op": "replace",
|
||||
"path": fmt.Sprintf("/verifiable_addresses/%d/status", existingIndex),
|
||||
"value": "completed",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(ops) == 0 {
|
||||
slog.Info("Ory identity verifiable address already ready", "identity_id", identityID, "loginID", loginID, "via", via)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := o.patchIdentity(identityID, ops); err != nil {
|
||||
slog.Warn("Ory identity patch failed, trying full update", "identity_id", identityID, "error", err)
|
||||
}
|
||||
|
||||
fullIdentity, err := o.fetchIdentityFull(identityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addresses = make([]kratosVerifiableAddress, 0, len(fullIdentity.VerifiableAddresses)+1)
|
||||
found := false
|
||||
for _, addr := range fullIdentity.VerifiableAddresses {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: addr.Value,
|
||||
Via: addr.Via,
|
||||
Verified: addr.Verified,
|
||||
Status: addr.Status,
|
||||
})
|
||||
if addr.Value == loginID && addr.Via == via {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
addresses = append(addresses, kratosVerifiableAddress{
|
||||
Value: loginID,
|
||||
Via: via,
|
||||
Verified: true,
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"schema_id": fullIdentity.SchemaID,
|
||||
"traits": fullIdentity.Traits,
|
||||
"verifiable_addresses": addresses,
|
||||
}
|
||||
if len(fullIdentity.RecoveryAddresses) > 0 {
|
||||
payload["recovery_addresses"] = fullIdentity.RecoveryAddresses
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: build identity update failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: identity update failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("ory provider: identity update failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Info("Ory identity updated with verifiable address", "identity_id", identityID, "loginID", loginID, "via", via)
|
||||
return nil
|
||||
}
|
||||
|
||||
type kratosIdentity struct {
|
||||
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
|
||||
}
|
||||
|
||||
type kratosRecoveryAddress struct {
|
||||
Value string `json:"value"`
|
||||
Via string `json:"via"`
|
||||
}
|
||||
|
||||
type kratosIdentityFull struct {
|
||||
SchemaID string `json:"schema_id"`
|
||||
Traits map[string]any `json:"traits"`
|
||||
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
|
||||
RecoveryAddresses []kratosRecoveryAddress `json:"recovery_addresses"`
|
||||
}
|
||||
|
||||
func (o *OryProvider) patchIdentity(identityID string, ops []map[string]any) error {
|
||||
body, _ := json.Marshal(ops)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: build identity patch failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json-patch+json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: identity patch failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("ory provider: identity patch failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Info("Ory identity patched", "identity_id", identityID, "ops", len(ops))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) fetchIdentity(identityID string) (*kratosIdentity, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identity kratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) fetchIdentityFull(identityID string) (*kratosIdentityFull, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identity kratosIdentityFull
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
// VerifyLoginCode는 Kratos 로그인 코드 제출로 세션을 발급합니다.
|
||||
func (o *OryProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
if loginID == "" || flowID == "" || code == "" {
|
||||
return nil, fmt.Errorf("ory provider: loginID, flowID and code are required")
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"method": "code",
|
||||
"identifier": loginID,
|
||||
"code": code,
|
||||
})
|
||||
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: build login code request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ory provider: login code request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("ory provider: login code failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
|
||||
Session struct {
|
||||
ID string `json:"id"`
|
||||
Identity struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"identity"`
|
||||
} `json:"session"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("ory provider: decode login code response failed: %w", err)
|
||||
}
|
||||
if result.SessionToken == "" {
|
||||
return nil, fmt.Errorf("ory provider: empty session token returned")
|
||||
}
|
||||
|
||||
slog.Info("Ory login code successful",
|
||||
"identity_id", result.Session.Identity.ID,
|
||||
"loginID", loginID,
|
||||
"expires_at", result.SessionTokenExpiresAt,
|
||||
)
|
||||
|
||||
return &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{
|
||||
JWT: result.SessionToken,
|
||||
Expiration: result.SessionTokenExpiresAt,
|
||||
SessionID: result.Session.ID,
|
||||
},
|
||||
Subject: result.Session.Identity.ID,
|
||||
SetCookies: resp.Cookies(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPasswordPolicy는 Ory 환경에서 사용하는 기본 정책을 반환합니다.
|
||||
func (o *OryProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{
|
||||
MinLength: 12,
|
||||
Lowercase: true,
|
||||
Uppercase: false,
|
||||
Number: true,
|
||||
NonAlphanumeric: true,
|
||||
MinCharacterTypes: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiatePasswordReset는 현재 내부 토큰/메일 흐름을 사용하고 있으므로 NO-OP로 둡니다.
|
||||
func (o *OryProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
slog.Info("Ory InitiatePasswordReset bypassed (handled by app internal flow)", "loginID", loginID, "redirect", redirectUrl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPasswordResetToken는 내부 토큰 검증 흐름을 사용하므로 아직 구현하지 않습니다.
|
||||
func (o *OryProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, fmt.Errorf("ory provider: VerifyPasswordResetToken not implemented (internal token flow expected)")
|
||||
}
|
||||
|
||||
// UpdateUserPassword: Kratos Admin API를 통해 비밀번호를 갱신합니다.
|
||||
func (o *OryProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
if loginID == "" || newPassword == "" {
|
||||
return fmt.Errorf("ory provider: loginID or new password missing")
|
||||
}
|
||||
|
||||
identityID, err := o.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: find identity failed: %w", err)
|
||||
}
|
||||
|
||||
if identityID == "" {
|
||||
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
identity, err := o.getIdentity(identityID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: load identity failed: %w", err)
|
||||
}
|
||||
if identity == nil {
|
||||
return fmt.Errorf("ory provider: identity payload missing for loginID=%s", loginID)
|
||||
}
|
||||
|
||||
hashedPassword, err := hashPasswordForKratos(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: hash password failed: %w", err)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"schema_id": identity.SchemaID,
|
||||
"traits": identity.Traits,
|
||||
"state": identity.State,
|
||||
"credentials": map[string]any{
|
||||
"password": map[string]any{
|
||||
"config": map[string]string{
|
||||
"hashed_password": hashedPassword,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if payload["schema_id"] == "" {
|
||||
payload["schema_id"] = "default"
|
||||
}
|
||||
if payload["state"] == "" {
|
||||
payload["state"] = "active"
|
||||
}
|
||||
if identity.MetadataAdmin != nil {
|
||||
payload["metadata_admin"] = identity.MetadataAdmin
|
||||
}
|
||||
if identity.MetadataPublic != nil {
|
||||
payload["metadata_public"] = identity.MetadataPublic
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: build request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ory provider: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("ory provider: password update failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Info("Ory password updated via Kratos admin", "identity_id", identityID, "loginID", loginID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getenv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// findIdentityByID: Kratos Admin API에서 ID(UUID)로 직접 조회
|
||||
func (o *OryProvider) findIdentityByID(id string) (string, error) {
|
||||
identity, err := o.getIdentity(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if identity != nil {
|
||||
return identity.ID, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// findIdentityID: Kratos Admin API에서 credentials_identifier로 검색 후 첫 번째 identity id 반환
|
||||
func (o *OryProvider) findIdentityID(loginID string) (string, error) {
|
||||
u, err := url.Parse(fmt.Sprintf("%s/admin/identities", o.KratosAdminURL))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
query.Set("credentials_identifier", loginID)
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return "", nil
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return "", fmt.Errorf("kratos admin search failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var identities []struct {
|
||||
ID string `json:"id"`
|
||||
Traits map[string]any `json:"traits"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identities); err != nil {
|
||||
return "", fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
if len(identities) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// VERIFY: Double check traits to avoid Kratos ignoring the query param
|
||||
candidate := identities[0]
|
||||
if email, ok := candidate.Traits["email"].(string); ok && strings.EqualFold(email, loginID) {
|
||||
return candidate.ID, nil
|
||||
}
|
||||
if phone, ok := candidate.Traits["phone_number"].(string); ok && strings.EqualFold(phone, loginID) {
|
||||
return candidate.ID, nil
|
||||
}
|
||||
if lids, ok := candidate.Traits["custom_login_ids"].([]any); ok {
|
||||
for _, lid := range lids {
|
||||
if s, ok := lid.(string); ok && strings.EqualFold(s, loginID) {
|
||||
return candidate.ID, nil
|
||||
}
|
||||
}
|
||||
} else if lids, ok := candidate.Traits["custom_login_ids"].([]string); ok {
|
||||
for _, lid := range lids {
|
||||
if strings.EqualFold(lid, loginID) {
|
||||
return candidate.ID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) getIdentity(identityID string) (*KratosIdentity, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("ory provider: get identity failed status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var identity KratosIdentity
|
||||
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
func hashPasswordForKratos(password string) (string, error) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashed), nil
|
||||
}
|
||||
|
||||
func (o *OryProvider) httpClient() *http.Client {
|
||||
if o.HTTPClient != nil {
|
||||
return o.HTTPClient
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// startLoginFlow는 Kratos Public API에서 login flow ID를 발급받습니다.
|
||||
func (o *OryProvider) startLoginFlow(returnTo string) (string, error) {
|
||||
loginURL := fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL)
|
||||
if returnTo != "" {
|
||||
loginURL = loginURL + "?return_to=" + url.QueryEscape(returnTo)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, loginURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: build login flow request failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := o.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ory provider: login flow request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return "", fmt.Errorf("ory provider: login flow failed status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("ory provider: decode login flow failed: %w", err)
|
||||
}
|
||||
if result.ID == "" {
|
||||
return "", fmt.Errorf("ory provider: empty login flow id")
|
||||
}
|
||||
return result.ID, nil
|
||||
}
|
||||
226
baron-sso/backend/internal/service/ory_service_test.go
Normal file
226
baron-sso/backend/internal/service/ory_service_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// clientForHandler returns an http.Client that routes requests to the given handler
|
||||
// without real network sockets.
|
||||
func clientForHandler(h http.Handler) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
// Clone request body for handler
|
||||
var bodyBytes []byte
|
||||
if req.Body != nil {
|
||||
bodyBytes, _ = io.ReadAll(req.Body)
|
||||
}
|
||||
r := httptest.NewRequest(req.Method, req.URL.String(), bytes.NewReader(bodyBytes))
|
||||
r.Header = req.Header.Clone()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
return w.Result(), nil
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }
|
||||
|
||||
func TestUpdateUserPassword_Success(t *testing.T) {
|
||||
const (
|
||||
loginID = "user@example.com"
|
||||
identityID = "7f0dc8c3-9d5d-4f57-b3d1-123456789abc"
|
||||
newPassword = "Sup3rStr0ng!Pass#2026"
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/admin/identities") && r.Method == http.MethodGet:
|
||||
if r.URL.Path == "/admin/identities" {
|
||||
q := r.URL.Query()
|
||||
if got := q.Get("credentials_identifier"); got != loginID {
|
||||
t.Fatalf("expected credentials_identifier=%s, got=%s", loginID, got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode([]map[string]any{
|
||||
{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": loginID,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/admin/identities/"+identityID {
|
||||
t.Fatalf("unexpected identity lookup path: %s", r.URL.Path)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": identityID,
|
||||
"schema_id": "default",
|
||||
"state": "active",
|
||||
"traits": map[string]any{
|
||||
"email": loginID,
|
||||
},
|
||||
})
|
||||
return
|
||||
case r.URL.Path == "/admin/identities/"+identityID && r.Method == http.MethodPut:
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
if !strings.Contains(string(body), "\"hashed_password\"") {
|
||||
t.Fatalf("payload missing hashed_password, body=%s", string(body))
|
||||
}
|
||||
if strings.Contains(string(body), newPassword) {
|
||||
t.Fatalf("payload must not contain plain password, body=%s", string(body))
|
||||
}
|
||||
if !strings.Contains(string(body), "\"schema_id\":\"default\"") {
|
||||
t.Fatalf("payload missing schema_id, body=%s", string(body))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||
}
|
||||
})
|
||||
|
||||
provider := &OryProvider{
|
||||
KratosAdminURL: "http://kratos-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
if err := provider.UpdateUserPassword(loginID, newPassword, nil); err != nil {
|
||||
t.Fatalf("UpdateUserPassword returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPassword_NotFound(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/admin/identities") && r.Method == http.MethodGet {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||
})
|
||||
|
||||
provider := &OryProvider{
|
||||
KratosAdminURL: "http://kratos-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
err := provider.UpdateUserPassword("user@example.com", "Sup3rStr0ng!Pass#2026", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "identity not found") {
|
||||
t.Fatalf("expected identity not found error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPassword_ServerError(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/admin/identities") && r.Method == http.MethodGet:
|
||||
if r.URL.Path == "/admin/identities" {
|
||||
_ = json.NewEncoder(w).Encode([]map[string]any{
|
||||
{
|
||||
"id": "abc",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/admin/identities/abc" {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "abc",
|
||||
"schema_id": "default",
|
||||
"state": "active",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||
case r.URL.Path == "/admin/identities/abc" && r.Method == http.MethodPut:
|
||||
http.Error(w, "boom", http.StatusInternalServerError)
|
||||
return
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||
}
|
||||
})
|
||||
|
||||
provider := &OryProvider{
|
||||
KratosAdminURL: "http://kratos-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
err := provider.UpdateUserPassword("user@example.com", "Sup3rStr0ng!Pass#2026", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "password update failed") {
|
||||
t.Fatalf("expected server error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindIdentityID_QueryEncoding(t *testing.T) {
|
||||
loginID := "user+alias@example.com"
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
values, _ := url.ParseQuery(r.URL.RawQuery)
|
||||
if values.Get("credentials_identifier") != loginID {
|
||||
t.Fatalf("expected credentials_identifier=%s, got=%s", loginID, values.Get("credentials_identifier"))
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode([]map[string]any{
|
||||
{
|
||||
"id": "id-123",
|
||||
"traits": map[string]any{
|
||||
"email": loginID,
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
provider := &OryProvider{
|
||||
KratosAdminURL: "http://kratos-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
id, err := provider.findIdentityID(loginID)
|
||||
if err != nil {
|
||||
t.Fatalf("findIdentityID returned error: %v", err)
|
||||
}
|
||||
if id != "id-123" {
|
||||
t.Fatalf("expected id-123, got %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOryProvider_CreateUser_RejectsRequestedIdentityID(t *testing.T) {
|
||||
const (
|
||||
email = "newuser@test.com"
|
||||
name = "New User"
|
||||
customUuid = "550e8400-e29b-41d4-a716-446655440000"
|
||||
password = "secret123456"
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||
})
|
||||
|
||||
provider := &OryProvider{
|
||||
KratosAdminURL: "http://kratos-admin.local",
|
||||
HTTPClient: clientForHandler(handler),
|
||||
}
|
||||
|
||||
id, err := provider.CreateUser(&domain.BrokerUser{
|
||||
ID: customUuid,
|
||||
Email: email,
|
||||
Name: name,
|
||||
}, password)
|
||||
if err == nil || !strings.Contains(err.Error(), "requested identity id import is disabled") {
|
||||
t.Fatalf("expected requested identity id rejection, got id=%s err=%v", id, err)
|
||||
}
|
||||
}
|
||||
238
baron-sso/backend/internal/service/redis_service.go
Normal file
238
baron-sso/backend/internal/service/redis_service.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
type RedisService struct {
|
||||
Client *redis.Client
|
||||
}
|
||||
|
||||
type identityMirrorStateStore struct {
|
||||
Status string `json:"status"`
|
||||
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
ObservedCount int64 `json:"observedCount,omitempty"`
|
||||
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
|
||||
}
|
||||
|
||||
// NewRedisService creates and returns a new RedisService
|
||||
func NewRedisService() (*RedisService, error) {
|
||||
redisAddr := os.Getenv("REDIS_ADDR")
|
||||
if redisAddr == "" {
|
||||
redisAddr = "localhost:6389" // Fallback for local dev without Docker
|
||||
}
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: redisAddr,
|
||||
})
|
||||
|
||||
// Ping the server to check the connection
|
||||
if _, err := rdb.Ping(ctx).Result(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// [DEV-FIX] Disable stop-writes-on-bgsave-error to allow writes even if persistence fails
|
||||
// This is common in dev docker environments with permission issues.
|
||||
rdb.ConfigSet(ctx, "stop-writes-on-bgsave-error", "no")
|
||||
|
||||
return &RedisService{Client: rdb}, nil
|
||||
}
|
||||
|
||||
func (s *RedisService) Ping(ctx context.Context) error {
|
||||
if s.Client == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
return s.Client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// StoreVerificationCode saves the SMS verification code with a 3-minute expiration
|
||||
func (s *RedisService) StoreVerificationCode(phone, code string) error {
|
||||
// Key format: "sms_verify:01012345678"
|
||||
key := "sms_verify:" + phone
|
||||
expiration := 3 * time.Minute
|
||||
err := s.Client.Set(ctx, key, code, expiration).Err()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetVerificationCode retrieves the SMS verification code
|
||||
func (s *RedisService) GetVerificationCode(phone string) (string, error) {
|
||||
key := "sms_verify:" + phone
|
||||
code, err := s.Client.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
// Key does not exist (expired or incorrect phone number)
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// DeleteVerificationCode removes the verification code after successful verification
|
||||
func (s *RedisService) DeleteVerificationCode(phone string) error {
|
||||
key := "sms_verify:" + phone
|
||||
return s.Client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Set stores a key-value pair with expiration
|
||||
func (s *RedisService) Set(key string, value string, expiration time.Duration) error {
|
||||
return s.Client.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
// Get retrieves a value by key
|
||||
func (s *RedisService) Get(key string) (string, error) {
|
||||
val, err := s.Client.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return val, err
|
||||
}
|
||||
|
||||
// Delete removes a key
|
||||
func (s *RedisService) Delete(key string) error {
|
||||
return s.Client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (s *RedisService) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
|
||||
if s == nil || s.Client == nil {
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: "unavailable",
|
||||
RedisReady: false,
|
||||
LastError: "redis service unavailable",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.Client.Ping(ctx).Err(); err != nil {
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: "failed",
|
||||
RedisReady: false,
|
||||
LastError: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
keyCount, err := s.countIdentityCacheKeys(ctx)
|
||||
if err != nil {
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: "failed",
|
||||
RedisReady: true,
|
||||
LastError: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
raw, err := s.Client.Get(ctx, "identity:mirror:state").Result()
|
||||
if err == redis.Nil {
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: "empty",
|
||||
RedisReady: true,
|
||||
KeyCount: keyCount,
|
||||
}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: "failed",
|
||||
RedisReady: true,
|
||||
KeyCount: keyCount,
|
||||
LastError: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var stored identityMirrorStateStore
|
||||
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: "failed",
|
||||
RedisReady: true,
|
||||
KeyCount: keyCount,
|
||||
LastError: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
status := stored.Status
|
||||
if status == "" {
|
||||
status = "unknown"
|
||||
}
|
||||
return domain.IdentityCacheStatus{
|
||||
Status: status,
|
||||
RedisReady: true,
|
||||
ObservedCount: stored.ObservedCount,
|
||||
KeyCount: keyCount,
|
||||
LastRefreshedAt: stored.LastRefreshedAt,
|
||||
LastError: stored.LastError,
|
||||
UpdatedAt: stored.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RedisService) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
|
||||
if s == nil || s.Client == nil {
|
||||
return domain.IdentityCacheFlushResult{}, os.ErrInvalid
|
||||
}
|
||||
|
||||
keys, err := s.identityCacheKeys(ctx)
|
||||
if err != nil {
|
||||
return domain.IdentityCacheFlushResult{}, err
|
||||
}
|
||||
var deleted int64
|
||||
for len(keys) > 0 {
|
||||
chunkSize := len(keys)
|
||||
if chunkSize > 500 {
|
||||
chunkSize = 500
|
||||
}
|
||||
chunk := keys[:chunkSize]
|
||||
count, err := s.Client.Del(ctx, chunk...).Result()
|
||||
if err != nil {
|
||||
return domain.IdentityCacheFlushResult{}, err
|
||||
}
|
||||
deleted += count
|
||||
keys = keys[chunkSize:]
|
||||
}
|
||||
|
||||
return domain.IdentityCacheFlushResult{
|
||||
Status: "success",
|
||||
FlushedKeys: deleted,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RedisService) countIdentityCacheKeys(ctx context.Context) (int64, error) {
|
||||
keys, err := s.identityCacheKeys(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(len(keys)), nil
|
||||
}
|
||||
|
||||
func (s *RedisService) identityCacheKeys(ctx context.Context) ([]string, error) {
|
||||
seen := make(map[string]bool)
|
||||
patterns := []string{
|
||||
"identity:mirror:*",
|
||||
"identity:index:*",
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, next, err := s.Client.Scan(ctx, cursor, pattern, 250).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, key := range keys {
|
||||
seen[key] = true
|
||||
}
|
||||
cursor = next
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(seen))
|
||||
for key := range seen {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
150
baron-sso/backend/internal/service/redis_service_test.go
Normal file
150
baron-sso/backend/internal/service/redis_service_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type redisCommandStub struct {
|
||||
scans map[string][]string
|
||||
stateValue string
|
||||
deleted []string
|
||||
}
|
||||
|
||||
func (h *redisCommandStub) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (h *redisCommandStub) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
|
||||
switch cmd.Name() {
|
||||
case "ping":
|
||||
if status, ok := cmd.(*redis.StatusCmd); ok {
|
||||
status.SetVal("PONG")
|
||||
}
|
||||
case "scan":
|
||||
if scan, ok := cmd.(*redis.ScanCmd); ok {
|
||||
scan.SetVal(h.scans[scanPattern(cmd.Args())], 0)
|
||||
}
|
||||
case "get":
|
||||
if str, ok := cmd.(*redis.StringCmd); ok {
|
||||
if h.stateValue == "" {
|
||||
str.SetErr(redis.Nil)
|
||||
return nil
|
||||
}
|
||||
str.SetVal(h.stateValue)
|
||||
}
|
||||
case "del":
|
||||
args := cmd.Args()
|
||||
keys := make([]string, 0, len(args)-1)
|
||||
for _, arg := range args[1:] {
|
||||
keys = append(keys, arg.(string))
|
||||
}
|
||||
h.deleted = append(h.deleted, keys...)
|
||||
if count, ok := cmd.(*redis.IntCmd); ok {
|
||||
count.SetVal(int64(len(keys)))
|
||||
}
|
||||
}
|
||||
cmd.SetErr(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *redisCommandStub) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (h *redisCommandStub) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanPattern(args []interface{}) string {
|
||||
for index := 0; index < len(args)-1; index++ {
|
||||
value, ok := args[index].(string)
|
||||
if ok && value == "match" {
|
||||
if pattern, ok := args[index+1].(string); ok {
|
||||
return pattern
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newStubbedRedisService(stub *redisCommandStub) *RedisService {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
MaxRetries: -1,
|
||||
})
|
||||
client.AddHook(stub)
|
||||
return &RedisService{Client: client}
|
||||
}
|
||||
|
||||
func TestRedisServiceGetIdentityCacheStatusReadsStateAndCountsCacheKeys(t *testing.T) {
|
||||
now := time.Date(2026, 6, 9, 3, 20, 0, 0, time.UTC)
|
||||
state, err := json.Marshal(identityMirrorStateStore{
|
||||
Status: "ready",
|
||||
LastRefreshedAt: &now,
|
||||
ObservedCount: 42,
|
||||
UpdatedAt: &now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
stub := &redisCommandStub{
|
||||
stateValue: string(state),
|
||||
scans: map[string][]string{
|
||||
"identity:mirror:*": {"identity:mirror:state", "identity:mirror:user:1"},
|
||||
"identity:index:*": {"identity:index:email:a", "identity:mirror:user:1"},
|
||||
},
|
||||
}
|
||||
service := newStubbedRedisService(stub)
|
||||
|
||||
status, err := service.GetIdentityCacheStatus(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ready", status.Status)
|
||||
require.True(t, status.RedisReady)
|
||||
require.Equal(t, int64(42), status.ObservedCount)
|
||||
require.Equal(t, int64(3), status.KeyCount)
|
||||
require.Equal(t, &now, status.LastRefreshedAt)
|
||||
require.Equal(t, &now, status.UpdatedAt)
|
||||
}
|
||||
|
||||
func TestRedisServiceFlushIdentityCacheDeletesOnlyIdentityMirrorAndIndexKeys(t *testing.T) {
|
||||
stub := &redisCommandStub{
|
||||
scans: map[string][]string{
|
||||
"identity:mirror:*": {"identity:mirror:state", "identity:mirror:user:1"},
|
||||
"identity:index:*": {"identity:index:email:a", "identity:mirror:user:1"},
|
||||
},
|
||||
}
|
||||
service := newStubbedRedisService(stub)
|
||||
|
||||
result, err := service.FlushIdentityCache(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "success", result.Status)
|
||||
require.Equal(t, int64(3), result.FlushedKeys)
|
||||
require.ElementsMatch(t, []string{
|
||||
"identity:mirror:state",
|
||||
"identity:mirror:user:1",
|
||||
"identity:index:email:a",
|
||||
}, stub.deleted)
|
||||
}
|
||||
|
||||
func TestRedisServiceGetIdentityCacheStatusReturnsUnavailableWithoutClient(t *testing.T) {
|
||||
status, err := (*RedisService)(nil).GetIdentityCacheStatus(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "unavailable", status.Status)
|
||||
require.False(t, status.RedisReady)
|
||||
require.NotEmpty(t, status.LastError)
|
||||
}
|
||||
|
||||
func TestRedisServiceFlushIdentityCacheFailsWithoutClient(t *testing.T) {
|
||||
_, err := (*RedisService)(nil).FlushIdentityCache(context.Background())
|
||||
|
||||
require.ErrorIs(t, err, os.ErrInvalid)
|
||||
}
|
||||
215
baron-sso/backend/internal/service/relying_party_service.go
Normal file
215
baron-sso/backend/internal/service/relying_party_service.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RelyingPartyService interface {
|
||||
Create(ctx context.Context, tenantID string, client domain.HydraClient) (*domain.RelyingParty, error)
|
||||
Get(ctx context.Context, clientID string) (*domain.RelyingParty, *domain.HydraClient, error)
|
||||
List(ctx context.Context, tenantID string) ([]domain.RelyingParty, error)
|
||||
ListAll(ctx context.Context) ([]domain.RelyingParty, error)
|
||||
ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error)
|
||||
Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error)
|
||||
Delete(ctx context.Context, clientID string) error
|
||||
}
|
||||
|
||||
type relyingPartyService struct {
|
||||
hydraService *HydraAdminService
|
||||
ketoService KetoService
|
||||
outboxRepo repository.KetoOutboxRepository
|
||||
}
|
||||
|
||||
var defaultRelyingPartyOperatorRelations = []string{
|
||||
"admins",
|
||||
"creator",
|
||||
"config_editor",
|
||||
"secret_viewer",
|
||||
"secret_rotator",
|
||||
"jwks_viewer",
|
||||
"jwks_operator",
|
||||
"consent_viewer",
|
||||
"consent_revoker",
|
||||
"relationship_viewer",
|
||||
"audit_viewer",
|
||||
"status_operator",
|
||||
}
|
||||
|
||||
func NewRelyingPartyService(
|
||||
hydraService *HydraAdminService,
|
||||
ketoService KetoService,
|
||||
outboxRepo repository.KetoOutboxRepository,
|
||||
) RelyingPartyService {
|
||||
return &relyingPartyService{
|
||||
hydraService: hydraService,
|
||||
ketoService: ketoService,
|
||||
outboxRepo: outboxRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func extractRelyingPartyCreatorSubject(client *domain.HydraClient) string {
|
||||
if client == nil || client.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
raw, _ := client.Metadata["user_id"].(string)
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
return "User:" + raw
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) enqueueRelyingPartyTuple(ctx context.Context, action, object, relation, subject string) {
|
||||
if s.outboxRepo == nil || strings.TrimSpace(object) == "" || strings.TrimSpace(relation) == "" || strings.TrimSpace(subject) == "" {
|
||||
return
|
||||
}
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "RelyingParty",
|
||||
Object: object,
|
||||
Relation: relation,
|
||||
Subject: subject,
|
||||
Action: action,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) enqueueDefaultRelyingPartyRelations(ctx context.Context, action string, client *domain.HydraClient, tenantID string) {
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID != "" {
|
||||
s.enqueueRelyingPartyTuple(ctx, action, client.ClientID, "parents", "Tenant:"+tenantID)
|
||||
}
|
||||
|
||||
creatorSubject := extractRelyingPartyCreatorSubject(client)
|
||||
if creatorSubject == "" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, relation := range defaultRelyingPartyOperatorRelations {
|
||||
s.enqueueRelyingPartyTuple(ctx, action, client.ClientID, relation, creatorSubject)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) Create(ctx context.Context, tenantID string, client domain.HydraClient) (*domain.RelyingParty, error) {
|
||||
// 1. Create Client in Hydra
|
||||
if client.Metadata == nil {
|
||||
client.Metadata = make(map[string]any)
|
||||
}
|
||||
client.Metadata["tenant_id"] = tenantID
|
||||
|
||||
createdClient, err := s.hydraService.CreateClient(ctx, client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create hydra client: %w", err)
|
||||
}
|
||||
|
||||
// 2. Create default relations in Keto via Outbox.
|
||||
s.enqueueDefaultRelyingPartyRelations(ctx, domain.KetoOutboxActionCreate, createdClient, tenantID)
|
||||
|
||||
return s.mapHydraToDomain(createdClient), nil
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) Get(ctx context.Context, clientID string) (*domain.RelyingParty, *domain.HydraClient, error) {
|
||||
hydraClient, err := s.hydraService.GetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return s.mapHydraToDomain(hydraClient), hydraClient, nil
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) List(ctx context.Context, tenantID string) ([]domain.RelyingParty, error) {
|
||||
// 1. Fetch ClientIDs from Keto
|
||||
// Relation tuple: RelyingParty:cid # parents @ Tenant:tid
|
||||
tuples, err := s.ketoService.ListRelations(ctx, "RelyingParty", "", "parents", "Tenant:"+tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rps []domain.RelyingParty
|
||||
for _, t := range tuples {
|
||||
clientID := t.Object
|
||||
client, err := s.hydraService.GetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to fetch relying party from hydra", "client_id", clientID, "error", err)
|
||||
continue
|
||||
}
|
||||
if rp := s.mapHydraToDomain(client); rp != nil {
|
||||
rps = append(rps, *rp)
|
||||
}
|
||||
}
|
||||
|
||||
return rps, nil
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) ListAll(ctx context.Context) ([]domain.RelyingParty, error) {
|
||||
return nil, fmt.Errorf("ListAll not implemented in SSOT mode yet")
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error) {
|
||||
var allRps []domain.RelyingParty
|
||||
for _, tid := range tenantIDs {
|
||||
rps, err := s.List(ctx, tid)
|
||||
if err == nil {
|
||||
allRps = append(allRps, rps...)
|
||||
}
|
||||
}
|
||||
return allRps, nil
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error) {
|
||||
updatedClient, err := s.hydraService.UpdateClient(ctx, clientID, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.mapHydraToDomain(updatedClient), nil
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) Delete(ctx context.Context, clientID string) error {
|
||||
// 1. Get client to find tenantID (for Keto cleanup)
|
||||
client, err := s.hydraService.GetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tenantID := ""
|
||||
if client.Metadata != nil {
|
||||
if tid, ok := client.Metadata["tenant_id"].(string); ok {
|
||||
tenantID = tid
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Delete from Hydra
|
||||
if err := s.hydraService.DeleteClient(ctx, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. Delete default relations from Keto via Outbox.
|
||||
s.enqueueDefaultRelyingPartyRelations(ctx, domain.KetoOutboxActionDelete, client, tenantID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *relyingPartyService) mapHydraToDomain(client *domain.HydraClient) *domain.RelyingParty {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
rp := &domain.RelyingParty{
|
||||
ClientID: client.ClientID,
|
||||
Name: client.ClientName,
|
||||
}
|
||||
if client.Metadata != nil {
|
||||
if tid, ok := client.Metadata["tenant_id"].(string); ok {
|
||||
rp.TenantID = tid
|
||||
}
|
||||
if desc, ok := client.Metadata["description"].(string); ok {
|
||||
rp.Description = desc
|
||||
}
|
||||
}
|
||||
return rp
|
||||
}
|
||||
217
baron-sso/backend/internal/service/relying_party_service_test.go
Normal file
217
baron-sso/backend/internal/service/relying_party_service_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
/*
|
||||
이 테스트 파일은 RelyingPartyService의 기능을 검증하기 위한 유닛 테스트입니다.
|
||||
RelyingPartyService는 HydraAdminService, KetoService와 협력하므로
|
||||
각 의존성을 모킹(Mocking)하여 통합 로직을 검증합니다.
|
||||
|
||||
주요 테스트 항목:
|
||||
1. Create: Hydra 클라이언트 생성 -> Keto 권한 설정
|
||||
2. Get: Hydra에서 정보 조회
|
||||
3. Update: Hydra 업데이트
|
||||
4. Delete: Hydra 삭제 + Keto 권한 정리
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Test Helpers ---
|
||||
|
||||
type hydraRoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f hydraRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func mockHydraClient(handler http.Handler) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: hydraRoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
return rec.Result(), nil
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestRelyingPartyService_Create_Success(t *testing.T) {
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
|
||||
tenantID := "tenant-1"
|
||||
inputClient := domain.HydraClient{
|
||||
ClientName: "Test App",
|
||||
Metadata: map[string]any{
|
||||
"user_id": "creator-1",
|
||||
},
|
||||
}
|
||||
|
||||
// Hydra Mock
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/clients") {
|
||||
var req domain.HydraClient
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// 메타데이터 tenant_id 주입 확인
|
||||
if req.Metadata["tenant_id"] != tenantID {
|
||||
t.Errorf("expected tenant_id in metadata")
|
||||
}
|
||||
|
||||
req.ClientID = "generated-client-id"
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(req)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
hydraSvc := &HydraAdminService{
|
||||
AdminURL: "http://hydra:4445",
|
||||
HTTPClient: mockHydraClient(hydraHandler),
|
||||
}
|
||||
|
||||
// Keto sync via Outbox using 'parents' relation
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "RelyingParty" && e.Object == "generated-client-id" && e.Relation == "parents" && e.Subject == "Tenant:"+tenantID
|
||||
})).Return(nil)
|
||||
for _, relation := range defaultRelyingPartyOperatorRelations {
|
||||
rel := relation
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "RelyingParty" && e.Object == "generated-client-id" && e.Relation == rel && e.Subject == "User:creator-1"
|
||||
})).Return(nil)
|
||||
}
|
||||
|
||||
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
|
||||
rp, err := svc.Create(context.Background(), tenantID, inputClient)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "generated-client-id", rp.ClientID)
|
||||
assert.Equal(t, tenantID, rp.TenantID)
|
||||
|
||||
mockOutbox.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestRelyingPartyService_Create_HydraFail(t *testing.T) {
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
})
|
||||
hydraSvc := &HydraAdminService{
|
||||
AdminURL: "http://hydra:4445",
|
||||
HTTPClient: mockHydraClient(hydraHandler),
|
||||
}
|
||||
|
||||
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
|
||||
_, err := svc.Create(context.Background(), "tenant-1", domain.HydraClient{})
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRelyingPartyService_Get_Success(t *testing.T) {
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
clientID := "client-123"
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraClient{
|
||||
ClientID: clientID,
|
||||
ClientName: "Hydra Name",
|
||||
Metadata: map[string]any{
|
||||
"tenant_id": "tenant-1",
|
||||
},
|
||||
})
|
||||
})
|
||||
hydraSvc := &HydraAdminService{
|
||||
AdminURL: "http://hydra:4445",
|
||||
HTTPClient: mockHydraClient(hydraHandler),
|
||||
}
|
||||
|
||||
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
|
||||
rp, hc, err := svc.Get(context.Background(), clientID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hydra Name", rp.Name)
|
||||
assert.Equal(t, "Hydra Name", hc.ClientName)
|
||||
}
|
||||
|
||||
func TestRelyingPartyService_Update_Success(t *testing.T) {
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
clientID := "client-123"
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPut {
|
||||
var req domain.HydraClient
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
_ = json.NewEncoder(w).Encode(req)
|
||||
return
|
||||
}
|
||||
})
|
||||
hydraSvc := &HydraAdminService{
|
||||
AdminURL: "http://hydra:4445",
|
||||
HTTPClient: mockHydraClient(hydraHandler),
|
||||
}
|
||||
|
||||
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
|
||||
|
||||
updateReq := domain.HydraClient{ClientName: "New Name"}
|
||||
rp, err := svc.Update(context.Background(), clientID, updateReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "New Name", rp.Name)
|
||||
}
|
||||
|
||||
func TestRelyingPartyService_Delete_Success(t *testing.T) {
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
clientID := "client-123"
|
||||
tenantID := "tenant-1"
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, clientID) {
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraClient{
|
||||
ClientID: clientID,
|
||||
Metadata: map[string]any{
|
||||
"tenant_id": tenantID,
|
||||
"user_id": "creator-1",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodDelete && strings.Contains(r.URL.Path, clientID) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
hydraSvc := &HydraAdminService{
|
||||
AdminURL: "http://hydra:4445",
|
||||
HTTPClient: mockHydraClient(hydraHandler),
|
||||
}
|
||||
|
||||
// Delete relation via Outbox using 'parents'
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "RelyingParty" && e.Object == clientID && e.Relation == "parents" && e.Subject == "Tenant:"+tenantID
|
||||
})).Return(nil)
|
||||
for _, relation := range defaultRelyingPartyOperatorRelations {
|
||||
rel := relation
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "RelyingParty" && e.Object == clientID && e.Relation == rel && e.Subject == "User:creator-1"
|
||||
})).Return(nil)
|
||||
}
|
||||
|
||||
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
|
||||
err := svc.Delete(context.Background(), clientID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockOutbox.AssertExpectations(t)
|
||||
}
|
||||
67
baron-sso/backend/internal/service/rp_usage_event_emitter.go
Normal file
67
baron-sso/backend/internal/service/rp_usage_event_emitter.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RPUsageEventEmitter struct {
|
||||
repo repository.RPUsageOutboxRepository
|
||||
}
|
||||
|
||||
func NewRPUsageEventEmitter(repo repository.RPUsageOutboxRepository) *RPUsageEventEmitter {
|
||||
return &RPUsageEventEmitter{repo: repo}
|
||||
}
|
||||
|
||||
func (e *RPUsageEventEmitter) EmitRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
|
||||
if e == nil || e.repo == nil {
|
||||
return nil
|
||||
}
|
||||
event.EventType = strings.TrimSpace(event.EventType)
|
||||
event.Subject = strings.TrimSpace(event.Subject)
|
||||
event.ClientID = strings.TrimSpace(event.ClientID)
|
||||
event.Source = strings.TrimSpace(event.Source)
|
||||
event.CorrelationID = strings.TrimSpace(event.CorrelationID)
|
||||
if event.EventType == "" {
|
||||
return fmt.Errorf("rp usage event type is required")
|
||||
}
|
||||
if event.Subject == "" {
|
||||
return fmt.Errorf("rp usage subject is required")
|
||||
}
|
||||
if event.ClientID == "" {
|
||||
return fmt.Errorf("rp usage client_id is required")
|
||||
}
|
||||
if event.Source == "" {
|
||||
event.Source = "backend"
|
||||
}
|
||||
if event.OccurredAt.IsZero() {
|
||||
event.OccurredAt = time.Now()
|
||||
}
|
||||
if event.DedupeKey == "" {
|
||||
event.DedupeKey = buildRPUsageDedupeKey(event)
|
||||
}
|
||||
if event.Payload == nil {
|
||||
event.Payload = domain.JSONMap{}
|
||||
}
|
||||
return e.repo.Create(ctx, &event)
|
||||
}
|
||||
|
||||
func buildRPUsageDedupeKey(event domain.RPUsageEvent) string {
|
||||
raw := strings.Join([]string{
|
||||
event.EventType,
|
||||
event.Subject,
|
||||
event.ClientID,
|
||||
event.SessionID,
|
||||
event.Source,
|
||||
event.CorrelationID,
|
||||
event.OccurredAt.UTC().Format("2006-01-02T15:04:05.000Z"),
|
||||
}, "|")
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeRPUsageOutboxRepo struct {
|
||||
created []domain.RPUsageEvent
|
||||
ready []domain.RPUsageEvent
|
||||
processing []string
|
||||
processed []string
|
||||
failed []string
|
||||
createErr error
|
||||
projectErr error
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageOutboxRepo) Create(ctx context.Context, event *domain.RPUsageEvent) error {
|
||||
if f.createErr != nil {
|
||||
return f.createErr
|
||||
}
|
||||
f.created = append(f.created, *event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageOutboxRepo) ListReady(ctx context.Context, limit int) ([]domain.RPUsageEvent, error) {
|
||||
return f.ready, nil
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageOutboxRepo) MarkProcessing(ctx context.Context, id string) error {
|
||||
f.processing = append(f.processing, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageOutboxRepo) MarkProcessed(ctx context.Context, id string) error {
|
||||
f.processed = append(f.processed, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageOutboxRepo) MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error {
|
||||
f.failed = append(f.failed, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeRPUsageProjectionRepo struct {
|
||||
created []domain.RPUsageEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageProjectionRepo) CreateRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
|
||||
if f.err != nil {
|
||||
return f.err
|
||||
}
|
||||
f.created = append(f.created, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRPUsageEventEmitterRequiresCanonicalFields(t *testing.T) {
|
||||
repo := &fakeRPUsageOutboxRepo{}
|
||||
emitter := NewRPUsageEventEmitter(repo)
|
||||
|
||||
err := emitter.EmitRPUsageEvent(context.Background(), domain.RPUsageEvent{
|
||||
EventType: domain.RPUsageEventTypeAuthorizationGranted,
|
||||
ClientID: "client-app",
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
|
||||
func TestRPUsageEventEmitterCreatesPendingOutboxEvent(t *testing.T) {
|
||||
repo := &fakeRPUsageOutboxRepo{}
|
||||
emitter := NewRPUsageEventEmitter(repo)
|
||||
|
||||
err := emitter.EmitRPUsageEvent(context.Background(), domain.RPUsageEvent{
|
||||
EventType: domain.RPUsageEventTypeAuthorizationGranted,
|
||||
Subject: "user-123",
|
||||
ClientID: "client-app",
|
||||
Source: "hydra_consent",
|
||||
CorrelationID: "challenge-1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repo.created, 1)
|
||||
require.NotEmpty(t, repo.created[0].DedupeKey)
|
||||
require.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, repo.created[0].EventType)
|
||||
require.Equal(t, "hydra_consent", repo.created[0].Source)
|
||||
}
|
||||
|
||||
func TestRPUsageProjectorWorkerMarksProcessedAfterProjection(t *testing.T) {
|
||||
outbox := &fakeRPUsageOutboxRepo{
|
||||
ready: []domain.RPUsageEvent{{
|
||||
ID: "event-1",
|
||||
EventType: domain.RPUsageEventTypeAuthorizationGranted,
|
||||
Subject: "user-123",
|
||||
ClientID: "client-app",
|
||||
}},
|
||||
}
|
||||
projection := &fakeRPUsageProjectionRepo{}
|
||||
worker := NewRPUsageProjectorWorker(outbox, projection)
|
||||
|
||||
worker.processOnce(context.Background())
|
||||
|
||||
require.Equal(t, []string{"event-1"}, outbox.processing)
|
||||
require.Equal(t, []string{"event-1"}, outbox.processed)
|
||||
require.Empty(t, outbox.failed)
|
||||
require.Len(t, projection.created, 1)
|
||||
}
|
||||
|
||||
func TestRPUsageProjectorWorkerMarksFailedWhenProjectionFails(t *testing.T) {
|
||||
outbox := &fakeRPUsageOutboxRepo{
|
||||
ready: []domain.RPUsageEvent{{
|
||||
ID: "event-1",
|
||||
EventType: domain.RPUsageEventTypeAuthorizationGranted,
|
||||
Subject: "user-123",
|
||||
ClientID: "client-app",
|
||||
}},
|
||||
}
|
||||
projection := &fakeRPUsageProjectionRepo{err: errors.New("clickhouse unavailable")}
|
||||
worker := NewRPUsageProjectorWorker(outbox, projection)
|
||||
|
||||
worker.processOnce(context.Background())
|
||||
|
||||
require.Equal(t, []string{"event-1"}, outbox.processing)
|
||||
require.Empty(t, outbox.processed)
|
||||
require.Equal(t, []string{"event-1"}, outbox.failed)
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RPUsageProjectorWorker struct {
|
||||
outbox repository.RPUsageOutboxRepository
|
||||
projection domain.RPUsageProjectionRepository
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
}
|
||||
|
||||
func NewRPUsageProjectorWorker(outbox repository.RPUsageOutboxRepository, projection domain.RPUsageProjectionRepository) *RPUsageProjectorWorker {
|
||||
return &RPUsageProjectorWorker{
|
||||
outbox: outbox,
|
||||
projection: projection,
|
||||
interval: 5 * time.Second,
|
||||
batchSize: 50,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *RPUsageProjectorWorker) Start(ctx context.Context) {
|
||||
if w == nil || w.outbox == nil || w.projection == nil {
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
w.processOnce(ctx)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *RPUsageProjectorWorker) processOnce(ctx context.Context) {
|
||||
events, err := w.outbox.ListReady(ctx, w.batchSize)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list rp usage outbox", "error", err)
|
||||
return
|
||||
}
|
||||
for _, event := range events {
|
||||
if err := w.outbox.MarkProcessing(ctx, event.ID); err != nil {
|
||||
slog.Warn("failed to mark rp usage event processing", "event_id", event.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
if err := w.projection.CreateRPUsageEvent(ctx, event); err != nil {
|
||||
nextAttempt := time.Now().Add(backoffDuration(event.RetryCount))
|
||||
_ = w.outbox.MarkFailed(ctx, event.ID, err.Error(), nextAttempt)
|
||||
slog.Warn("failed to project rp usage event", "event_id", event.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
if err := w.outbox.MarkProcessed(ctx, event.ID); err != nil {
|
||||
slog.Warn("failed to mark rp usage event processed", "event_id", event.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func backoffDuration(retryCount int) time.Duration {
|
||||
if retryCount < 0 {
|
||||
retryCount = 0
|
||||
}
|
||||
delay := time.Duration(retryCount+1) * time.Minute
|
||||
if delay > 30*time.Minute {
|
||||
return 30 * time.Minute
|
||||
}
|
||||
return delay
|
||||
}
|
||||
78
baron-sso/backend/internal/service/ses_service.go
Normal file
78
baron-sso/backend/internal/service/ses_service.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ses"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ses/types"
|
||||
)
|
||||
|
||||
type SesServiceImpl struct {
|
||||
client *ses.Client
|
||||
sender string
|
||||
}
|
||||
|
||||
func NewEmailService() domain.EmailService {
|
||||
region := os.Getenv("AWS_REGION")
|
||||
accessKey := os.Getenv("AWS_ACCESS_KEY_ID")
|
||||
secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY")
|
||||
sender := os.Getenv("AWS_SES_SENDER")
|
||||
|
||||
if region == "" || accessKey == "" || secretKey == "" {
|
||||
slog.Warn("[EmailService] AWS configuration missing, email service will not work")
|
||||
return nil
|
||||
}
|
||||
cfg, err := config.LoadDefaultConfig(context.TODO(),
|
||||
config.WithRegion(region),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")),
|
||||
)
|
||||
if err != nil {
|
||||
slog.Error("Failed to load AWS config", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &SesServiceImpl{
|
||||
client: ses.NewFromConfig(cfg),
|
||||
sender: sender,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SesServiceImpl) SendEmail(to, subject, body string) error {
|
||||
if s == nil || s.client == nil {
|
||||
return fmt.Errorf("email service not initialized")
|
||||
}
|
||||
|
||||
input := &ses.SendEmailInput{
|
||||
Destination: &types.Destination{
|
||||
ToAddresses: []string{to},
|
||||
},
|
||||
Message: &types.Message{
|
||||
Body: &types.Body{
|
||||
Html: &types.Content{
|
||||
Charset: aws.String("UTF-8"),
|
||||
Data: aws.String(body),
|
||||
},
|
||||
},
|
||||
Subject: &types.Content{
|
||||
Charset: aws.String("UTF-8"),
|
||||
Data: aws.String(subject),
|
||||
},
|
||||
},
|
||||
Source: aws.String(s.sender),
|
||||
}
|
||||
|
||||
_, err := s.client.SendEmail(context.TODO(), input)
|
||||
if err != nil {
|
||||
slog.Error("[EmailService] Failed to send email", "to", to, "error", err)
|
||||
} else {
|
||||
slog.Info("[EmailService] Email sent successfully", "to", to)
|
||||
}
|
||||
return err
|
||||
}
|
||||
63
baron-sso/backend/internal/service/shared_link_service.go
Normal file
63
baron-sso/backend/internal/service/shared_link_service.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SharedLinkService interface {
|
||||
CreateLink(ctx context.Context, tenantID, name, description string, expiresAt *time.Time) (*domain.SharedLink, error)
|
||||
ValidateToken(ctx context.Context, token string) (*domain.SharedLink, error)
|
||||
GetLinksByTenant(ctx context.Context, tenantID string) ([]domain.SharedLink, error)
|
||||
DeactivateLink(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type sharedLinkService struct {
|
||||
repo repository.SharedLinkRepository
|
||||
}
|
||||
|
||||
func NewSharedLinkService(repo repository.SharedLinkRepository) SharedLinkService {
|
||||
return &sharedLinkService{repo: repo}
|
||||
}
|
||||
|
||||
func (s *sharedLinkService) CreateLink(ctx context.Context, tenantID, name, description string, expiresAt *time.Time) (*domain.SharedLink, error) {
|
||||
link := &domain.SharedLink{
|
||||
TenantID: tenantID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
ExpiresAt: expiresAt,
|
||||
IsActive: true,
|
||||
AccessLevel: "READ_ONLY",
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, link); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (s *sharedLinkService) ValidateToken(ctx context.Context, token string) (*domain.SharedLink, error) {
|
||||
link, err := s.repo.FindByToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid or expired share link")
|
||||
}
|
||||
|
||||
if !link.IsValid() {
|
||||
return nil, errors.New("share link has expired or is inactive")
|
||||
}
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (s *sharedLinkService) GetLinksByTenant(ctx context.Context, tenantID string) ([]domain.SharedLink, error) {
|
||||
return s.repo.FindByTenantID(ctx, tenantID)
|
||||
}
|
||||
|
||||
func (s *sharedLinkService) DeactivateLink(ctx context.Context, id string) error {
|
||||
// 실제 삭제 대신 비활성화 처리 (soft-delete와 유사)
|
||||
// 하지만 여기서는 간단히 활성 플래그만 끔
|
||||
return s.repo.Delete(ctx, id) // 리포지토리의 Delete는 GORM의 DeletedAt을 사용하여 soft-delete함
|
||||
}
|
||||
134
baron-sso/backend/internal/service/sms_service.go
Normal file
134
baron-sso/backend/internal/service/sms_service.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const naverSMSMaxBytes = 90
|
||||
|
||||
type SmsServiceImpl struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
serviceID string
|
||||
senderPhone string
|
||||
}
|
||||
|
||||
func NewSmsService() domain.SmsService {
|
||||
// Sanitize sender phone number right after reading from env
|
||||
rawSenderPhone := os.Getenv("NAVER_SENDER_PHONE_NUMBER")
|
||||
sanitizedSenderPhone := strings.ReplaceAll(rawSenderPhone, "-", "")
|
||||
slog.Info("[서비스 초기화] 발신자 번호 처리", "원본", rawSenderPhone, "정제후", sanitizedSenderPhone)
|
||||
|
||||
return &SmsServiceImpl{
|
||||
accessKey: os.Getenv("NAVER_CLOUD_ACCESS_KEY"),
|
||||
secretKey: os.Getenv("NAVER_CLOUD_SECRET_KEY"),
|
||||
serviceID: os.Getenv("NAVER_CLOUD_SERVICE_ID"),
|
||||
senderPhone: sanitizedSenderPhone,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SmsServiceImpl) SendSms(to, content string) error {
|
||||
timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10)
|
||||
apiURL := fmt.Sprintf("https://sens.apigw.ntruss.com/sms/v2/services/%s/messages", s.serviceID)
|
||||
slog.Info("[SmsService] Requesting SENS API URL", "url", apiURL)
|
||||
|
||||
// Naver SENS API requires phone number without '+'
|
||||
sanitizedTo := strings.Replace(to, "+", "", 1)
|
||||
|
||||
reqBody := buildNaverSmsRequest(s.senderPhone, sanitizedTo, content)
|
||||
if reqBody.Type == "LMS" {
|
||||
slog.Info("[SmsService] Upgrading message type to LMS due to content length",
|
||||
"bytes", len([]byte(content)),
|
||||
)
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling request body: %w", err)
|
||||
}
|
||||
|
||||
signature, err := s.makeSignature("POST", fmt.Sprintf("/sms/v2/services/%s/messages", s.serviceID), timestamp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating signature: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-ncp-apigw-timestamp", timestamp)
|
||||
req.Header.Set("x-ncp-iam-access-key", s.accessKey)
|
||||
req.Header.Set("x-ncp-apigw-signature-v2", signature)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
slog.Error("[SmsService] error response from naver cloud sms api", "body", string(respBody))
|
||||
return fmt.Errorf("error sending sms: status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
slog.Info("[SmsService] sms sent successfully", "body", string(respBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildNaverSmsRequest(senderPhone, sanitizedTo, content string) domain.NaverSmsRequest {
|
||||
requestType := "SMS"
|
||||
subject := ""
|
||||
if len([]byte(content)) > naverSMSMaxBytes {
|
||||
requestType = "LMS"
|
||||
subject = "[Baron 로그인]"
|
||||
}
|
||||
|
||||
return domain.NaverSmsRequest{
|
||||
Type: requestType,
|
||||
ContentType: "COMM",
|
||||
CountryCode: "82",
|
||||
From: senderPhone,
|
||||
Subject: subject,
|
||||
Content: content,
|
||||
Messages: []domain.SmsMessage{
|
||||
{
|
||||
To: sanitizedTo,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SmsServiceImpl) makeSignature(method, url, timestamp string) (string, error) {
|
||||
space := " "
|
||||
newLine := "\n"
|
||||
message := method + space + url + newLine + timestamp + newLine + s.accessKey
|
||||
|
||||
h := hmac.New(sha256.New, []byte(s.secretKey))
|
||||
_, err := h.Write([]byte(message))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
26
baron-sso/backend/internal/service/sms_service_test.go
Normal file
26
baron-sso/backend/internal/service/sms_service_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildNaverSmsRequest_UsesSMSForShortContent(t *testing.T) {
|
||||
req := buildNaverSmsRequest("0262857755", "821012345678", "123456")
|
||||
|
||||
if req.Type != "SMS" {
|
||||
t.Fatalf("expected SMS, got %s", req.Type)
|
||||
}
|
||||
if req.Subject != "" {
|
||||
t.Fatalf("expected empty subject for SMS, got %q", req.Subject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildNaverSmsRequest_UsesLMSForLongContent(t *testing.T) {
|
||||
content := "[Baron 로그인] 비밀번호 재설정 링크: http://sso.example.test/api/v1/auth/password/reset/v/1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
req := buildNaverSmsRequest("0262857755", "821012345678", content)
|
||||
|
||||
if req.Type != "LMS" {
|
||||
t.Fatalf("expected LMS, got %s", req.Type)
|
||||
}
|
||||
if req.Subject == "" {
|
||||
t.Fatal("expected LMS subject to be set")
|
||||
}
|
||||
}
|
||||
394
baron-sso/backend/internal/service/tenant_service.go
Normal file
394
baron-sso/backend/internal/service/tenant_service.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TenantService interface {
|
||||
RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error)
|
||||
RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error)
|
||||
GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error)
|
||||
GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
|
||||
GetTenant(ctx context.Context, id string) (*domain.Tenant, error)
|
||||
ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error)
|
||||
ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
|
||||
ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
|
||||
IsDomainAllowed(ctx context.Context, domainName string) (bool, error)
|
||||
ApproveTenant(ctx context.Context, id string) error
|
||||
ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error)
|
||||
SetKetoService(keto KetoService)
|
||||
DeleteTenantsBulk(ctx context.Context, ids []string) error
|
||||
}
|
||||
|
||||
type tenantService struct {
|
||||
repo repository.TenantRepository
|
||||
userRepo repository.UserRepository
|
||||
userGroupRepo repository.UserGroupRepository
|
||||
keto KetoService
|
||||
outboxRepo repository.KetoOutboxRepository
|
||||
}
|
||||
|
||||
func NewTenantService(repo repository.TenantRepository, userRepo repository.UserRepository, userGroupRepo repository.UserGroupRepository, outboxRepo repository.KetoOutboxRepository) TenantService {
|
||||
return &tenantService{
|
||||
repo: repo,
|
||||
userRepo: userRepo,
|
||||
userGroupRepo: userGroupRepo,
|
||||
outboxRepo: outboxRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tenantService) SetKetoService(keto KetoService) {
|
||||
s.keto = keto
|
||||
}
|
||||
|
||||
func (s *tenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return s.repo.FindByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *tenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
if s.keto == nil {
|
||||
return nil, errors.New("keto service not initialized")
|
||||
}
|
||||
|
||||
allIDs, err := s.keto.ListObjects(ctx, "Tenant", "manage", "User:"+userID)
|
||||
if err != nil {
|
||||
slog.Error("Failed to list manageable tenants from Keto", "userID", userID, "error", err)
|
||||
return []domain.Tenant{}, nil
|
||||
}
|
||||
|
||||
if len(allIDs) == 0 {
|
||||
directAdminIDs, _ := s.keto.ListObjects(ctx, "Tenant", "admins", "User:"+userID)
|
||||
directOwnerIDs, _ := s.keto.ListObjects(ctx, "Tenant", "owners", "User:"+userID)
|
||||
|
||||
idMap := make(map[string]bool)
|
||||
for _, id := range directAdminIDs {
|
||||
idMap[id] = true
|
||||
}
|
||||
for _, id := range directOwnerIDs {
|
||||
idMap[id] = true
|
||||
}
|
||||
|
||||
allIDs = make([]string, 0, len(idMap))
|
||||
for id := range idMap {
|
||||
allIDs = append(allIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allIDs) == 0 {
|
||||
return []domain.Tenant{}, nil
|
||||
}
|
||||
|
||||
return s.repo.FindByIDs(ctx, allIDs)
|
||||
}
|
||||
|
||||
func (s *tenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
if s.keto == nil {
|
||||
return nil, errors.New("keto service not initialized")
|
||||
}
|
||||
|
||||
memberIDs, err := s.keto.ListObjects(ctx, "Tenant", "members", "User:"+userID)
|
||||
if err != nil {
|
||||
slog.Error("Failed to list joined tenants from Keto", "userID", userID, "error", err)
|
||||
return []domain.Tenant{}, nil
|
||||
}
|
||||
|
||||
ownerIDs, _ := s.keto.ListObjects(ctx, "Tenant", "owners", "User:"+userID)
|
||||
adminIDs, _ := s.keto.ListObjects(ctx, "Tenant", "admins", "User:"+userID)
|
||||
|
||||
idMap := make(map[string]bool)
|
||||
for _, id := range memberIDs {
|
||||
idMap[id] = true
|
||||
}
|
||||
for _, id := range ownerIDs {
|
||||
idMap[id] = true
|
||||
}
|
||||
for _, id := range adminIDs {
|
||||
idMap[id] = true
|
||||
}
|
||||
|
||||
allIDs := make([]string, 0, len(idMap))
|
||||
for id := range idMap {
|
||||
allIDs = append(allIDs, id)
|
||||
}
|
||||
|
||||
if len(allIDs) == 0 {
|
||||
return []domain.Tenant{}, nil
|
||||
}
|
||||
|
||||
return s.repo.FindByIDs(ctx, allIDs)
|
||||
}
|
||||
|
||||
func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
if ok, msg := utils.ValidateSlug(slug); !ok {
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
existing, err := s.repo.FindBySlug(ctx, slug)
|
||||
if err == nil && existing != nil {
|
||||
return nil, errors.New("tenant slug already exists")
|
||||
}
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tenant := &domain.Tenant{
|
||||
Type: tenantType,
|
||||
Name: name,
|
||||
Slug: slug,
|
||||
Description: description,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: parentID,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, tenant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.outboxRepo != nil {
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "admins",
|
||||
Subject: "System:global#super_admins",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
|
||||
if tenant.ParentID != nil {
|
||||
if err := s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + *tenant.ParentID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to create outbox entry for tenant hierarchy", "tenant", tenant.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if creatorID != "" {
|
||||
slog.Info("Creating outbox entries for tenant creator", "tenant", tenant.ID, "creator", creatorID)
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "owners",
|
||||
Subject: "User:" + creatorID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "admins",
|
||||
Subject: "User:" + creatorID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "members",
|
||||
Subject: "User:" + creatorID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, d := range domains {
|
||||
if err := s.repo.AddDomain(ctx, tenant.ID, d, true); err != nil {
|
||||
slog.Error("Failed to add domain to tenant", "tenant", slug, "domain", d, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.FindBySlug(ctx, slug)
|
||||
}
|
||||
|
||||
func (s *tenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
if ok, msg := utils.ValidateSlug(slug); !ok {
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
parts := strings.Split(adminEmail, "@")
|
||||
if len(parts) != 2 || parts[1] != domainName {
|
||||
return nil, errors.New("admin email domain must match the tenant domain")
|
||||
}
|
||||
|
||||
tenant := &domain.Tenant{
|
||||
Type: domain.TenantTypeCompany,
|
||||
Name: name,
|
||||
Slug: slug,
|
||||
Description: description,
|
||||
Status: domain.TenantStatusPending,
|
||||
Config: domain.JSONMap{"adminEmail": adminEmail},
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, tenant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.outboxRepo != nil {
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "admins",
|
||||
Subject: "System:global#super_admins",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.repo.AddDomain(ctx, tenant.ID, domainName, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tenant, nil
|
||||
}
|
||||
|
||||
func (s *tenantService) ApproveTenant(ctx context.Context, id string) error {
|
||||
tenant, err := s.repo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tenant.Status = domain.TenantStatusActive
|
||||
if err := s.repo.Update(ctx, tenant); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.outboxRepo != nil {
|
||||
if adminEmail, ok := tenant.Config["adminEmail"].(string); ok && adminEmail != "" {
|
||||
slog.Info("Queueing tenant admin/owner sync to Keto", "tenant", tenant.Slug, "adminEmail", adminEmail)
|
||||
if s.userRepo != nil {
|
||||
user, err := s.userRepo.FindByEmail(ctx, adminEmail)
|
||||
if err == nil && user != nil {
|
||||
slog.Info("Queueing tenant ownership/membership sync to Keto", "tenant", tenant.Slug, "userID", user.ID)
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "owners",
|
||||
Subject: "User:" + user.ID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "admins",
|
||||
Subject: "User:" + user.ID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "members",
|
||||
Subject: "User:" + user.ID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
} else {
|
||||
slog.Info("Tenant admin user not found in local DB, will need manual sync or sync on signup", "email", adminEmail)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
tenant, err := s.repo.FindByDomain(ctx, emailDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tenant.Status != domain.TenantStatusActive {
|
||||
return nil, errors.New("tenant is not active")
|
||||
}
|
||||
|
||||
return tenant, nil
|
||||
}
|
||||
|
||||
func (s *tenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
return s.repo.FindBySlug(ctx, slug)
|
||||
}
|
||||
|
||||
func (s *tenantService) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return s.repo.List(ctx, limit, offset, parentID, search)
|
||||
}
|
||||
|
||||
func (s *tenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
tenant, err := s.repo.FindByDomain(ctx, domainName)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return tenant != nil && tenant.Status == domain.TenantStatusActive, nil
|
||||
}
|
||||
|
||||
func (s *tenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
groups, err := s.repo.ListByType(ctx, domain.TenantTypeCompanyGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
rawConfig, ok := g.Config["autoProvisioning"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled, _ := rawConfig["enabled"].(bool)
|
||||
if !enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
mapping, ok := rawConfig["mappingRules"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
rule, ok := mapping[domainName].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
slug, _ := rule["slug"].(string)
|
||||
name, _ := rule["name"].(string)
|
||||
|
||||
if slug == "" || name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("[Provisioning] Found rule for domain, creating sub-tenant", "domain", domainName, "parent", g.Slug, "newTenant", slug)
|
||||
return s.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "Automatically provisioned via group policy", []string{domainName}, &g.ID, "")
|
||||
}
|
||||
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (s *tenantService) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.repo.DeleteBulk(ctx, ids); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.outboxRepo != nil {
|
||||
for _, id := range ids {
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: id,
|
||||
Relation: "parents",
|
||||
Action: domain.KetoOutboxActionDelete,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
114
baron-sso/backend/internal/service/tenant_service_edge_test.go
Normal file
114
baron-sso/backend/internal/service/tenant_service_edge_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestTenantService_RegisterTenant_DuplicateSlug(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
svc := NewTenantService(mockRepo, nil, nil, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
slug := "duplicate-slug"
|
||||
|
||||
// Mock: slug already exists
|
||||
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "existing-id", Slug: slug}, nil)
|
||||
|
||||
tenant, err := svc.RegisterTenant(ctx, "New Name", slug, domain.TenantTypeCompany, "", nil, nil, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
assert.Nil(t, tenant)
|
||||
}
|
||||
|
||||
func TestTenantService_RegisterTenant_InvalidSlug(t *testing.T) {
|
||||
svc := NewTenantService(nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Case 1: Too short
|
||||
_, err := svc.RegisterTenant(ctx, "Name", "a", domain.TenantTypeCompany, "", nil, nil, "")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Case 2: Invalid characters
|
||||
_, err = svc.RegisterTenant(ctx, "Name", "Invalid Slug!", domain.TenantTypeCompany, "", nil, nil, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTenantService_RequestRegistration_EmailMismatch(t *testing.T) {
|
||||
svc := NewTenantService(nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// admin email domain (gmail.com) != tenant domain (company.com)
|
||||
tenant, err := svc.RequestRegistration(ctx, "Name", "slug", "", "company.com", "admin@gmail.com")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must match")
|
||||
assert.Nil(t, tenant)
|
||||
}
|
||||
|
||||
func TestTenantService_ApproveTenant_NotFound(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
svc := NewTenantService(mockRepo, nil, nil, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
id := "non-existent-id"
|
||||
|
||||
mockRepo.On("FindByID", ctx, id).Return(nil, gorm.ErrRecordNotFound)
|
||||
|
||||
err := svc.ApproveTenant(ctx, id)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, gorm.ErrRecordNotFound))
|
||||
}
|
||||
|
||||
func TestTenantService_GetTenantByDomain_Inactive(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
svc := NewTenantService(mockRepo, nil, nil, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
domainName := "inactive.com"
|
||||
|
||||
mockRepo.On("FindByDomain", ctx, domainName).Return(&domain.Tenant{
|
||||
ID: "t1",
|
||||
Status: domain.TenantStatusPending,
|
||||
}, nil)
|
||||
|
||||
tenant, err := svc.GetTenantByDomain(ctx, domainName)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not active")
|
||||
assert.Nil(t, tenant)
|
||||
}
|
||||
|
||||
func TestTenantService_ApproveTenant_UserNotFound(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
mockUserRepo := new(MockUserRepoForTenant)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
|
||||
svc := NewTenantService(mockRepo, mockUserRepo, nil, mockOutbox)
|
||||
ctx := context.Background()
|
||||
tenantID := "t1"
|
||||
adminEmail := "notfound@tenant.com"
|
||||
|
||||
tenant := &domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "tenant-slug",
|
||||
Config: domain.JSONMap{"adminEmail": adminEmail},
|
||||
}
|
||||
|
||||
mockRepo.On("FindByID", ctx, tenantID).Return(tenant, nil)
|
||||
mockRepo.On("Update", ctx, mock.Anything).Return(nil)
|
||||
// User not found in DB
|
||||
mockUserRepo.On("FindByEmail", adminEmail).Return(nil, gorm.ErrRecordNotFound)
|
||||
|
||||
// Outbox should not be called since user is not found
|
||||
|
||||
err := svc.ApproveTenant(ctx, tenantID)
|
||||
assert.NoError(t, err) // Should succeed but just log that user is not found
|
||||
mockRepo.AssertExpectations(t)
|
||||
mockUserRepo.AssertExpectations(t)
|
||||
mockOutbox.AssertNotCalled(t, "Create")
|
||||
}
|
||||
345
baron-sso/backend/internal/service/tenant_service_test.go
Normal file
345
baron-sso/backend/internal/service/tenant_service_test.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// --- Local Mocks to avoid collisions ---
|
||||
|
||||
type MockTenantRepoForSvc struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) Create(ctx context.Context, tenant *domain.Tenant) error {
|
||||
return m.Called(ctx, tenant).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) Update(ctx context.Context, tenant *domain.Tenant) error {
|
||||
return m.Called(ctx, tenant).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, slug)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, domainName)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
|
||||
return m.Called(ctx, tenantID, domainName, verified).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
args := m.Called(ctx, limit, offset, parentID, search)
|
||||
return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, tenantType)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepoForSvc) DeleteBulk(ctx context.Context, ids []string) error {
|
||||
args := m.Called(ctx, ids)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockKetoSvcForTenant struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKetoSvcForTenant) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return m.Called(ctx, namespace, object, relation, subject).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoSvcForTenant) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return m.Called(ctx, namespace, object, relation, subject).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockKetoSvcForTenant) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Get(0).([]RelationTuple), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoSvcForTenant) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
args := m.Called(ctx, namespace, relation, subject)
|
||||
return args.Get(0).([]string), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKetoSvcForTenant) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockUserRepoForTenant struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) Create(ctx context.Context, user *domain.User) error { return nil }
|
||||
func (m *MockUserRepoForTenant) Update(ctx context.Context, user *domain.User) error { return nil }
|
||||
func (m *MockUserRepoForTenant) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
args := m.Called(email)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) Delete(ctx context.Context, id string) error {
|
||||
return m.Called(ctx, id).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
|
||||
return nil, 0, "", nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
args := m.Called(tenantID)
|
||||
return int64(args.Int(0)), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, tenantIDs)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
args := m.Called(tenantIDs)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(map[string]int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, codes)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
||||
args := m.Called(ctx, codes)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(map[string]int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepoForTenant) DB() *gorm.DB {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestTenantService_RegisterTenant_AutoVerify(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
svc := NewTenantService(mockRepo, nil, nil, mockOutbox)
|
||||
|
||||
ctx := context.Background()
|
||||
name := "New Tenant"
|
||||
slug := "new-tenant"
|
||||
domains := []string{"example.com"}
|
||||
|
||||
// Use .Once() to ensure correct return values for sequential calls to FindBySlug
|
||||
mockRepo.On("FindBySlug", ctx, slug).Return(nil, nil).Once()
|
||||
mockRepo.On("Create", ctx, mock.Anything).Return(nil)
|
||||
mockRepo.On("AddDomain", ctx, mock.Anything, "example.com", true).Return(nil)
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(k *domain.KetoOutbox) bool {
|
||||
return k.Relation == "admins" && k.Subject == "System:global#super_admins"
|
||||
})).Return(nil)
|
||||
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "t1", Slug: slug}, nil).Once()
|
||||
|
||||
tenant, err := svc.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "", domains, nil, "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tenant)
|
||||
assert.Equal(t, "t1", tenant.ID)
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantService_RegisterTenant_WithCreator(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
svc := NewTenantService(mockRepo, nil, nil, mockOutbox)
|
||||
|
||||
ctx := context.Background()
|
||||
name := "Creator Tenant"
|
||||
slug := "creator-tenant"
|
||||
creatorID := "creator-uuid"
|
||||
tenantID := "t-new"
|
||||
|
||||
mockRepo.On("FindBySlug", ctx, slug).Return(nil, nil).Once()
|
||||
mockRepo.On("Create", ctx, mock.MatchedBy(func(t *domain.Tenant) bool {
|
||||
return t.Slug == slug
|
||||
})).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*domain.Tenant)
|
||||
t.ID = tenantID
|
||||
}).Return(nil)
|
||||
|
||||
// Expect global super admin sync
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(k *domain.KetoOutbox) bool {
|
||||
return k.Relation == "admins" && k.Subject == "System:global#super_admins"
|
||||
})).Return(nil)
|
||||
// Expect owners sync
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "owners" && e.Subject == "User:"+creatorID
|
||||
})).Return(nil)
|
||||
// Expect admins sync
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "admins" && e.Subject == "User:"+creatorID
|
||||
})).Return(nil)
|
||||
// Expect members sync
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+creatorID
|
||||
})).Return(nil)
|
||||
|
||||
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: tenantID, Slug: slug}, nil).Once()
|
||||
|
||||
tenant, err := svc.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "", nil, nil, creatorID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tenant)
|
||||
mockRepo.AssertExpectations(t)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantService_RequestRegistration_NoVerify(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
svc := NewTenantService(mockRepo, nil, nil, mockOutbox)
|
||||
|
||||
ctx := context.Background()
|
||||
name := "Public Tenant"
|
||||
slug := "public-tenant"
|
||||
domainName := "public.com"
|
||||
adminEmail := "admin@public.com"
|
||||
|
||||
mockRepo.On("Create", ctx, mock.MatchedBy(func(tenant *domain.Tenant) bool {
|
||||
return tenant.Status == domain.TenantStatusPending
|
||||
})).Return(nil)
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(k *domain.KetoOutbox) bool {
|
||||
return k.Relation == "admins" && k.Subject == "System:global#super_admins"
|
||||
})).Return(nil)
|
||||
mockRepo.On("AddDomain", ctx, mock.Anything, domainName, false).Return(nil)
|
||||
|
||||
tenant, err := svc.RequestRegistration(ctx, name, slug, "", domainName, adminEmail)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tenant)
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantService_ApproveTenant_SyncAdmin(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
mockUserRepo := new(MockUserRepoForTenant)
|
||||
mockKeto := new(MockKetoSvcForTenant)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
|
||||
svc := NewTenantService(mockRepo, mockUserRepo, nil, mockOutbox)
|
||||
svc.SetKetoService(mockKeto)
|
||||
|
||||
ctx := context.Background()
|
||||
tenantID := "t1"
|
||||
adminEmail := "admin@tenant.com"
|
||||
userID := "user-uuid"
|
||||
|
||||
tenant := &domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "tenant-slug",
|
||||
Config: domain.JSONMap{"adminEmail": adminEmail},
|
||||
}
|
||||
|
||||
mockRepo.On("FindByID", ctx, tenantID).Return(tenant, nil)
|
||||
mockRepo.On("Update", ctx, mock.Anything).Return(nil)
|
||||
mockUserRepo.On("FindByEmail", adminEmail).Return(&domain.User{ID: userID, Email: adminEmail}, nil)
|
||||
// Now using Outbox instead of direct Keto call
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "owners" && e.Subject == "User:"+userID
|
||||
})).Return(nil)
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "admins" && e.Subject == "User:"+userID
|
||||
})).Return(nil)
|
||||
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil)
|
||||
|
||||
err := svc.ApproveTenant(ctx, tenantID)
|
||||
assert.NoError(t, err)
|
||||
mockRepo.AssertExpectations(t)
|
||||
mockUserRepo.AssertExpectations(t)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantService_ListTenants(t *testing.T) {
|
||||
mockRepo := new(MockTenantRepoForSvc)
|
||||
svc := NewTenantService(mockRepo, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tenants := []domain.Tenant{{ID: "t1", Name: "Tenant 1"}}
|
||||
mockRepo.On("List", ctx, 10, 0, "", "").Return(tenants, int64(1), nil)
|
||||
|
||||
result, total, err := svc.ListTenants(ctx, 10, 0, "", "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Equal(t, tenants, result)
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
466
baron-sso/backend/internal/service/user_group_service.go
Normal file
466
baron-sso/backend/internal/service/user_group_service.go
Normal file
@@ -0,0 +1,466 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type UserGroupService interface {
|
||||
Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error)
|
||||
Get(ctx context.Context, id string) (*domain.UserGroup, error)
|
||||
List(ctx context.Context, tenantID string) ([]domain.UserGroup, error)
|
||||
Delete(ctx context.Context, tenantID, groupID string) error
|
||||
Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error)
|
||||
SetWorksmobileSyncer(syncer WorksmobileSyncer)
|
||||
|
||||
// Member Management with Keto Sync
|
||||
AddMember(ctx context.Context, groupID, userID string) error
|
||||
RemoveMember(ctx context.Context, groupID, userID string) error
|
||||
|
||||
// Permission Management
|
||||
ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error)
|
||||
AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error
|
||||
RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error
|
||||
}
|
||||
|
||||
type userGroupService struct {
|
||||
repo repository.UserGroupRepository
|
||||
userRepo repository.UserRepository
|
||||
tenantRepo repository.TenantRepository
|
||||
ketoService KetoService
|
||||
outboxRepo repository.KetoOutboxRepository
|
||||
kratos KratosAdminService
|
||||
worksmobile WorksmobileSyncer
|
||||
}
|
||||
|
||||
func NewUserGroupService(
|
||||
repo repository.UserGroupRepository,
|
||||
userRepo repository.UserRepository,
|
||||
tenantRepo repository.TenantRepository,
|
||||
keto KetoService,
|
||||
outbox repository.KetoOutboxRepository,
|
||||
kratos KratosAdminService,
|
||||
) UserGroupService {
|
||||
return &userGroupService{
|
||||
repo: repo,
|
||||
userRepo: userRepo,
|
||||
tenantRepo: tenantRepo,
|
||||
ketoService: keto,
|
||||
outboxRepo: outbox,
|
||||
kratos: kratos,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userGroupService) SetWorksmobileSyncer(syncer WorksmobileSyncer) {
|
||||
s.worksmobile = syncer
|
||||
}
|
||||
|
||||
func (s *userGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) {
|
||||
// For Keto and Tenant hierarchy, if no parent group, the company tenant is the parent.
|
||||
actualParentID := parentID
|
||||
if actualParentID == nil || *actualParentID == "" {
|
||||
actualParentID = &tenantID
|
||||
}
|
||||
|
||||
// Validate parent tenant exists
|
||||
if _, err := s.tenantRepo.FindByID(ctx, *actualParentID); err != nil {
|
||||
return nil, fmt.Errorf("parent tenant not found or invalid: %w", err)
|
||||
}
|
||||
|
||||
unitID := uuid.NewString()
|
||||
|
||||
// 1. Create Tenant (Type: ORGANIZATION)
|
||||
groupTenant := &domain.Tenant{
|
||||
ID: unitID,
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: actualParentID,
|
||||
Name: name,
|
||||
Slug: fmt.Sprintf("ug-%s", unitID[:8]),
|
||||
Description: description,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
|
||||
if err := s.tenantRepo.Create(ctx, groupTenant); err != nil {
|
||||
slog.Error("Failed to create tenant record for user group", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Create UserGroup metadata
|
||||
// parent_id in user_groups refers to other groups, so use original parentID (which might be nil)
|
||||
group := &domain.UserGroup{
|
||||
ID: unitID,
|
||||
TenantID: tenantID,
|
||||
ParentID: parentID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
UnitType: unitType,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, group); err != nil {
|
||||
// Rollback Tenant creation? Or handle via cleanup job. For now, just log.
|
||||
slog.Error("Failed to create user group metadata after creating tenant", "tenantId", unitID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Keto Hierarchy via Outbox: Tenant:<child_id>#parents@Tenant:<parent_id>
|
||||
if s.outboxRepo != nil {
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: unitID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + *actualParentID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *userGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) {
|
||||
// Implementation for Update
|
||||
return nil, nil // Placeholder
|
||||
}
|
||||
|
||||
func (s *userGroupService) Delete(ctx context.Context, tenantID, groupID string) error {
|
||||
// Implementation for Delete
|
||||
return nil // Placeholder
|
||||
}
|
||||
|
||||
func (s *userGroupService) populateMembers(ctx context.Context, group *domain.UserGroup) {
|
||||
tuples, err := s.ketoService.ListRelations(ctx, "Tenant", group.ID, "members", "")
|
||||
if err != nil {
|
||||
slog.Error("Failed to fetch group members from keto", "error", err, "group_id", group.ID)
|
||||
group.Members = []domain.User{}
|
||||
return
|
||||
}
|
||||
|
||||
var userIDs []string
|
||||
for _, t := range tuples {
|
||||
sid := t.SubjectID
|
||||
if len(sid) > 5 && sid[:5] == "User:" {
|
||||
userIDs = append(userIDs, sid[5:])
|
||||
} else {
|
||||
userIDs = append(userIDs, sid)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) > 0 {
|
||||
members, err := s.userRepo.FindByIDs(ctx, userIDs)
|
||||
if err != nil {
|
||||
slog.Error("Failed to fetch member details from db", "error", err)
|
||||
}
|
||||
|
||||
memberMap := make(map[string]domain.User)
|
||||
for _, m := range members {
|
||||
memberMap[m.ID] = m
|
||||
}
|
||||
|
||||
var finalMembers []domain.User
|
||||
for _, uid := range userIDs {
|
||||
if m, ok := memberMap[uid]; ok {
|
||||
finalMembers = append(finalMembers, m)
|
||||
} else if s.kratos != nil {
|
||||
identity, err := s.kratos.GetIdentity(ctx, uid)
|
||||
if err == nil && identity != nil {
|
||||
name, _ := identity.Traits["name"].(string)
|
||||
email, _ := identity.Traits["email"].(string)
|
||||
finalMembers = append(finalMembers, domain.User{
|
||||
ID: uid,
|
||||
Name: name,
|
||||
Email: email,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
group.Members = finalMembers
|
||||
} else {
|
||||
group.Members = []domain.User{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) {
|
||||
group, err := s.repo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.populateMembers(ctx, group)
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *userGroupService) List(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
|
||||
groups, err := s.repo.ListByTenantID(ctx, tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.ketoService == nil {
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
for i := range groups {
|
||||
s.populateMembers(ctx, &groups[i])
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *userGroupService) AddMember(ctx context.Context, groupID, userID string) error {
|
||||
// Validate group exists
|
||||
group, err := s.repo.FindByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user group not found: %w", err)
|
||||
}
|
||||
|
||||
var tenant *domain.Tenant
|
||||
if s.tenantRepo != nil {
|
||||
tenant, _ = s.tenantRepo.FindByID(ctx, group.TenantID)
|
||||
}
|
||||
|
||||
// Kratos는 identity SSOT이고 조직/부서 정보의 원장이 아니므로 AddMember에서 traits를 수정하지 않습니다.
|
||||
if s.userRepo != nil && tenant != nil {
|
||||
localUser, err := s.userRepo.FindByID(ctx, userID)
|
||||
if err != nil || localUser == nil {
|
||||
if s.kratos != nil {
|
||||
identity, identityErr := s.kratos.GetIdentity(ctx, userID)
|
||||
if identityErr == nil && identity != nil {
|
||||
localUser = mapUserGroupKratosIdentityToLocalUser(*identity)
|
||||
} else {
|
||||
slog.Warn("Skipping local user sync during AddMember because identity read is unavailable", "user", userID, "error", identityErr)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("Skipping local user sync during AddMember because identity projection is unavailable", "user", userID, "error", err)
|
||||
}
|
||||
}
|
||||
if localUser != nil {
|
||||
localUser.TenantID = &tenant.ID
|
||||
localUser.Department = group.Name
|
||||
if err := s.userRepo.Update(ctx, localUser); err != nil {
|
||||
slog.Error("Failed to sync local user during AddMember", "user", userID, "error", err)
|
||||
} else if s.worksmobile != nil {
|
||||
if err := s.worksmobile.EnqueueUserUpsertIfInScope(ctx, *localUser); err != nil {
|
||||
slog.Warn("Failed to enqueue Worksmobile user sync during AddMember", "user", userID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keto via Outbox: Tenant:<groupID>#members@User:<userID>
|
||||
if s.outboxRepo != nil {
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: groupID,
|
||||
Relation: "members",
|
||||
Subject: "User:" + userID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
|
||||
// Also add direct Tenant membership to Keto for member counting
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: group.TenantID,
|
||||
Relation: "members",
|
||||
Subject: "User:" + userID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapUserGroupKratosIdentityToLocalUser(identity KratosIdentity) *domain.User {
|
||||
traits := identity.Traits
|
||||
now := time.Now()
|
||||
createdAt := identity.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = now
|
||||
}
|
||||
updatedAt := identity.UpdatedAt
|
||||
if updatedAt.IsZero() {
|
||||
updatedAt = now
|
||||
}
|
||||
|
||||
role, ok := domain.NormalizeRoleAlias(userGroupTraitString(traits, "role"))
|
||||
if !ok {
|
||||
role, ok = domain.NormalizeRoleAlias(userGroupTraitString(traits, "grade"))
|
||||
if !ok {
|
||||
role = domain.RoleUser
|
||||
}
|
||||
}
|
||||
grade := userGroupTraitString(traits, "grade")
|
||||
if _, ok := domain.NormalizeRoleAlias(grade); ok {
|
||||
grade = ""
|
||||
}
|
||||
|
||||
user := &domain.User{
|
||||
ID: identity.ID,
|
||||
Email: userGroupTraitString(traits, "email"),
|
||||
Name: userGroupTraitString(traits, "name"),
|
||||
Phone: domain.NormalizePhoneNumber(userGroupTraitString(traits, "phone_number")),
|
||||
Role: role,
|
||||
Status: userGroupIdentityStatus(identity.State),
|
||||
Department: userGroupTraitString(traits, "department"),
|
||||
Grade: grade,
|
||||
Position: userGroupTraitString(traits, "position"),
|
||||
JobTitle: userGroupTraitString(traits, "jobTitle"),
|
||||
AffiliationType: userGroupTraitString(traits, "affiliationType"),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
Metadata: make(domain.JSONMap),
|
||||
}
|
||||
if tenantID := userGroupTraitString(traits, "tenant_id"); tenantID != "" {
|
||||
user.TenantID = &tenantID
|
||||
}
|
||||
if relyingPartyID := userGroupTraitString(traits, "relying_party_id"); relyingPartyID != "" {
|
||||
user.RelyingPartyID = &relyingPartyID
|
||||
}
|
||||
coreTraits := map[string]bool{
|
||||
"email": true, "name": true, "phone_number": true,
|
||||
"grade": true, "role": true, "companyCode": true, "company_code": true,
|
||||
"companyCodes": true, "tenant_id": true, "department": true,
|
||||
"position": true, "jobTitle": true, "affiliationType": true,
|
||||
"relying_party_id": true, "custom_login_ids": true, "id": true,
|
||||
}
|
||||
for key, value := range traits {
|
||||
if !coreTraits[key] {
|
||||
user.Metadata[key] = value
|
||||
}
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func userGroupTraitString(traits map[string]any, key string) string {
|
||||
if traits == nil {
|
||||
return ""
|
||||
}
|
||||
value, ok := traits[key]
|
||||
if !ok || value == nil {
|
||||
return ""
|
||||
}
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
return fmt.Sprint(value)
|
||||
}
|
||||
|
||||
func userGroupTraitStringArray(traits map[string]any, key string) []string {
|
||||
if traits == nil {
|
||||
return nil
|
||||
}
|
||||
switch value := traits[key].(type) {
|
||||
case []string:
|
||||
return value
|
||||
case []any:
|
||||
items := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
if str, ok := item.(string); ok && str != "" {
|
||||
items = append(items, str)
|
||||
}
|
||||
}
|
||||
return items
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func userGroupIdentityStatus(state string) string {
|
||||
return domain.NormalizeUserStatus(state)
|
||||
}
|
||||
|
||||
func (s *userGroupService) RemoveMember(ctx context.Context, groupID, userID string) error {
|
||||
// Validate group exists
|
||||
if _, err := s.repo.FindByID(ctx, groupID); err != nil {
|
||||
return fmt.Errorf("user group not found: %w", err)
|
||||
}
|
||||
|
||||
// Keto via Outbox: Delete relation
|
||||
if s.outboxRepo != nil {
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: groupID,
|
||||
Relation: "members",
|
||||
Subject: "User:" + userID,
|
||||
Action: domain.KetoOutboxActionDelete,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userGroupService) ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error) {
|
||||
// Query: namespace=Tenant, subject=Tenant:groupID#members
|
||||
subject := "Tenant:" + groupID + "#members"
|
||||
tuples, err := s.ketoService.ListRelations(ctx, "Tenant", "", "", subject)
|
||||
if err != nil {
|
||||
slog.Error("Failed to fetch group roles from keto", "error", err, "group_id", groupID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var roles []domain.GroupRole
|
||||
tenantIDs := make([]string, 0, len(tuples))
|
||||
for _, t := range tuples {
|
||||
tenantIDs = append(tenantIDs, t.Object)
|
||||
}
|
||||
|
||||
if len(tenantIDs) > 0 {
|
||||
tenantList, err := s.tenantRepo.FindByIDs(ctx, tenantIDs)
|
||||
if err != nil {
|
||||
slog.Error("Failed to fetch tenant details for roles", "error", err)
|
||||
}
|
||||
|
||||
tenantMap := make(map[string]string)
|
||||
for _, t := range tenantList {
|
||||
tenantMap[t.ID] = t.Name
|
||||
}
|
||||
|
||||
for _, t := range tuples {
|
||||
roles = append(roles, domain.GroupRole{
|
||||
TenantID: t.Object,
|
||||
TenantName: tenantMap[t.Object],
|
||||
Relation: t.Relation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
func (s *userGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error {
|
||||
// Validate group exists
|
||||
if _, err := s.repo.FindByID(ctx, groupID); err != nil {
|
||||
return fmt.Errorf("user group not found: %w", err)
|
||||
}
|
||||
|
||||
// Keto via Outbox: Tenant:<tenantID>#<relation>@Tenant:<groupID>#members
|
||||
if s.outboxRepo != nil {
|
||||
subject := "Tenant:" + groupID + "#members"
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenantID,
|
||||
Relation: relation,
|
||||
Subject: subject,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userGroupService) RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error {
|
||||
// Keto via Outbox: Delete relation
|
||||
if s.outboxRepo != nil {
|
||||
subject := "Tenant:" + groupID + "#members"
|
||||
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenantID,
|
||||
Relation: relation,
|
||||
Subject: subject,
|
||||
Action: domain.KetoOutboxActionDelete,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestUserGroupService_Create_InvalidParentID(t *testing.T) {
|
||||
mockRepo := new(MockUserGroupRepository)
|
||||
mockTenantRepo := new(MockTenantRepository)
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
svc := NewUserGroupService(mockRepo, nil, mockTenantRepo, mockKeto, mockOutbox, nil)
|
||||
|
||||
tenantID := "company-1"
|
||||
invalidParentID := "invalid-uuid"
|
||||
name := "Invalid Parent Group"
|
||||
description := ""
|
||||
unitType := "Team"
|
||||
|
||||
// Mock: TenantRepo returns record not found for invalidParentID
|
||||
mockTenantRepo.On("FindByID", mock.Anything, invalidParentID).Return(nil, gorm.ErrRecordNotFound).Once()
|
||||
|
||||
// No Create calls should happen on any repo if parent is invalid
|
||||
mockRepo.AssertNotCalled(t, "Create")
|
||||
mockTenantRepo.AssertNotCalled(t, "Create")
|
||||
mockOutbox.AssertNotCalled(t, "Create")
|
||||
|
||||
group, err := svc.Create(context.Background(), tenantID, &invalidParentID, name, description, unitType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "parent tenant not found or invalid")
|
||||
assert.Nil(t, group)
|
||||
|
||||
mockTenantRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserGroupService_AddMember_GroupNotFound(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
|
||||
|
||||
groupID := "non-existent-group"
|
||||
userID := "user-1"
|
||||
|
||||
// Mock: Group does not exist
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(nil, gorm.ErrRecordNotFound)
|
||||
|
||||
// No Outbox call should happen if group is not found
|
||||
mockOutbox.AssertNotCalled(t, "Create")
|
||||
|
||||
err := svc.AddMember(context.Background(), groupID, userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user group not found")
|
||||
|
||||
mockUserGroupRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserGroupService_RemoveMember_GroupNotFound(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
|
||||
|
||||
groupID := "non-existent-group"
|
||||
userID := "user-1"
|
||||
|
||||
// Mock: Group does not exist
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(nil, gorm.ErrRecordNotFound)
|
||||
|
||||
// No Outbox call should happen if group is not found
|
||||
mockOutbox.AssertNotCalled(t, "Create")
|
||||
|
||||
err := svc.RemoveMember(context.Background(), groupID, userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user group not found")
|
||||
|
||||
mockUserGroupRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserGroupService_AssignRoleToTenant_GroupNotFound(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
|
||||
|
||||
groupID := "non-existent-group"
|
||||
tenantID := "tenant-alpha"
|
||||
relation := "manage"
|
||||
|
||||
// Mock: Group does not exist
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(nil, gorm.ErrRecordNotFound)
|
||||
|
||||
// No Outbox call should happen if group is not found
|
||||
mockOutbox.AssertNotCalled(t, "Create")
|
||||
|
||||
err := svc.AssignRoleToTenant(context.Background(), groupID, tenantID, relation)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user group not found")
|
||||
|
||||
mockUserGroupRepo.AssertExpectations(t)
|
||||
}
|
||||
463
baron-sso/backend/internal/service/user_group_service_test.go
Normal file
463
baron-sso/backend/internal/service/user_group_service_test.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// --- Mocks for Repositories ---
|
||||
|
||||
type MockUserGroupRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockUserGroupRepository) Create(ctx context.Context, group *domain.UserGroup) error {
|
||||
return m.Called(ctx, group).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupRepository) Update(ctx context.Context, group *domain.UserGroup) error {
|
||||
return m.Called(ctx, group).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupRepository) Delete(ctx context.Context, id string) error {
|
||||
return m.Called(ctx, id).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupRepository) FindByID(ctx context.Context, id string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupRepository) ListByTenantID(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
type MockUserRepository struct {
|
||||
mock.Mock
|
||||
updatedUsers []domain.User
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Create(ctx context.Context, user *domain.User) error { return nil }
|
||||
func (m *MockUserRepository) Update(ctx context.Context, user *domain.User) error {
|
||||
copied := *user
|
||||
m.updatedUsers = append(m.updatedUsers, copied)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Delete(ctx context.Context, id string) error {
|
||||
return m.Called(ctx, id).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, ids)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
|
||||
return nil, 0, "", nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
args := m.Called(tenantID)
|
||||
return int64(args.Int(0)), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, tenantIDs)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
args := m.Called(tenantIDs)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(map[string]int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, codes)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
||||
args := m.Called(ctx, codes)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(map[string]int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) DB() *gorm.DB {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeUserGroupWorksmobileSyncer struct {
|
||||
userUpserts []domain.User
|
||||
}
|
||||
|
||||
func (f *fakeUserGroupWorksmobileSyncer) EnqueueTenantUpsertIfInScope(ctx context.Context, tenant domain.Tenant) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeUserGroupWorksmobileSyncer) EnqueueTenantDeleteIfInScope(ctx context.Context, tenant domain.Tenant) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeUserGroupWorksmobileSyncer) EnqueueUserUpsertIfInScope(ctx context.Context, user domain.User) error {
|
||||
f.userUpserts = append(f.userUpserts, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeUserGroupWorksmobileSyncer) EnqueueUserDeleteIfInScope(ctx context.Context, user domain.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type MockKetoOutboxRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockTenantRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) Create(ctx context.Context, tenant *domain.Tenant) error {
|
||||
return m.Called(ctx, tenant).Error(0)
|
||||
}
|
||||
func (m *MockTenantRepository) Update(ctx context.Context, tenant *domain.Tenant) error { return nil }
|
||||
func (m *MockTenantRepository) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, ids)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTenantRepository) DeleteBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUserGroupService_Create(t *testing.T) {
|
||||
mockRepo := new(MockUserGroupRepository)
|
||||
mockTenantRepo := new(MockTenantRepository)
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
svc := NewUserGroupService(mockRepo, nil, mockTenantRepo, mockKeto, mockOutbox, nil)
|
||||
|
||||
tenantID := "company-1"
|
||||
parentID := "parent-group-id"
|
||||
name := "Test Group"
|
||||
description := "Group Description"
|
||||
unitType := "Team"
|
||||
|
||||
// Mock Tenant FindByID for parent check
|
||||
mockTenantRepo.On("FindByID", mock.Anything, parentID).Return(&domain.Tenant{ID: parentID}, nil)
|
||||
|
||||
// Mock Tenant creation (Polymorphic)
|
||||
mockTenantRepo.On("Create", mock.Anything, mock.MatchedBy(func(ten *domain.Tenant) bool {
|
||||
return ten.Type == domain.TenantTypeOrganization && ten.Name == name && *ten.ParentID == parentID
|
||||
})).Return(nil)
|
||||
|
||||
// Mock UserGroup creation
|
||||
mockRepo.On("Create", mock.Anything, mock.MatchedBy(func(g *domain.UserGroup) bool {
|
||||
return g.Name == name && *g.ParentID == parentID && g.TenantID == tenantID
|
||||
})).Return(nil)
|
||||
|
||||
// Mock Keto sync via Outbox
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Relation == "parents" && e.Subject == "Tenant:"+parentID
|
||||
})).Return(nil)
|
||||
|
||||
group, err := svc.Create(context.Background(), tenantID, &parentID, name, description, unitType)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, group)
|
||||
mockTenantRepo.AssertExpectations(t)
|
||||
mockRepo.AssertExpectations(t)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserGroupService_AddMember(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
mockUserRepo := new(MockUserRepository)
|
||||
mockTenantRepo := new(MockTenantRepository)
|
||||
mockKratos := new(MockKratosAdminServiceShared)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, mockUserRepo, mockTenantRepo, nil, mockOutbox, mockKratos)
|
||||
|
||||
groupID := "group-1"
|
||||
userID := "user-1"
|
||||
tenantID := "tenant-1"
|
||||
tenantSlug := "tenant-slug"
|
||||
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, TenantID: tenantID, Name: "Sales"}, nil)
|
||||
mockUserRepo.On("FindByID", mock.Anything, userID).Return(&domain.User{ID: userID}, nil)
|
||||
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: tenantSlug}, nil)
|
||||
|
||||
// Mock local user repo update (Ignored since Update is hardcoded to return nil without calling m.Called)
|
||||
// mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
|
||||
// return u.CompanyCode == tenantSlug && *u.TenantID == tenantID && u.Department == "Sales"
|
||||
// })).Return(nil)
|
||||
|
||||
// First Outbox Create for Group
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil).Once()
|
||||
|
||||
// Second Outbox Create for Tenant
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := svc.AddMember(context.Background(), groupID, userID)
|
||||
assert.NoError(t, err)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
mockKratos.AssertExpectations(t)
|
||||
mockKratos.AssertNotCalled(t, "GetIdentity", mock.Anything, userID)
|
||||
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
|
||||
// mockUserRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserGroupService_AddMemberUpsertsLocalReadModelWhenMissing(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
mockUserRepo := new(MockUserRepository)
|
||||
mockTenantRepo := new(MockTenantRepository)
|
||||
mockKratos := new(MockKratosAdminServiceShared)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, mockUserRepo, mockTenantRepo, nil, mockOutbox, mockKratos)
|
||||
|
||||
groupID := "group-1"
|
||||
userID := "user-1"
|
||||
tenantID := "tenant-1"
|
||||
tenantSlug := "tenant-slug"
|
||||
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, TenantID: tenantID, Name: "Sales"}, nil)
|
||||
mockUserRepo.On("FindByID", mock.Anything, userID).Return(nil, gorm.ErrRecordNotFound)
|
||||
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: tenantSlug}, nil)
|
||||
mockKratos.On("GetIdentity", mock.Anything, userID).Return(&KratosIdentity{
|
||||
ID: userID,
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "User Test",
|
||||
},
|
||||
State: "active",
|
||||
}, nil)
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil).Once()
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := svc.AddMember(context.Background(), groupID, userID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, mockUserRepo.updatedUsers, 1)
|
||||
assert.Equal(t, userID, mockUserRepo.updatedUsers[0].ID)
|
||||
assert.Empty(t, mockUserRepo.updatedUsers[0].CompanyCode)
|
||||
assert.NotNil(t, mockUserRepo.updatedUsers[0].TenantID)
|
||||
assert.Equal(t, tenantID, *mockUserRepo.updatedUsers[0].TenantID)
|
||||
assert.Equal(t, "Sales", mockUserRepo.updatedUsers[0].Department)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
mockKratos.AssertExpectations(t)
|
||||
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestUserGroupService_AddMemberEnqueuesWorksmobileUserSync(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
mockUserRepo := new(MockUserRepository)
|
||||
mockTenantRepo := new(MockTenantRepository)
|
||||
mockKratos := new(MockKratosAdminServiceShared)
|
||||
worksmobile := &fakeUserGroupWorksmobileSyncer{}
|
||||
svc := NewUserGroupService(mockUserGroupRepo, mockUserRepo, mockTenantRepo, nil, mockOutbox, mockKratos)
|
||||
svc.SetWorksmobileSyncer(worksmobile)
|
||||
|
||||
groupID := "group-1"
|
||||
userID := "user-1"
|
||||
tenantID := "tenant-1"
|
||||
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, TenantID: tenantID, Name: "Sales"}, nil)
|
||||
mockUserRepo.On("FindByID", mock.Anything, userID).Return(&domain.User{
|
||||
ID: userID,
|
||||
Email: "user@test.com",
|
||||
Name: "User Test",
|
||||
Status: "active",
|
||||
}, nil)
|
||||
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: "tenant-slug"}, nil)
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil).Once()
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := svc.AddMember(context.Background(), groupID, userID)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, worksmobile.userUpserts, 1)
|
||||
assert.Equal(t, userID, worksmobile.userUpserts[0].ID)
|
||||
assert.NotNil(t, worksmobile.userUpserts[0].TenantID)
|
||||
assert.Equal(t, tenantID, *worksmobile.userUpserts[0].TenantID)
|
||||
assert.Equal(t, "Sales", worksmobile.userUpserts[0].Department)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
mockKratos.AssertExpectations(t)
|
||||
mockKratos.AssertNotCalled(t, "GetIdentity", mock.Anything, userID)
|
||||
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestUserGroupService_AssignRoleToTenant(t *testing.T) {
|
||||
mockOutbox := new(MockKetoOutboxRepositoryShared)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
|
||||
|
||||
groupID := "group-1"
|
||||
tenantID := "tenant-alpha"
|
||||
relation := "manage"
|
||||
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID}, nil)
|
||||
|
||||
expectedSubject := "Tenant:" + groupID + "#members"
|
||||
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
|
||||
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == relation && e.Subject == expectedSubject
|
||||
})).Return(nil)
|
||||
|
||||
err := svc.AssignRoleToTenant(context.Background(), groupID, tenantID, relation)
|
||||
assert.NoError(t, err)
|
||||
mockOutbox.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserGroupService_ListRoles(t *testing.T) {
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockTenantRepo := new(MockTenantRepository)
|
||||
mockUserGroupRepo := new(MockUserGroupRepository)
|
||||
svc := NewUserGroupService(mockUserGroupRepo, nil, mockTenantRepo, mockKeto, nil, nil)
|
||||
|
||||
groupID := "group-1"
|
||||
subject := "Tenant:" + groupID + "#members"
|
||||
|
||||
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID}, nil)
|
||||
|
||||
tuples := []RelationTuple{
|
||||
{Object: "t1", Relation: "manage", SubjectID: subject},
|
||||
{Object: "t2", Relation: "view", SubjectID: subject},
|
||||
}
|
||||
mockKeto.On("ListRelations", mock.Anything, "Tenant", "", "", subject).Return(tuples, nil)
|
||||
|
||||
tenants := []domain.Tenant{
|
||||
{ID: "t1", Name: "Tenant One"},
|
||||
{ID: "t2", Name: "Tenant Two"},
|
||||
}
|
||||
mockTenantRepo.On("FindByIDs", mock.Anything, []string{"t1", "t2"}).Return(tenants, nil)
|
||||
|
||||
roles, err := svc.ListRoles(context.Background(), groupID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, roles, 2)
|
||||
}
|
||||
|
||||
func TestUserGroupService_Get_WithKratosFallback(t *testing.T) {
|
||||
mockRepo := new(MockUserGroupRepository)
|
||||
mockKeto := new(MockKetoServiceShared)
|
||||
mockUserRepo := new(MockUserRepository)
|
||||
mockKratos := new(MockKratosAdminServiceShared)
|
||||
|
||||
svc := NewUserGroupService(mockRepo, mockUserRepo, nil, mockKeto, nil, mockKratos)
|
||||
|
||||
groupID := "group-1"
|
||||
mockRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, Name: "Test"}, nil)
|
||||
|
||||
tuples := []RelationTuple{
|
||||
{Object: groupID, Relation: "members", SubjectID: "User:u1"},
|
||||
}
|
||||
mockKeto.On("ListRelations", mock.Anything, "Tenant", groupID, "members", "").Return(tuples, nil)
|
||||
|
||||
mockUserRepo.On("FindByIDs", mock.Anything, []string{"u1"}).Return([]domain.User{}, nil)
|
||||
|
||||
mockKratos.On("GetIdentity", mock.Anything, "u1").Return(&KratosIdentity{
|
||||
ID: "u1",
|
||||
Traits: map[string]any{"name": "User One", "email": "user1@example.com"},
|
||||
}, nil)
|
||||
|
||||
group, err := svc.Get(context.Background(), groupID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, group)
|
||||
assert.Len(t, group.Members, 1)
|
||||
assert.Equal(t, "User One", group.Members[0].Name)
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserProjectionSyncService struct {
|
||||
kratos KratosAdminService
|
||||
repo repository.UserProjectionRepository
|
||||
}
|
||||
|
||||
type UserProjectionReconciler interface {
|
||||
Reconcile(ctx context.Context) (int, error)
|
||||
}
|
||||
|
||||
func NewUserProjectionSyncService(kratos KratosAdminService, repo repository.UserProjectionRepository) *UserProjectionSyncService {
|
||||
return &UserProjectionSyncService{
|
||||
kratos: kratos,
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserProjectionSyncService) Reconcile(ctx context.Context) (int, error) {
|
||||
if s == nil || s.kratos == nil || s.repo == nil {
|
||||
return 0, fmt.Errorf("user projection sync dependencies are not configured")
|
||||
}
|
||||
|
||||
identities, err := s.kratos.ListIdentities(ctx)
|
||||
if err != nil {
|
||||
_ = s.repo.MarkFailed(ctx, err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
users := make([]domain.User, 0, len(identities))
|
||||
for _, identity := range identities {
|
||||
users = append(users, MapKratosIdentityToLocalUser(identity))
|
||||
}
|
||||
if err := s.repo.ReplaceAllFromKratos(ctx, users); err != nil {
|
||||
_ = s.repo.MarkFailed(ctx, err)
|
||||
return 0, err
|
||||
}
|
||||
return len(users), nil
|
||||
}
|
||||
|
||||
func MapKratosIdentityToLocalUser(identity KratosIdentity) domain.User {
|
||||
traits := identity.Traits
|
||||
now := time.Now()
|
||||
createdAt := identity.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = now
|
||||
}
|
||||
updatedAt := identity.UpdatedAt
|
||||
if updatedAt.IsZero() {
|
||||
updatedAt = now
|
||||
}
|
||||
|
||||
role, ok := domain.NormalizeRoleAlias(kratosProjectionTraitString(traits, "role"))
|
||||
if !ok {
|
||||
role, ok = domain.NormalizeRoleAlias(kratosProjectionTraitString(traits, "grade"))
|
||||
if !ok {
|
||||
role = domain.RoleUser
|
||||
}
|
||||
}
|
||||
grade := kratosProjectionTraitString(traits, "grade")
|
||||
if _, ok := domain.NormalizeRoleAlias(grade); ok {
|
||||
grade = ""
|
||||
}
|
||||
|
||||
user := domain.User{
|
||||
ID: identity.ID,
|
||||
Email: kratosProjectionTraitString(traits, "email"),
|
||||
Name: kratosProjectionTraitString(traits, "name"),
|
||||
Phone: domain.NormalizePhoneNumber(kratosProjectionTraitString(traits, "phone_number")),
|
||||
Role: role,
|
||||
Status: normalizeProjectionStatus(identity.State),
|
||||
Department: kratosProjectionTraitString(traits, "department"),
|
||||
Grade: grade,
|
||||
Position: kratosProjectionTraitString(traits, "position"),
|
||||
JobTitle: kratosProjectionTraitString(traits, "jobTitle"),
|
||||
AffiliationType: kratosProjectionTraitString(traits, "affiliationType"),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
Metadata: make(domain.JSONMap),
|
||||
}
|
||||
if tenantID := kratosProjectionTraitString(traits, "tenant_id"); tenantID != "" {
|
||||
user.TenantID = &tenantID
|
||||
}
|
||||
if relyingPartyID := kratosProjectionTraitString(traits, "relying_party_id"); relyingPartyID != "" {
|
||||
user.RelyingPartyID = &relyingPartyID
|
||||
}
|
||||
|
||||
coreTraits := map[string]bool{
|
||||
"email": true, "name": true, "phone_number": true,
|
||||
"grade": true, "role": true,
|
||||
"companyCode": true, "company_code": true, "companyCodes": true,
|
||||
"tenant_id": true, "department": true,
|
||||
"position": true, "jobTitle": true, "affiliationType": true,
|
||||
"relying_party_id": true, "custom_login_ids": true, "id": true,
|
||||
}
|
||||
for key, value := range traits {
|
||||
if !coreTraits[key] {
|
||||
user.Metadata[key] = value
|
||||
}
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func kratosProjectionTraitString(traits map[string]any, key string) string {
|
||||
if traits == nil {
|
||||
return ""
|
||||
}
|
||||
value, ok := traits[key]
|
||||
if !ok || value == nil {
|
||||
return ""
|
||||
}
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
return fmt.Sprint(value)
|
||||
}
|
||||
|
||||
func kratosProjectionTraitStringArray(traits map[string]any, key string) []string {
|
||||
if traits == nil {
|
||||
return nil
|
||||
}
|
||||
switch value := traits[key].(type) {
|
||||
case []string:
|
||||
return value
|
||||
case []any:
|
||||
items := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
if str, ok := item.(string); ok && strings.TrimSpace(str) != "" {
|
||||
items = append(items, str)
|
||||
}
|
||||
}
|
||||
return items
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeProjectionStatus(state string) string {
|
||||
normalized := domain.NormalizeUserStatus(state)
|
||||
if normalized == "" {
|
||||
return domain.UserStatusActive
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeUserProjectionRepo struct {
|
||||
replacedUsers []domain.User
|
||||
failedErr error
|
||||
replaceErr error
|
||||
}
|
||||
|
||||
func (f *fakeUserProjectionRepo) IsReady(ctx context.Context) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (f *fakeUserProjectionRepo) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
|
||||
return domain.UserProjectionStatus{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
|
||||
f.replacedUsers = append([]domain.User(nil), users...)
|
||||
return f.replaceErr
|
||||
}
|
||||
|
||||
func (f *fakeUserProjectionRepo) MarkFailed(ctx context.Context, syncErr error) error {
|
||||
f.failedErr = syncErr
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUserProjectionSyncService_ReconcileReplacesProjectionFromKratos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
kratos := new(MockKratosAdminServiceShared)
|
||||
repo := &fakeUserProjectionRepo{}
|
||||
svc := NewUserProjectionSyncService(kratos, repo)
|
||||
|
||||
tenantID := "00000000-0000-0000-0000-000000000001"
|
||||
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000101",
|
||||
Traits: map[string]any{
|
||||
"email": "one@example.com",
|
||||
"name": "One",
|
||||
"phone_number": "+821012345678",
|
||||
"companyCode": "saman",
|
||||
"companyCodes": []any{"saman", "group-a"},
|
||||
"tenant_id": tenantID,
|
||||
"department": "DX",
|
||||
"customAttr": "kept",
|
||||
},
|
||||
State: "active",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
count, err := svc.Reconcile(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
require.Len(t, repo.replacedUsers, 1)
|
||||
assert.Equal(t, "one@example.com", repo.replacedUsers[0].Email)
|
||||
assert.Equal(t, "One", repo.replacedUsers[0].Name)
|
||||
assert.Equal(t, "+821012345678", repo.replacedUsers[0].Phone)
|
||||
assert.Empty(t, repo.replacedUsers[0].CompanyCode)
|
||||
assert.Empty(t, repo.replacedUsers[0].CompanyCodes)
|
||||
require.NotNil(t, repo.replacedUsers[0].TenantID)
|
||||
assert.Equal(t, tenantID, *repo.replacedUsers[0].TenantID)
|
||||
assert.Equal(t, "kept", repo.replacedUsers[0].Metadata["customAttr"])
|
||||
assert.NoError(t, repo.failedErr)
|
||||
kratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserProjectionSyncService_ReconcileDeduplicatesKoreanCountryCodePhone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
kratos := new(MockKratosAdminServiceShared)
|
||||
repo := &fakeUserProjectionRepo{}
|
||||
svc := NewUserProjectionSyncService(kratos, repo)
|
||||
|
||||
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{
|
||||
{
|
||||
ID: "00000000-0000-0000-0000-000000000102",
|
||||
Traits: map[string]any{
|
||||
"email": "two@example.com",
|
||||
"name": "Two",
|
||||
"phone_number": "+82 +821091917771",
|
||||
},
|
||||
State: "active",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
count, err := svc.Reconcile(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
require.Len(t, repo.replacedUsers, 1)
|
||||
assert.Equal(t, "+821091917771", repo.replacedUsers[0].Phone)
|
||||
kratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserProjectionSyncService_ReconcileMarksFailedWhenKratosFails(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
kratos := new(MockKratosAdminServiceShared)
|
||||
repo := &fakeUserProjectionRepo{}
|
||||
svc := NewUserProjectionSyncService(kratos, repo)
|
||||
|
||||
expectedErr := errors.New("kratos down")
|
||||
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{}, expectedErr).Once()
|
||||
|
||||
count, err := svc.Reconcile(ctx)
|
||||
|
||||
assert.Equal(t, 0, count)
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
assert.ErrorIs(t, repo.failedErr, expectedErr)
|
||||
assert.Empty(t, repo.replacedUsers)
|
||||
kratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMapKratosIdentityToLocalUserPreservesArchivedStatus(t *testing.T) {
|
||||
user := MapKratosIdentityToLocalUser(KratosIdentity{
|
||||
ID: "00000000-0000-0000-0000-000000000201",
|
||||
State: domain.UserStatusArchived,
|
||||
Traits: map[string]any{
|
||||
"email": "archived@example.com",
|
||||
"name": "Archived User",
|
||||
},
|
||||
})
|
||||
|
||||
assert.Equal(t, domain.UserStatusArchived, user.Status)
|
||||
}
|
||||
1527
baron-sso/backend/internal/service/worksmobile_client.go
Normal file
1527
baron-sso/backend/internal/service/worksmobile_client.go
Normal file
File diff suppressed because it is too large
Load Diff
1596
baron-sso/backend/internal/service/worksmobile_client_test.go
Normal file
1596
baron-sso/backend/internal/service/worksmobile_client_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1124
baron-sso/backend/internal/service/worksmobile_live_flow_test.go
Normal file
1124
baron-sso/backend/internal/service/worksmobile_live_flow_test.go
Normal file
File diff suppressed because it is too large
Load Diff
972
baron-sso/backend/internal/service/worksmobile_mapper.go
Normal file
972
baron-sso/backend/internal/service/worksmobile_mapper.go
Normal file
@@ -0,0 +1,972 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/mail"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
WorksmobileUserActionUpsert = "UPSERT"
|
||||
WorksmobileUserActionSuspend = "SUSPEND"
|
||||
)
|
||||
|
||||
type WorksmobileOrgUnitPayload struct {
|
||||
DomainID int64 `json:"domainId"`
|
||||
OrgUnitName string `json:"orgUnitName"`
|
||||
Email string `json:"email,omitempty"`
|
||||
OrgUnitExternalKey string `json:"orgUnitExternalKey"`
|
||||
ParentOrgUnitID string `json:"parentOrgUnitId,omitempty"`
|
||||
DisplayOrder int `json:"displayOrder"`
|
||||
}
|
||||
|
||||
type WorksmobileUserPayload struct {
|
||||
DomainID int64 `json:"domainId"`
|
||||
Email string `json:"email"`
|
||||
UserExternalKey string `json:"userExternalKey,omitempty"`
|
||||
UserName WorksmobileUserName `json:"userName"`
|
||||
CellPhone string `json:"cellPhone,omitempty"`
|
||||
EmployeeNumber string `json:"employeeNumber,omitempty"`
|
||||
PrivateEmail string `json:"privateEmail,omitempty"`
|
||||
AliasEmails []string `json:"aliasEmails,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
PasswordConfig WorksmobilePasswordConfig `json:"passwordConfig,omitempty"`
|
||||
Task string `json:"task,omitempty"`
|
||||
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
|
||||
}
|
||||
|
||||
type WorksmobileUserName struct {
|
||||
LastName string `json:"lastName,omitempty"`
|
||||
}
|
||||
|
||||
type WorksmobilePasswordConfig struct {
|
||||
PasswordCreationType string `json:"passwordCreationType"`
|
||||
Password string `json:"password"`
|
||||
ChangePasswordAtNextLogin *bool `json:"changePasswordAtNextLogin,omitempty"`
|
||||
}
|
||||
|
||||
func (c WorksmobilePasswordConfig) IsZero() bool {
|
||||
return strings.TrimSpace(c.PasswordCreationType) == "" &&
|
||||
strings.TrimSpace(c.Password) == "" &&
|
||||
c.ChangePasswordAtNextLogin == nil
|
||||
}
|
||||
|
||||
func (p WorksmobileUserPayload) MarshalJSON() ([]byte, error) {
|
||||
type payloadJSON struct {
|
||||
DomainID int64 `json:"domainId"`
|
||||
Email string `json:"email"`
|
||||
UserExternalKey string `json:"userExternalKey,omitempty"`
|
||||
UserName WorksmobileUserName `json:"userName"`
|
||||
CellPhone string `json:"cellPhone,omitempty"`
|
||||
EmployeeNumber string `json:"employeeNumber,omitempty"`
|
||||
PrivateEmail string `json:"privateEmail,omitempty"`
|
||||
AliasEmails []string `json:"aliasEmails,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
PasswordConfig *WorksmobilePasswordConfig `json:"passwordConfig,omitempty"`
|
||||
Task string `json:"task,omitempty"`
|
||||
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
|
||||
}
|
||||
|
||||
var passwordConfig *WorksmobilePasswordConfig
|
||||
if !p.PasswordConfig.IsZero() {
|
||||
passwordConfig = &p.PasswordConfig
|
||||
}
|
||||
|
||||
return json.Marshal(payloadJSON{
|
||||
DomainID: p.DomainID,
|
||||
Email: p.Email,
|
||||
UserExternalKey: p.UserExternalKey,
|
||||
UserName: p.UserName,
|
||||
CellPhone: p.CellPhone,
|
||||
EmployeeNumber: p.EmployeeNumber,
|
||||
PrivateEmail: p.PrivateEmail,
|
||||
AliasEmails: p.AliasEmails,
|
||||
Locale: p.Locale,
|
||||
PasswordConfig: passwordConfig,
|
||||
Task: p.Task,
|
||||
Organizations: p.Organizations,
|
||||
})
|
||||
}
|
||||
|
||||
type WorksmobilePasswordResetPayload struct {
|
||||
Email string `json:"email"`
|
||||
PasswordConfig WorksmobilePasswordConfig `json:"passwordConfig"`
|
||||
}
|
||||
|
||||
type WorksmobileUserOrganization struct {
|
||||
DomainID int64 `json:"domainId,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Primary bool `json:"primary"`
|
||||
OrgUnits []WorksmobileUserOrgUnit `json:"orgUnits"`
|
||||
}
|
||||
|
||||
type WorksmobileUserOrgUnit struct {
|
||||
OrgUnitID string `json:"orgUnitId"`
|
||||
Primary bool `json:"primary"`
|
||||
PositionID string `json:"positionId,omitempty"`
|
||||
IsManager *bool `json:"isManager,omitempty"`
|
||||
}
|
||||
|
||||
func BuildWorksmobileOrgUnitPayload(tenant domain.Tenant, rootConfig domain.JSONMap, displayOrder int) (WorksmobileOrgUnitPayload, error) {
|
||||
return BuildWorksmobileOrgUnitPayloadForDomainTenant(tenant, tenant, rootConfig, displayOrder)
|
||||
}
|
||||
|
||||
func BuildWorksmobileOrgUnitPayloadForDomainTenant(tenant domain.Tenant, domainTenant domain.Tenant, rootConfig domain.JSONMap, displayOrder int) (WorksmobileOrgUnitPayload, error) {
|
||||
if err := ValidateWorksmobileExternalKey(tenant.ID); err != nil {
|
||||
return WorksmobileOrgUnitPayload{}, err
|
||||
}
|
||||
if displayOrder < 1 {
|
||||
displayOrder = 1
|
||||
}
|
||||
domainID, err := ResolveWorksmobileDomainIDFromTenant(domainTenant, rootConfig)
|
||||
if err != nil {
|
||||
return WorksmobileOrgUnitPayload{}, err
|
||||
}
|
||||
payload := WorksmobileOrgUnitPayload{
|
||||
DomainID: domainID,
|
||||
OrgUnitName: strings.TrimSpace(tenant.Name),
|
||||
Email: buildWorksmobileOrgUnitEmail(tenant, domainTenant),
|
||||
OrgUnitExternalKey: tenant.ID,
|
||||
DisplayOrder: displayOrder,
|
||||
}
|
||||
if tenant.ParentID != nil && *tenant.ParentID != "" {
|
||||
if err := ValidateWorksmobileExternalKey(*tenant.ParentID); err != nil {
|
||||
return WorksmobileOrgUnitPayload{}, err
|
||||
}
|
||||
payload.ParentOrgUnitID = "externalKey:" + *tenant.ParentID
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func buildWorksmobileOrgUnitEmail(tenant domain.Tenant, domainTenant domain.Tenant) string {
|
||||
slug := strings.ToLower(strings.TrimSpace(tenant.Slug))
|
||||
if slug == "" {
|
||||
return ""
|
||||
}
|
||||
if domainName := worksmobileTenantMailDomain(domainTenant); domainName != "" {
|
||||
return slug + "@" + domainName
|
||||
}
|
||||
for _, candidate := range append([]domain.TenantDomain{}, domainTenant.Domains...) {
|
||||
domainName := strings.ToLower(strings.TrimSpace(candidate.Domain))
|
||||
if domainName != "" {
|
||||
return slug + "@" + domainName
|
||||
}
|
||||
}
|
||||
for _, candidate := range tenant.Domains {
|
||||
domainName := strings.ToLower(strings.TrimSpace(candidate.Domain))
|
||||
if domainName != "" {
|
||||
return slug + "@" + domainName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func worksmobileTenantMailDomain(tenant domain.Tenant) string {
|
||||
envKey := strings.TrimSuffix(worksmobileTenantDomainIDEnvKey(tenant), "_DOMAIN_ID")
|
||||
if domainName := strings.ToLower(strings.TrimSpace(os.Getenv("WORKS_DEFAULT_DOMAIN_" + envKey))); domainName != "" {
|
||||
return domainName
|
||||
}
|
||||
if domainName := strings.ToLower(strings.TrimSpace(os.Getenv(envKey + "_MAIL_DOMAIN"))); domainName != "" {
|
||||
return domainName
|
||||
}
|
||||
switch envKey {
|
||||
case "SAMAN":
|
||||
return "samaneng.com"
|
||||
case "HANMAC":
|
||||
return "hanmaceng.co.kr"
|
||||
case "GPDTDC":
|
||||
return "baroncs.co.kr"
|
||||
case "HALLA":
|
||||
return "hallasanup.com"
|
||||
case "BARONGROUP":
|
||||
return "brsw.kr"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func BuildWorksmobileUserPayload(user domain.User, tenant domain.Tenant, rootConfig domain.JSONMap) (WorksmobileUserPayload, error) {
|
||||
return BuildWorksmobileUserPayloadForDomainTenant(user, tenant, tenant, rootConfig)
|
||||
}
|
||||
|
||||
func BuildWorksmobileUserPayloadForDomainTenant(user domain.User, tenant domain.Tenant, _ domain.Tenant, rootConfig domain.JSONMap) (WorksmobileUserPayload, error) {
|
||||
return BuildWorksmobileUserPayloadForDomainTenants(user, tenant, map[string]domain.Tenant{tenant.ID: tenant}, rootConfig)
|
||||
}
|
||||
|
||||
func BuildWorksmobileUserPayloadForDomainTenants(user domain.User, tenant domain.Tenant, tenantByID map[string]domain.Tenant, rootConfig domain.JSONMap) (WorksmobileUserPayload, error) {
|
||||
if err := ValidateWorksmobileExternalKey(user.ID); err != nil {
|
||||
return WorksmobileUserPayload{}, err
|
||||
}
|
||||
if tenant.ID == "" {
|
||||
return WorksmobileUserPayload{}, errors.New("tenant is required")
|
||||
}
|
||||
if tenantByID == nil {
|
||||
tenantByID = map[string]domain.Tenant{}
|
||||
}
|
||||
tenantByID[tenant.ID] = tenant
|
||||
domainID, err := ResolveWorksmobileAccountDomainIDFromEmail(user.Email, tenant, rootConfig)
|
||||
if err != nil {
|
||||
return WorksmobileUserPayload{}, err
|
||||
}
|
||||
employeeNumber := metadataEmployeeNumber(user.Metadata)
|
||||
organizations, task, err := buildWorksmobileUserOrganizations(user, tenant, tenantByID, rootConfig)
|
||||
if err != nil {
|
||||
return WorksmobileUserPayload{}, err
|
||||
}
|
||||
if task == "" {
|
||||
task = strings.TrimSpace(user.JobTitle)
|
||||
}
|
||||
payload := WorksmobileUserPayload{
|
||||
DomainID: domainID,
|
||||
Email: strings.TrimSpace(user.Email),
|
||||
UserExternalKey: user.ID,
|
||||
UserName: WorksmobileUserName{LastName: strings.TrimSpace(user.Name)},
|
||||
CellPhone: domain.NormalizePhoneNumber(user.Phone),
|
||||
EmployeeNumber: employeeNumber,
|
||||
Locale: "ko_KR",
|
||||
Task: task,
|
||||
Organizations: organizations,
|
||||
}
|
||||
payload.AliasEmails = BuildWorksmobileAliasEmails(user, tenant)
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
type worksmobileAppointment struct {
|
||||
TenantID string
|
||||
IsPrimary bool
|
||||
IsManager bool
|
||||
HasManager bool
|
||||
JobTitle string
|
||||
PositionID string
|
||||
Source string
|
||||
}
|
||||
|
||||
func buildWorksmobileUserOrganizations(user domain.User, tenant domain.Tenant, tenantByID map[string]domain.Tenant, rootConfig domain.JSONMap) ([]WorksmobileUserOrganization, string, error) {
|
||||
appointments := worksmobileAppointmentsFromMetadata(user.Metadata)
|
||||
if len(appointments) == 0 {
|
||||
appointments = []worksmobileAppointment{{TenantID: tenant.ID, IsPrimary: true}}
|
||||
} else if !worksmobileAppointmentsContainTenant(appointments, tenant.ID) && !worksmobileAppointmentsHavePrimary(appointments) {
|
||||
appointments = append([]worksmobileAppointment{{
|
||||
TenantID: tenant.ID,
|
||||
IsPrimary: true,
|
||||
JobTitle: strings.TrimSpace(user.JobTitle),
|
||||
PositionID: metadataString(user.Metadata, "worksmobilePositionId", "positionId", "position_id"),
|
||||
}}, appointments...)
|
||||
}
|
||||
accountDomainTenant := worksmobileAccountDomainTenantFromEmail(user.Email, tenant, tenantByID)
|
||||
accountDomainEnvKey := worksmobileTenantDomainIDEnvKey(accountDomainTenant)
|
||||
if !worksmobileAppointmentsContainDomain(appointments, tenantByID, accountDomainEnvKey) && accountDomainTenant.ID != "" {
|
||||
appointments = append([]worksmobileAppointment{{
|
||||
TenantID: accountDomainTenant.ID,
|
||||
IsPrimary: true,
|
||||
JobTitle: strings.TrimSpace(user.JobTitle),
|
||||
PositionID: metadataString(user.Metadata, "worksmobilePositionId", "positionId", "position_id"),
|
||||
}}, appointments...)
|
||||
}
|
||||
|
||||
organizations := make([]WorksmobileUserOrganization, 0)
|
||||
organizationIndexByDomainID := map[int64]int{}
|
||||
seen := map[string]bool{}
|
||||
task := ""
|
||||
for _, appointment := range appointments {
|
||||
if appointment.TenantID == "" || seen[appointment.TenantID] {
|
||||
continue
|
||||
}
|
||||
appointmentTenant, ok := tenantByID[appointment.TenantID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if worksmobileShouldSkipEmailDomainRootAppointment(appointment, appointmentTenant, appointments, tenantByID) {
|
||||
seen[appointment.TenantID] = true
|
||||
continue
|
||||
}
|
||||
if isWorksmobileDomainRootTenant(appointmentTenant) {
|
||||
if appointment.IsPrimary && strings.TrimSpace(appointment.JobTitle) != "" && task == "" {
|
||||
task = strings.TrimSpace(appointment.JobTitle)
|
||||
}
|
||||
seen[appointment.TenantID] = true
|
||||
continue
|
||||
}
|
||||
if err := ValidateWorksmobileExternalKey(appointmentTenant.ID); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
domainTenant := worksmobileDomainClassificationTenant(appointmentTenant, tenantByID)
|
||||
domainID, err := ResolveWorksmobileDomainIDFromTenant(domainTenant, rootConfig)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
isAccountDomain := worksmobileTenantDomainIDEnvKey(domainTenant) == accountDomainEnvKey
|
||||
isPrimaryOrganization := isAccountDomain && !worksmobileOrganizationsHavePrimary(organizations)
|
||||
organizationIndex, organizationExists := organizationIndexByDomainID[domainID]
|
||||
orgUnit := WorksmobileUserOrgUnit{
|
||||
OrgUnitID: "externalKey:" + appointmentTenant.ID,
|
||||
Primary: !organizationExists,
|
||||
PositionID: appointment.PositionID,
|
||||
}
|
||||
if appointment.HasManager {
|
||||
isManager := appointment.IsManager
|
||||
orgUnit.IsManager = &isManager
|
||||
}
|
||||
if organizationExists {
|
||||
if isPrimaryOrganization {
|
||||
organizations[organizationIndex].Primary = true
|
||||
organizations[organizationIndex].Email = worksmobileOrganizationEmail(user, domainTenant)
|
||||
}
|
||||
organizations[organizationIndex].OrgUnits = append(organizations[organizationIndex].OrgUnits, orgUnit)
|
||||
} else {
|
||||
organizationIndexByDomainID[domainID] = len(organizations)
|
||||
organizations = append(organizations, WorksmobileUserOrganization{
|
||||
DomainID: domainID,
|
||||
Email: worksmobileOrganizationEmail(user, domainTenant),
|
||||
Primary: isPrimaryOrganization,
|
||||
OrgUnits: []WorksmobileUserOrgUnit{orgUnit},
|
||||
})
|
||||
}
|
||||
if isPrimaryOrganization && strings.TrimSpace(appointment.JobTitle) != "" {
|
||||
task = strings.TrimSpace(appointment.JobTitle)
|
||||
}
|
||||
seen[appointment.TenantID] = true
|
||||
}
|
||||
if len(organizations) == 0 {
|
||||
return nil, task, nil
|
||||
}
|
||||
if !worksmobileOrganizationsHavePrimary(organizations) {
|
||||
organizations[0].Primary = true
|
||||
if len(organizations[0].OrgUnits) > 0 {
|
||||
organizations[0].OrgUnits[0].Primary = true
|
||||
}
|
||||
}
|
||||
sortWorksmobileOrganizations(organizations)
|
||||
return organizations, task, nil
|
||||
}
|
||||
|
||||
func worksmobileAppointmentsContainTenant(appointments []worksmobileAppointment, tenantID string) bool {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
return false
|
||||
}
|
||||
for _, appointment := range appointments {
|
||||
if strings.TrimSpace(appointment.TenantID) == tenantID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func worksmobileAppointmentsHavePrimary(appointments []worksmobileAppointment) bool {
|
||||
for _, appointment := range appointments {
|
||||
if appointment.IsPrimary {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func worksmobileAppointmentsContainDomain(appointments []worksmobileAppointment, tenantByID map[string]domain.Tenant, envKey string) bool {
|
||||
for _, appointment := range appointments {
|
||||
tenant, ok := tenantByID[appointment.TenantID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
domainTenant := worksmobileDomainClassificationTenant(tenant, tenantByID)
|
||||
if worksmobileTenantDomainIDEnvKey(domainTenant) == envKey {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func worksmobileShouldSkipEmailDomainRootAppointment(appointment worksmobileAppointment, tenant domain.Tenant, appointments []worksmobileAppointment, tenantByID map[string]domain.Tenant) bool {
|
||||
if strings.TrimSpace(appointment.Source) != "email_domain" || !isWorksmobileDomainRootTenant(tenant) {
|
||||
return false
|
||||
}
|
||||
envKey := worksmobileTenantDomainIDEnvKey(tenant)
|
||||
for _, candidate := range appointments {
|
||||
if strings.TrimSpace(candidate.TenantID) == "" || strings.TrimSpace(candidate.TenantID) == tenant.ID {
|
||||
continue
|
||||
}
|
||||
candidateTenant, ok := tenantByID[candidate.TenantID]
|
||||
if !ok || isWorksmobileDomainRootTenant(candidateTenant) {
|
||||
continue
|
||||
}
|
||||
if worksmobileTenantDomainIDEnvKey(worksmobileDomainClassificationTenant(candidateTenant, tenantByID)) == envKey {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func worksmobileOrganizationsHavePrimary(organizations []WorksmobileUserOrganization) bool {
|
||||
for _, organization := range organizations {
|
||||
if organization.Primary {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func worksmobileAppointmentsFromMetadata(metadata domain.JSONMap) []worksmobileAppointment {
|
||||
rawAppointments, ok := metadata["additionalAppointments"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
appointments := make([]worksmobileAppointment, 0, len(rawAppointments))
|
||||
for _, raw := range rawAppointments {
|
||||
item, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
appointment := worksmobileAppointment{
|
||||
TenantID: metadataString(domain.JSONMap(item), "tenantId", "tenant_id"),
|
||||
IsPrimary: metadataBool(domain.JSONMap(item), "isPrimary", "primary"),
|
||||
JobTitle: metadataString(domain.JSONMap(item), "jobTitle", "job_title", "task"),
|
||||
PositionID: metadataString(domain.JSONMap(item), "worksmobilePositionId", "positionId", "position_id"),
|
||||
Source: metadataString(domain.JSONMap(item), "assignmentSource", "source"),
|
||||
}
|
||||
if isManager, ok := metadataOptionalBool(domain.JSONMap(item), "isManager", "lead", "isLead"); ok {
|
||||
appointment.IsManager = isManager
|
||||
appointment.HasManager = true
|
||||
}
|
||||
appointments = append(appointments, appointment)
|
||||
}
|
||||
return appointments
|
||||
}
|
||||
|
||||
func sortWorksmobileOrganizations(organizations []WorksmobileUserOrganization) {
|
||||
sort.SliceStable(organizations, func(i, j int) bool {
|
||||
if organizations[i].Primary != organizations[j].Primary {
|
||||
return organizations[i].Primary
|
||||
}
|
||||
left := ""
|
||||
right := ""
|
||||
if len(organizations[i].OrgUnits) > 0 {
|
||||
left = organizations[i].OrgUnits[0].OrgUnitID
|
||||
}
|
||||
if len(organizations[j].OrgUnits) > 0 {
|
||||
right = organizations[j].OrgUnits[0].OrgUnitID
|
||||
}
|
||||
return left < right
|
||||
})
|
||||
}
|
||||
|
||||
func BuildWorksmobileAliasEmails(user domain.User, tenant domain.Tenant) []string {
|
||||
candidates := make([]string, 0)
|
||||
for _, key := range []string{
|
||||
"aliasEmails",
|
||||
"alias_emails",
|
||||
"worksmobileAliasEmails",
|
||||
"sub_email",
|
||||
"secondary_email",
|
||||
"secondary_emails",
|
||||
"additional_email",
|
||||
"additional_emails",
|
||||
"naverworks_sub_email",
|
||||
} {
|
||||
candidates = append(candidates, metadataStringList(user.Metadata, key)...)
|
||||
}
|
||||
employeeNumber := metadataEmployeeNumber(user.Metadata)
|
||||
if isHanmacWorksmobileTenant(tenant) && employeeNumber != "" {
|
||||
candidates = append(candidates, employeeNumber+"@hanmaceng.co.kr")
|
||||
}
|
||||
return normalizeWorksmobileAliasEmails(user.Email, candidates)
|
||||
}
|
||||
|
||||
func normalizeWorksmobileAliasEmails(primaryEmail string, candidates []string) []string {
|
||||
result := make([]string, 0, len(candidates))
|
||||
seen := map[string]bool{}
|
||||
primary := strings.ToLower(strings.TrimSpace(primaryEmail))
|
||||
for _, candidate := range candidates {
|
||||
normalized := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if normalized == "" || normalized == primary || seen[normalized] {
|
||||
continue
|
||||
}
|
||||
if _, err := mail.ParseAddress(normalized); err != nil {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = true
|
||||
result = append(result, normalized)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func ValidateWorksmobileAliasEmails(primaryEmail string, aliasEmails []string, existingEmails map[string]string) error {
|
||||
seen := map[string]string{strings.ToLower(strings.TrimSpace(primaryEmail)): primaryEmail}
|
||||
|
||||
for _, aliasEmail := range aliasEmails {
|
||||
normalized := strings.ToLower(strings.TrimSpace(aliasEmail))
|
||||
if _, err := mail.ParseAddress(normalized); err != nil {
|
||||
return err
|
||||
}
|
||||
if previous, ok := seen[normalized]; ok {
|
||||
return fmt.Errorf("worksmobile alias email duplicates: %s and %s", previous, aliasEmail)
|
||||
}
|
||||
if owner, ok := existingEmails[normalized]; ok {
|
||||
return fmt.Errorf("worksmobile alias email %s는 이미 사용 중입니다: %s", normalized, owner)
|
||||
}
|
||||
seen[normalized] = aliasEmail
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateWorksmobileInitialPassword() string {
|
||||
digits := "0123456789"
|
||||
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
symbols := "!@#$%"
|
||||
all := digits + letters + symbols
|
||||
|
||||
password := []byte{
|
||||
randomChar(digits),
|
||||
randomChar(letters),
|
||||
randomChar(symbols),
|
||||
}
|
||||
for len(password) < 16 {
|
||||
password = append(password, randomChar(all))
|
||||
}
|
||||
shuffleBytes(password)
|
||||
return string(password)
|
||||
}
|
||||
|
||||
func randomChar(chars string) byte {
|
||||
if chars == "" {
|
||||
return 'x'
|
||||
}
|
||||
index, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
|
||||
if err != nil {
|
||||
return chars[0]
|
||||
}
|
||||
return chars[index.Int64()]
|
||||
}
|
||||
|
||||
func shuffleBytes(values []byte) {
|
||||
for i := len(values) - 1; i > 0; i-- {
|
||||
j, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
values[i], values[j.Int64()] = values[j.Int64()], values[i]
|
||||
}
|
||||
}
|
||||
|
||||
func WorksmobileUserStatusAction(status string) string {
|
||||
normalized := domain.NormalizeUserStatus(status)
|
||||
if domain.IsWorksDeprovisionUserStatus(normalized) {
|
||||
return domain.WorksmobileActionDelete
|
||||
}
|
||||
switch normalized {
|
||||
case domain.UserStatusSuspended:
|
||||
return WorksmobileUserActionSuspend
|
||||
default:
|
||||
return WorksmobileUserActionUpsert
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateWorksmobileExternalKey(value string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return errors.New("external key is required")
|
||||
}
|
||||
if strings.ContainsAny(value, `%\#/?`) {
|
||||
return fmt.Errorf("external key contains unsupported character: %s", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResolveWorksmobileDomainIDFromTenant(tenant domain.Tenant, _ domain.JSONMap) (int64, error) {
|
||||
envKey := worksmobileTenantDomainIDEnvKey(tenant)
|
||||
if domainID, ok := worksmobileDomainIDFromEnv(envKey); ok {
|
||||
return domainID, nil
|
||||
}
|
||||
return 0, fmt.Errorf("worksmobile domain id env is missing for tenant: %s", envKey)
|
||||
}
|
||||
|
||||
func ResolveWorksmobileAccountDomainIDFromEmail(email string, fallbackTenant domain.Tenant, rootConfig domain.JSONMap) (int64, error) {
|
||||
switch worksmobileEmailDomainName(email) {
|
||||
case "samaneng.com":
|
||||
if domainID, ok := worksmobileDomainIDFromEnv("SAMAN_DOMAIN_ID"); ok {
|
||||
return domainID, nil
|
||||
}
|
||||
case "hanmaceng.co.kr":
|
||||
if domainID, ok := worksmobileDomainIDFromEnv("HANMAC_DOMAIN_ID"); ok {
|
||||
return domainID, nil
|
||||
}
|
||||
case "baroncs.co.kr":
|
||||
if domainID, ok := worksmobileDomainIDFromEnv("GPDTDC_DOMAIN_ID"); ok {
|
||||
return domainID, nil
|
||||
}
|
||||
case "hallasanup.com":
|
||||
if domainID, ok := worksmobileDomainIDFromEnv("HALLA_DOMAIN_ID"); ok {
|
||||
return domainID, nil
|
||||
}
|
||||
case "brsw.kr":
|
||||
if domainID, ok := worksmobileDomainIDFromEnv("BARONGROUP_DOMAIN_ID"); ok {
|
||||
return domainID, nil
|
||||
}
|
||||
}
|
||||
return ResolveWorksmobileDomainIDFromTenant(fallbackTenant, rootConfig)
|
||||
}
|
||||
|
||||
func worksmobileAccountDomainTenantFromEmail(email string, fallbackTenant domain.Tenant, tenantByID map[string]domain.Tenant) domain.Tenant {
|
||||
envKey := worksmobileDomainIDEnvKeyFromEmail(email)
|
||||
for _, tenant := range tenantByID {
|
||||
if isWorksmobileDomainRootTenant(tenant) && worksmobileTenantDomainIDEnvKey(tenant) == envKey {
|
||||
return tenant
|
||||
}
|
||||
}
|
||||
for _, tenant := range tenantByID {
|
||||
if worksmobileTenantDomainIDEnvKey(tenant) == envKey {
|
||||
return worksmobileDomainClassificationTenant(tenant, tenantByID)
|
||||
}
|
||||
}
|
||||
return worksmobileDomainClassificationTenant(fallbackTenant, tenantByID)
|
||||
}
|
||||
|
||||
func worksmobileDomainIDEnvKeyFromEmail(email string) string {
|
||||
switch worksmobileEmailDomainName(email) {
|
||||
case "samaneng.com":
|
||||
return "SAMAN_DOMAIN_ID"
|
||||
case "hanmaceng.co.kr":
|
||||
return "HANMAC_DOMAIN_ID"
|
||||
case "baroncs.co.kr":
|
||||
return "GPDTDC_DOMAIN_ID"
|
||||
case "hallasanup.com":
|
||||
return "HALLA_DOMAIN_ID"
|
||||
case "brsw.kr":
|
||||
return "BARONGROUP_DOMAIN_ID"
|
||||
default:
|
||||
return worksmobileTenantDomainIDEnvKey(domain.Tenant{})
|
||||
}
|
||||
}
|
||||
|
||||
func worksmobileEmailDomainName(email string) string {
|
||||
address, err := mail.ParseAddress(strings.TrimSpace(email))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(address.Address, "@")
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(parts[1]))
|
||||
}
|
||||
|
||||
func worksmobileOrganizationEmail(user domain.User, domainTenant domain.Tenant) string {
|
||||
domainName := worksmobileTenantMailDomain(domainTenant)
|
||||
if domainName == "" {
|
||||
return ""
|
||||
}
|
||||
primaryEmail := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if worksmobileEmailDomainName(primaryEmail) == domainName {
|
||||
return primaryEmail
|
||||
}
|
||||
for _, alias := range BuildWorksmobileAliasEmails(user, domainTenant) {
|
||||
if worksmobileEmailDomainName(alias) == domainName {
|
||||
return alias
|
||||
}
|
||||
}
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(primaryEmail)
|
||||
if err != nil || localPart == "" {
|
||||
return ""
|
||||
}
|
||||
return localPart + "@" + domainName
|
||||
}
|
||||
|
||||
func worksmobileTenantDomainIDEnvKey(tenant domain.Tenant) string {
|
||||
if tenantHasDomain(tenant, "samaneng.com") || tenantMatchesAny(tenant, "saman", "삼안") {
|
||||
return "SAMAN_DOMAIN_ID"
|
||||
}
|
||||
if isHanmacWorksmobileTenant(tenant) {
|
||||
return "HANMAC_DOMAIN_ID"
|
||||
}
|
||||
if tenantMatchesAny(tenant, "gpdtdc", "총괄", "기술개발센터", "기술개발") {
|
||||
return "GPDTDC_DOMAIN_ID"
|
||||
}
|
||||
if isHallaWorksmobileTenant(tenant) {
|
||||
return "HALLA_DOMAIN_ID"
|
||||
}
|
||||
return "BARONGROUP_DOMAIN_ID"
|
||||
}
|
||||
|
||||
func worksmobileDomainIDFromEnv(key string) (int64, bool) {
|
||||
if key == "" {
|
||||
return 0, false
|
||||
}
|
||||
id, ok := parseDomainID(os.Getenv(key))
|
||||
return id, ok
|
||||
}
|
||||
|
||||
type worksmobileDomainEnvMapping struct {
|
||||
Key string
|
||||
Label string
|
||||
}
|
||||
|
||||
func worksmobileDomainEnvMappings() []worksmobileDomainEnvMapping {
|
||||
return []worksmobileDomainEnvMapping{
|
||||
{Key: "SAMAN_DOMAIN_ID", Label: "삼안"},
|
||||
{Key: "HANMAC_DOMAIN_ID", Label: "한맥기술"},
|
||||
{Key: "GPDTDC_DOMAIN_ID", Label: "총괄기획&기술개발센터"},
|
||||
{Key: "HALLA_DOMAIN_ID", Label: "한라산업개발"},
|
||||
{Key: "BARONGROUP_DOMAIN_ID", Label: "바론그룹"},
|
||||
}
|
||||
}
|
||||
|
||||
func WorksmobileDomainIDsFromEnv() []int64 {
|
||||
mappings := worksmobileDomainEnvMappings()
|
||||
result := make([]int64, 0, len(mappings))
|
||||
seen := map[int64]bool{}
|
||||
for _, mapping := range mappings {
|
||||
if id, ok := worksmobileDomainIDFromEnv(mapping.Key); ok && !seen[id] {
|
||||
seen[id] = true
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func WorksmobileDomainLabelForID(domainID int64) string {
|
||||
for _, mapping := range worksmobileDomainEnvMappings() {
|
||||
if id, ok := worksmobileDomainIDFromEnv(mapping.Key); ok && id == domainID {
|
||||
return mapping.Label
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isHanmacWorksmobileTenant(tenant domain.Tenant) bool {
|
||||
return tenantHasDomain(tenant, "hanmaceng.co.kr") || tenantMatchesAny(tenant, "hanmac", "한맥")
|
||||
}
|
||||
|
||||
func isHallaWorksmobileTenant(tenant domain.Tenant) bool {
|
||||
return tenantHasDomain(tenant, "hallasanup.com") || tenantMatchesAny(tenant, "halla", "hanlla", "한라산업개발")
|
||||
}
|
||||
|
||||
func tenantHasDomain(tenant domain.Tenant, domainName string) bool {
|
||||
domainName = strings.ToLower(strings.TrimSpace(domainName))
|
||||
for _, d := range tenant.Domains {
|
||||
if strings.EqualFold(strings.TrimSpace(d.Domain), domainName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantMatchesAny(tenant domain.Tenant, needles ...string) bool {
|
||||
haystack := strings.ToLower(strings.TrimSpace(tenant.Slug + " " + tenant.Name))
|
||||
for _, needle := range needles {
|
||||
if strings.Contains(haystack, strings.ToLower(strings.TrimSpace(needle))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func WorksmobileEnabled(rootConfig domain.JSONMap) bool {
|
||||
rawWorksmobile, ok := rootConfig["worksmobile"].(map[string]any)
|
||||
if !ok {
|
||||
if raw, ok := rootConfig["worksmobile"].(domain.JSONMap); ok {
|
||||
rawWorksmobile = map[string]any(raw)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
enabled, _ := rawWorksmobile["enabled"].(bool)
|
||||
return enabled
|
||||
}
|
||||
|
||||
func WorksmobileDomainMappings(rootConfig domain.JSONMap) map[string]int64 {
|
||||
result := map[string]int64{}
|
||||
rawWorksmobile, ok := rootConfig["worksmobile"].(map[string]any)
|
||||
if !ok {
|
||||
if raw, ok := rootConfig["worksmobile"].(domain.JSONMap); ok {
|
||||
rawWorksmobile = map[string]any(raw)
|
||||
} else {
|
||||
return result
|
||||
}
|
||||
}
|
||||
rawMappings, ok := rawWorksmobile["domainMappings"].(map[string]any)
|
||||
if !ok {
|
||||
if raw, ok := rawWorksmobile["domainMappings"].(domain.JSONMap); ok {
|
||||
rawMappings = map[string]any(raw)
|
||||
} else {
|
||||
return result
|
||||
}
|
||||
}
|
||||
for key, raw := range rawMappings {
|
||||
if id, ok := parseDomainID(raw); ok {
|
||||
result[strings.ToLower(strings.TrimSpace(key))] = id
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseDomainID(raw any) (int64, bool) {
|
||||
switch value := raw.(type) {
|
||||
case int:
|
||||
return int64(value), value > 0
|
||||
case int64:
|
||||
return value, value > 0
|
||||
case float64:
|
||||
id := int64(value)
|
||||
return id, id > 0
|
||||
case string:
|
||||
id, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
|
||||
return id, err == nil && id > 0
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func metadataString(metadata domain.JSONMap, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value, ok := metadata[key]; ok {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func metadataEmployeeNumber(metadata domain.JSONMap) string {
|
||||
for _, key := range []string{"employee_id", "employeeNumber", "employee_number"} {
|
||||
value, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if normalized := normalizeMetadataEmployeeNumber(value); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeMetadataEmployeeNumber(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
case map[string]any:
|
||||
return normalizeMetadataCharacterMap(v)
|
||||
case domain.JSONMap:
|
||||
return normalizeMetadataCharacterMap(map[string]any(v))
|
||||
case map[string]string:
|
||||
converted := make(map[string]any, len(v))
|
||||
for key, value := range v {
|
||||
converted[key] = value
|
||||
}
|
||||
return normalizeMetadataCharacterMap(converted)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMetadataCharacterMap(value map[string]any) string {
|
||||
type characterEntry struct {
|
||||
index int
|
||||
value string
|
||||
}
|
||||
entries := make([]characterEntry, 0, len(value))
|
||||
for key, raw := range value {
|
||||
index, err := strconv.Atoi(key)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
part, ok := raw.(string)
|
||||
if !ok || part == "" {
|
||||
return ""
|
||||
}
|
||||
entries = append(entries, characterEntry{index: index, value: part})
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].index < entries[j].index
|
||||
})
|
||||
var builder strings.Builder
|
||||
for _, entry := range entries {
|
||||
builder.WriteString(entry.value)
|
||||
}
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
func metadataBool(metadata domain.JSONMap, keys ...string) bool {
|
||||
value, _ := metadataOptionalBool(metadata, keys...)
|
||||
return value
|
||||
}
|
||||
|
||||
func metadataOptionalBool(metadata domain.JSONMap, keys ...string) (bool, bool) {
|
||||
for _, key := range keys {
|
||||
value, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return v, true
|
||||
case string:
|
||||
normalized := strings.ToLower(strings.TrimSpace(v))
|
||||
if normalized == "true" || normalized == "1" || normalized == "yes" {
|
||||
return true, true
|
||||
}
|
||||
if normalized == "false" || normalized == "0" || normalized == "no" {
|
||||
return false, true
|
||||
}
|
||||
case int:
|
||||
return v != 0, true
|
||||
case float64:
|
||||
return v != 0, true
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func metadataStringList(metadata domain.JSONMap, keys ...string) []string {
|
||||
for _, key := range keys {
|
||||
value, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
return splitWorksmobileAliasValues(v)
|
||||
case []any:
|
||||
values := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
values = append(values, strings.TrimSpace(fmt.Sprint(item)))
|
||||
}
|
||||
return splitWorksmobileAliasValues(values)
|
||||
case string:
|
||||
return splitWorksmobileAliasValues([]string{v})
|
||||
default:
|
||||
return splitWorksmobileAliasValues([]string{fmt.Sprint(v)})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitWorksmobileAliasValues(values []string) []string {
|
||||
result := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
fields := strings.FieldsFunc(value, func(r rune) bool {
|
||||
return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t'
|
||||
})
|
||||
for _, field := range fields {
|
||||
if trimmed := strings.TrimSpace(field); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
856
baron-sso/backend/internal/service/worksmobile_mapper_test.go
Normal file
856
baron-sso/backend/internal/service/worksmobile_mapper_test.go
Normal file
@@ -0,0 +1,856 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildWorksmobileOrgUnitPayloadUsesTenantExternalKeyAndEnvDomainClassification(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
parentID := "11111111-1111-1111-1111-111111111111"
|
||||
tenant := domain.Tenant{
|
||||
ID: "22222222-2222-2222-2222-222222222222",
|
||||
Slug: "tech-dev-center",
|
||||
Name: "Saman Engineering",
|
||||
ParentID: &parentID,
|
||||
Domains: []domain.TenantDomain{
|
||||
{Domain: "samaneng.com"},
|
||||
},
|
||||
}
|
||||
rootConfig := domain.JSONMap{
|
||||
"worksmobile": map[string]any{
|
||||
"domainMappings": map[string]any{
|
||||
"samaneng.com": float64(9999),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileOrgUnitPayload(tenant, rootConfig, 7)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1001), payload.DomainID)
|
||||
require.Equal(t, "Saman Engineering", payload.OrgUnitName)
|
||||
require.Equal(t, "tech-dev-center@samaneng.com", payload.Email)
|
||||
require.Equal(t, tenant.ID, payload.OrgUnitExternalKey)
|
||||
require.Equal(t, "externalKey:"+parentID, payload.ParentOrgUnitID)
|
||||
require.Equal(t, 7, payload.DisplayOrder)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileOrgUnitPayloadUsesWorksmobileMailDomainForBarongroup(t *testing.T) {
|
||||
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
|
||||
tenant := domain.Tenant{
|
||||
ID: "11111111-1111-1111-1111-111111111111",
|
||||
Slug: "jangheon",
|
||||
Name: "(주)장헌",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "jangheon.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileOrgUnitPayloadForDomainTenant(tenant, tenant, nil, 1)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1004), payload.DomainID)
|
||||
require.Equal(t, "jangheon@brsw.kr", payload.Email)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileOrgUnitPayloadDefaultsDisplayOrderToOne(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenant := domain.Tenant{
|
||||
ID: "11111111-1111-1111-1111-111111111111",
|
||||
Slug: "tech-dev-center",
|
||||
Name: "기술개발센터",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileOrgUnitPayload(tenant, nil, 0)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, payload.DisplayOrder)
|
||||
}
|
||||
|
||||
func TestNormalizeRootChildWorksmobileOrgUnitParentClearsCrossDomainParent(t *testing.T) {
|
||||
rootID := "038326b6-954a-48a7-a85f-efd83f62b82a"
|
||||
payload := WorksmobileOrgUnitPayload{ParentOrgUnitID: "externalKey:" + rootID}
|
||||
tenant := domain.Tenant{ParentID: &rootID}
|
||||
|
||||
normalized := normalizeWorksmobileOrgUnitParent(payload, tenant, nil, rootID)
|
||||
|
||||
require.Empty(t, normalized.ParentOrgUnitID)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadMapsBaronUserAndPrimaryTenant(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
rootTenantID := "11111111-1111-1111-1111-111111111111"
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "john1@samaneng.com",
|
||||
Name: "John Doe",
|
||||
Phone: "+19144812222",
|
||||
Position: "Manager",
|
||||
JobTitle: "Sales management",
|
||||
TenantID: &tenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"employee_id": "AB001",
|
||||
},
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "sales",
|
||||
Name: "Sales",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: &rootTenantID,
|
||||
}
|
||||
rootTenant := domain.Tenant{
|
||||
ID: rootTenantID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
rootConfig := domain.JSONMap{
|
||||
"worksmobile": map[string]any{
|
||||
"domainMappings": map[string]any{
|
||||
"samaneng.com": int64(9999),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
|
||||
user,
|
||||
tenant,
|
||||
map[string]domain.Tenant{
|
||||
rootTenantID: rootTenant,
|
||||
tenantID: tenant,
|
||||
},
|
||||
rootConfig,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1001), payload.DomainID)
|
||||
require.Equal(t, "john1@samaneng.com", payload.Email)
|
||||
require.Equal(t, user.ID, payload.UserExternalKey)
|
||||
require.Equal(t, "John Doe", payload.UserName.LastName)
|
||||
require.Equal(t, "+19144812222", payload.CellPhone)
|
||||
require.Equal(t, "AB001", payload.EmployeeNumber)
|
||||
require.Equal(t, "Sales management", payload.Task)
|
||||
require.Empty(t, payload.PrivateEmail)
|
||||
require.Empty(t, payload.AliasEmails)
|
||||
require.Equal(t, "ko_KR", payload.Locale)
|
||||
require.Empty(t, payload.PasswordConfig.PasswordCreationType)
|
||||
require.Empty(t, payload.PasswordConfig.Password)
|
||||
require.Len(t, payload.Organizations, 1)
|
||||
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
|
||||
require.True(t, payload.Organizations[0].Primary)
|
||||
require.Equal(t, "externalKey:"+tenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadDeduplicatesKoreanCountryCodeInCellPhone(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "john1@samaneng.com",
|
||||
Name: "John Doe",
|
||||
Phone: "+82 +821091917771",
|
||||
TenantID: &tenantID,
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "saman",
|
||||
Name: "Saman",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "+821091917771", payload.CellPhone)
|
||||
}
|
||||
|
||||
func TestWorksmobileUserPayloadJSONOmitsEmptyPasswordConfig(t *testing.T) {
|
||||
data, err := json.Marshal(WorksmobileUserPayload{
|
||||
DomainID: 1001,
|
||||
Email: "target@samaneng.com",
|
||||
UserExternalKey: "user-1",
|
||||
UserName: WorksmobileUserName{LastName: "Target"},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(data), "passwordConfig")
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadOmitsOrganizationsForSamanRootTenant(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "root-user@samaneng.com",
|
||||
Name: "Root User",
|
||||
JobTitle: "Advisor",
|
||||
TenantID: &tenantID,
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1001), payload.DomainID)
|
||||
require.Equal(t, "root-user@samaneng.com", payload.Email)
|
||||
require.Equal(t, "Advisor", payload.Task)
|
||||
require.Empty(t, payload.Organizations)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadNormalizesLegacyCharacterMapEmployeeID(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "john1@samaneng.com",
|
||||
Name: "John Doe",
|
||||
TenantID: &tenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"employee_id": map[string]any{
|
||||
"0": "j",
|
||||
"1": "o",
|
||||
"2": "h",
|
||||
"3": "n",
|
||||
},
|
||||
},
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "saman",
|
||||
Name: "Saman",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "john", payload.EmployeeNumber)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadMapsAdditionalAppointmentsToOrgUnits(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
t.Setenv("HANMAC_DOMAIN_ID", "1002")
|
||||
samanRootID := "11111111-1111-1111-1111-111111111111"
|
||||
hanmacRootID := "22222222-2222-2222-2222-222222222222"
|
||||
primaryTenantID := "33333333-3333-3333-3333-333333333333"
|
||||
secondaryTenantID := "55555555-5555-5555-5555-555555555555"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "john1@samaneng.com",
|
||||
Name: "John Doe",
|
||||
Phone: "+19144812222",
|
||||
TenantID: &primaryTenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": secondaryTenantID,
|
||||
"isPrimary": false,
|
||||
"isManager": true,
|
||||
"jobTitle": "PM",
|
||||
"position": "팀장",
|
||||
},
|
||||
map[string]any{
|
||||
"tenantId": primaryTenantID,
|
||||
"isPrimary": true,
|
||||
"isOwner": true,
|
||||
"jobTitle": "Engineering",
|
||||
"position": "책임",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
samanRoot := domain.Tenant{
|
||||
ID: samanRootID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
hanmacRoot := domain.Tenant{
|
||||
ID: hanmacRootID,
|
||||
Slug: "hanmac",
|
||||
Name: "한맥기술",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}},
|
||||
}
|
||||
primaryTenant := domain.Tenant{
|
||||
ID: primaryTenantID,
|
||||
Slug: "saman-sales",
|
||||
Name: "Saman Sales",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: &samanRootID,
|
||||
}
|
||||
secondaryTenant := domain.Tenant{
|
||||
ID: secondaryTenantID,
|
||||
Slug: "hanmac-sales",
|
||||
Name: "Hanmac Sales",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: &hanmacRootID,
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
|
||||
user,
|
||||
primaryTenant,
|
||||
map[string]domain.Tenant{
|
||||
samanRootID: samanRoot,
|
||||
hanmacRootID: hanmacRoot,
|
||||
primaryTenantID: primaryTenant,
|
||||
secondaryTenantID: secondaryTenant,
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Engineering", payload.Task)
|
||||
require.Len(t, payload.Organizations, 2)
|
||||
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
|
||||
require.True(t, payload.Organizations[0].Primary)
|
||||
require.Equal(t, "externalKey:"+primaryTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
|
||||
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
|
||||
require.Nil(t, payload.Organizations[0].OrgUnits[0].IsManager)
|
||||
require.Equal(t, int64(1002), payload.Organizations[1].DomainID)
|
||||
require.False(t, payload.Organizations[1].Primary)
|
||||
require.Equal(t, "externalKey:"+secondaryTenantID, payload.Organizations[1].OrgUnits[0].OrgUnitID)
|
||||
require.True(t, payload.Organizations[1].OrgUnits[0].Primary)
|
||||
require.NotNil(t, payload.Organizations[1].OrgUnits[0].IsManager)
|
||||
require.True(t, *payload.Organizations[1].OrgUnits[0].IsManager)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadKeepsPrimaryTenantWhenEmailDomainAppointmentExists(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
rootTenantID := "9caf62e1-297d-4e8f-870b-61780998bbeb"
|
||||
primaryTenantID := "1edc196d-020c-4519-9ec4-3d23b99076e6"
|
||||
user := domain.User{
|
||||
ID: "64231465-d5c0-4085-b4a2-603b90834f86",
|
||||
Email: "evenlee@samaneng.com",
|
||||
Name: "이용운",
|
||||
JobTitle: "부사장",
|
||||
TenantID: &primaryTenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": rootTenantID,
|
||||
"tenantSlug": "saman",
|
||||
"tenantName": "삼안",
|
||||
"assignmentSource": "email_domain",
|
||||
"sourceDomain": "samaneng.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
rootTenant := domain.Tenant{
|
||||
ID: rootTenantID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
primaryTenant := domain.Tenant{
|
||||
ID: primaryTenantID,
|
||||
Slug: "asset-management",
|
||||
Name: "자산관리",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: &rootTenantID,
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
|
||||
user,
|
||||
primaryTenant,
|
||||
map[string]domain.Tenant{
|
||||
rootTenantID: rootTenant,
|
||||
primaryTenantID: primaryTenant,
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, payload.Organizations, 1)
|
||||
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
|
||||
require.True(t, payload.Organizations[0].Primary)
|
||||
require.Len(t, payload.Organizations[0].OrgUnits, 1)
|
||||
require.Equal(t, "externalKey:"+primaryTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
|
||||
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadKeepsFirstAffiliationPrimaryWhenBaronRepresentativeIsGPDTDC(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
|
||||
samanRootID := "11111111-1111-1111-1111-111111111111"
|
||||
gpdtdcID := "5530ca6e-c5e6-4bf0-84d6-76c6a8fb70ee"
|
||||
firstTenantID := "33333333-3333-3333-3333-333333333333"
|
||||
secondTenantID := "55555555-5555-5555-5555-555555555555"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "gpdtdc-dual@samaneng.com",
|
||||
Name: "GPDTDC Dual User",
|
||||
TenantID: &gpdtdcID,
|
||||
Metadata: domain.JSONMap{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": firstTenantID,
|
||||
"isPrimary": true,
|
||||
"jobTitle": "First affiliation task",
|
||||
},
|
||||
map[string]any{
|
||||
"tenantId": secondTenantID,
|
||||
"isPrimary": false,
|
||||
"jobTitle": "Second affiliation task",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
gpdtdcTenant := domain.Tenant{
|
||||
ID: gpdtdcID,
|
||||
Slug: "gpdtdc",
|
||||
Name: "총괄기획&기술개발센터",
|
||||
}
|
||||
samanRoot := domain.Tenant{
|
||||
ID: samanRootID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
firstTenant := domain.Tenant{
|
||||
ID: firstTenantID,
|
||||
Slug: "rnd-center",
|
||||
Name: "삼안기술개발센터",
|
||||
Type: domain.TenantTypeOrganization,
|
||||
ParentID: &samanRootID,
|
||||
}
|
||||
secondTenant := domain.Tenant{
|
||||
ID: secondTenantID,
|
||||
Slug: "tdc",
|
||||
Name: "기술개발센터",
|
||||
ParentID: &gpdtdcID,
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
|
||||
user,
|
||||
gpdtdcTenant,
|
||||
map[string]domain.Tenant{
|
||||
samanRootID: samanRoot,
|
||||
gpdtdcID: gpdtdcTenant,
|
||||
firstTenantID: firstTenant,
|
||||
secondTenantID: secondTenant,
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1001), payload.DomainID)
|
||||
require.Equal(t, "First affiliation task", payload.Task)
|
||||
require.Len(t, payload.Organizations, 2)
|
||||
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
|
||||
require.True(t, payload.Organizations[0].Primary)
|
||||
require.Equal(t, "externalKey:"+firstTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
|
||||
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
|
||||
require.Equal(t, int64(1003), payload.Organizations[1].DomainID)
|
||||
require.False(t, payload.Organizations[1].Primary)
|
||||
require.Equal(t, "externalKey:"+secondTenantID, payload.Organizations[1].OrgUnits[0].OrgUnitID)
|
||||
require.True(t, payload.Organizations[1].OrgUnits[0].Primary)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadUsesEmailDomainForAccountDomainWhenPrimaryOrgIsGPDTDC(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
|
||||
samanID := "11111111-1111-1111-1111-111111111111"
|
||||
gpdtdcID := "5530ca6e-c5e6-4bf0-84d6-76c6a8fb70ee"
|
||||
leafTenantID := "52f06c97-9d6f-4819-971b-43303062e193"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "dhlee@samaneng.com",
|
||||
Name: "GPDTDC Saman User",
|
||||
TenantID: &leafTenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": leafTenantID,
|
||||
"isPrimary": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
samanTenant := domain.Tenant{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
gpdtdcTenant := domain.Tenant{
|
||||
ID: gpdtdcID,
|
||||
Slug: "gpdtdc",
|
||||
Name: "총괄기획&기술개발센터",
|
||||
}
|
||||
leafTenant := domain.Tenant{
|
||||
ID: leafTenantID,
|
||||
Slug: "infra-bim2",
|
||||
Name: "인프라 BIM2",
|
||||
ParentID: &gpdtdcID,
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
|
||||
user,
|
||||
leafTenant,
|
||||
map[string]domain.Tenant{
|
||||
samanID: samanTenant,
|
||||
gpdtdcID: gpdtdcTenant,
|
||||
leafTenantID: leafTenant,
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1001), payload.DomainID)
|
||||
require.Len(t, payload.Organizations, 1)
|
||||
require.Equal(t, int64(1003), payload.Organizations[0].DomainID)
|
||||
require.True(t, payload.Organizations[0].Primary)
|
||||
require.Equal(t, "dhlee@baroncs.co.kr", payload.Organizations[0].Email)
|
||||
require.Equal(t, "externalKey:"+leafTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
|
||||
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
|
||||
}
|
||||
|
||||
func TestWorksmobileUserPayloadJSONIncludesFalsePrimaryFields(t *testing.T) {
|
||||
payload := WorksmobileUserPayload{
|
||||
Email: "user@samaneng.com",
|
||||
Organizations: []WorksmobileUserOrganization{
|
||||
{
|
||||
DomainID: 1001,
|
||||
Primary: true,
|
||||
OrgUnits: []WorksmobileUserOrgUnit{
|
||||
{OrgUnitID: "externalKey:primary", Primary: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
DomainID: 1003,
|
||||
Primary: false,
|
||||
OrgUnits: []WorksmobileUserOrgUnit{
|
||||
{OrgUnitID: "externalKey:secondary", Primary: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), `"primary":false`)
|
||||
require.Contains(t, string(data), `"orgUnitId":"externalKey:secondary","primary":false`)
|
||||
}
|
||||
|
||||
func TestResolveWorksmobileDomainIDFromTenantIgnoresRootDomainMappings(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
rootConfig := domain.JSONMap{
|
||||
"worksmobile": map[string]any{
|
||||
"domainMappings": map[string]any{
|
||||
"samaneng.com": int64(9999),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, err := ResolveWorksmobileDomainIDFromTenant(
|
||||
domain.Tenant{
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
},
|
||||
rootConfig,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1001), got)
|
||||
}
|
||||
|
||||
func TestResolveWorksmobileDomainIDFromTenantRequiresFamilyDomainEnv(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "")
|
||||
|
||||
rootConfig := domain.JSONMap{
|
||||
"worksmobile": map[string]any{
|
||||
"domainMappings": map[string]any{
|
||||
"samaneng.com": int64(9999),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ResolveWorksmobileDomainIDFromTenant(
|
||||
domain.Tenant{
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
},
|
||||
rootConfig,
|
||||
)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "SAMAN_DOMAIN_ID")
|
||||
}
|
||||
|
||||
func TestResolveWorksmobileDomainIDUsesEnvFamilyFallbacks(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
t.Setenv("HANMAC_DOMAIN_ID", "1002")
|
||||
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
|
||||
t.Setenv("HALLA_DOMAIN_ID", "1005")
|
||||
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tenant domain.Tenant
|
||||
want int64
|
||||
}{
|
||||
{
|
||||
name: "saman",
|
||||
tenant: domain.Tenant{Slug: "saman", Domains: []domain.TenantDomain{{Domain: "samaneng.com"}}},
|
||||
want: 1001,
|
||||
},
|
||||
{
|
||||
name: "hanmac",
|
||||
tenant: domain.Tenant{Slug: "hanmac", Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}}},
|
||||
want: 1002,
|
||||
},
|
||||
{
|
||||
name: "gpdtdc",
|
||||
tenant: domain.Tenant{Slug: "gpdtdc", Name: "총괄기획&기술개발센터"},
|
||||
want: 1003,
|
||||
},
|
||||
{
|
||||
name: "halla",
|
||||
tenant: domain.Tenant{Slug: "halla", Name: "한라산업개발", Domains: []domain.TenantDomain{{Domain: "hallasanup.com"}}},
|
||||
want: 1005,
|
||||
},
|
||||
{
|
||||
name: "hanlla legacy slug",
|
||||
tenant: domain.Tenant{Slug: "hanlla", Name: "한라산업개발"},
|
||||
want: 1005,
|
||||
},
|
||||
{
|
||||
name: "barongroup fallback",
|
||||
tenant: domain.Tenant{Slug: "family-company", Name: "기타 가족사"},
|
||||
want: 1004,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ResolveWorksmobileDomainIDFromTenant(tt.tenant, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWorksmobileAccountDomainIDUsesHallaEmailDomain(t *testing.T) {
|
||||
t.Setenv("HALLA_DOMAIN_ID", "1005")
|
||||
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
|
||||
tenant := domain.Tenant{
|
||||
Slug: "halla",
|
||||
Name: "한라산업개발",
|
||||
Domains: []domain.TenantDomain{{Domain: "hallasanup.com"}},
|
||||
}
|
||||
|
||||
got, err := ResolveWorksmobileAccountDomainIDFromEmail("user@hallasanup.com", tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1005), got)
|
||||
}
|
||||
|
||||
func TestWorksmobileDomainIDsFromEnvIncludesHallaBeforeFallback(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
t.Setenv("HANMAC_DOMAIN_ID", "1002")
|
||||
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
|
||||
t.Setenv("HALLA_DOMAIN_ID", "1005")
|
||||
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
|
||||
|
||||
got := WorksmobileDomainIDsFromEnv()
|
||||
|
||||
require.Equal(t, []int64{1001, 1002, 1003, 1005, 1004}, got)
|
||||
require.Equal(t, "한라산업개발", WorksmobileDomainLabelForID(1005))
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadUsesHallaDomain(t *testing.T) {
|
||||
t.Setenv("HALLA_DOMAIN_ID", "1005")
|
||||
t.Setenv("WORKS_DEFAULT_DOMAIN_HALLA", "hallasanup.com")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "main@hallasanup.com",
|
||||
Name: "Halla User",
|
||||
TenantID: &tenantID,
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "halla",
|
||||
Name: "한라산업개발",
|
||||
Domains: []domain.TenantDomain{{Domain: "hallasanup.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1005), payload.DomainID)
|
||||
require.Equal(t, "main@hallasanup.com", payload.Email)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadAddsHanmacEmployeeAlias(t *testing.T) {
|
||||
t.Setenv("HANMAC_DOMAIN_ID", "1002")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "main@hanmaceng.co.kr",
|
||||
Name: "Hanmac User",
|
||||
TenantID: &tenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"employee_id": "HM001",
|
||||
"personal_email": "private@example.com",
|
||||
},
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "hanmac",
|
||||
Name: "한맥",
|
||||
Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1002), payload.DomainID)
|
||||
require.Equal(t, []string{"hm001@hanmaceng.co.kr"}, payload.AliasEmails)
|
||||
require.Empty(t, payload.PrivateEmail)
|
||||
require.Equal(t, "ko_KR", payload.Locale)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadAddsMultipleMetadataAliases(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "main@samaneng.com",
|
||||
Name: "Saman User",
|
||||
TenantID: &tenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"aliasEmails": []any{"alias1@samaneng.com", "alias2@samaneng.com", "main@samaneng.com"},
|
||||
},
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"alias1@samaneng.com", "alias2@samaneng.com"}, payload.AliasEmails)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadAddsSubEmailMetadataAlias(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "main@samaneng.com",
|
||||
Name: "Saman User",
|
||||
TenantID: &tenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"sub_email": "alias1@hanmaceng.co.kr",
|
||||
"secondary_emails": []any{"alias2@hanmaceng.co.kr"},
|
||||
},
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"alias1@hanmaceng.co.kr", "alias2@hanmaceng.co.kr"}, payload.AliasEmails)
|
||||
}
|
||||
|
||||
func TestBuildWorksmobileUserPayloadKeepsSubEmailAliasWithPrimaryLocalPart(t *testing.T) {
|
||||
t.Setenv("SAMAN_DOMAIN_ID", "1001")
|
||||
tenantID := "33333333-3333-3333-3333-333333333333"
|
||||
user := domain.User{
|
||||
ID: "44444444-4444-4444-4444-444444444444",
|
||||
Email: "ypshim@samaneng.com",
|
||||
Name: "Saman User",
|
||||
TenantID: &tenantID,
|
||||
Metadata: domain.JSONMap{
|
||||
"sub_email": "ypshim@hanmaceng.co.kr",
|
||||
},
|
||||
}
|
||||
tenant := domain.Tenant{
|
||||
ID: tenantID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
|
||||
}
|
||||
|
||||
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"ypshim@hanmaceng.co.kr"}, payload.AliasEmails)
|
||||
}
|
||||
|
||||
func TestValidateWorksmobileAliasEmailsAllowsSameLocalPartOnDifferentDomains(t *testing.T) {
|
||||
err := ValidateWorksmobileAliasEmails(
|
||||
"main@samaneng.com",
|
||||
[]string{"main@hanmaceng.co.kr"},
|
||||
map[string]string{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ValidateWorksmobileAliasEmails(
|
||||
"main@samaneng.com",
|
||||
[]string{"main@samaneng.com"},
|
||||
map[string]string{},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "duplicates")
|
||||
|
||||
err = ValidateWorksmobileAliasEmails(
|
||||
"main@samaneng.com",
|
||||
[]string{"alias@hanmaceng.co.kr"},
|
||||
map[string]string{"alias@hanmaceng.co.kr": "existing-user"},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "이미 사용 중")
|
||||
}
|
||||
|
||||
func containsAny(value string, candidates string) bool {
|
||||
return strings.ContainsAny(value, candidates)
|
||||
}
|
||||
|
||||
func TestWorksmobileUserStatusAction(t *testing.T) {
|
||||
require.Equal(t, WorksmobileUserActionUpsert, WorksmobileUserStatusAction(domain.UserStatusActive))
|
||||
require.Equal(t, WorksmobileUserActionUpsert, WorksmobileUserStatusAction(domain.UserStatusTemporaryLeave))
|
||||
require.Equal(t, WorksmobileUserActionSuspend, WorksmobileUserStatusAction(domain.UserStatusSuspended))
|
||||
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction(domain.UserStatusExtendedLeave))
|
||||
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction(domain.UserStatusBaronGuest))
|
||||
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction(domain.UserStatusArchived))
|
||||
require.Equal(t, WorksmobileUserActionUpsert, WorksmobileUserStatusAction("leave_of_absence"))
|
||||
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction("baron_only"))
|
||||
}
|
||||
|
||||
func TestValidateWorksmobileExternalKeyRejectsUnsupportedCharacters(t *testing.T) {
|
||||
require.NoError(t, ValidateWorksmobileExternalKey("44444444-4444-4444-4444-444444444444"))
|
||||
require.Error(t, ValidateWorksmobileExternalKey("user/with/slash"))
|
||||
require.Error(t, ValidateWorksmobileExternalKey("user#with-hash"))
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
const (
|
||||
worksmobileRelayLeaderLockKey = "baron:worksmobile:relay:leader"
|
||||
worksmobileRelayLeaderLockTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
const worksmobileRelayLeaderRenewScript = `
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("EXPIRE", KEYS[1], ARGV[2])
|
||||
end
|
||||
return 0
|
||||
`
|
||||
|
||||
type WorksmobileRedisRelayLeaderLock struct {
|
||||
client *redis.Client
|
||||
key string
|
||||
ttl time.Duration
|
||||
ownerID string
|
||||
}
|
||||
|
||||
func NewWorksmobileRedisRelayLeaderLock(redisService *RedisService) *WorksmobileRedisRelayLeaderLock {
|
||||
if redisService == nil || redisService.Client == nil {
|
||||
return nil
|
||||
}
|
||||
return &WorksmobileRedisRelayLeaderLock{
|
||||
client: redisService.Client,
|
||||
key: worksmobileRelayLeaderLockKey,
|
||||
ttl: worksmobileRelayLeaderLockTTL,
|
||||
ownerID: newWorksmobileRelayLeaderOwnerID(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *WorksmobileRedisRelayLeaderLock) EnsureLeadership(ctx context.Context) (bool, error) {
|
||||
if l == nil || l.client == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
acquired, err := l.client.SetNX(ctx, l.key, l.ownerID, l.ttl).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if acquired {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
ttlSeconds := int64(l.ttl / time.Second)
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = 30
|
||||
}
|
||||
result, err := l.client.Eval(ctx, worksmobileRelayLeaderRenewScript, []string{l.key}, l.ownerID, ttlSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func newWorksmobileRelayLeaderOwnerID() string {
|
||||
hostname, _ := os.Hostname()
|
||||
if hostname == "" {
|
||||
hostname = "unknown-host"
|
||||
}
|
||||
randomBytes := make([]byte, 8)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return fmt.Sprintf("%s:%d:%d", hostname, os.Getpid(), time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s:%d:%s", hostname, os.Getpid(), hex.EncodeToString(randomBytes))
|
||||
}
|
||||
302
baron-sso/backend/internal/service/worksmobile_relay_worker.go
Normal file
302
baron-sso/backend/internal/service/worksmobile_relay_worker.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WorksmobileRelayWorker struct {
|
||||
repo repository.WorksmobileOutboxRepository
|
||||
client WorksmobileDirectoryClient
|
||||
leaderLock WorksmobileRelayLeaderLock
|
||||
interval time.Duration
|
||||
batchLimit int
|
||||
}
|
||||
|
||||
type WorksmobileRelayLeaderLock interface {
|
||||
EnsureLeadership(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
func NewWorksmobileRelayWorker(repo repository.WorksmobileOutboxRepository, client WorksmobileDirectoryClient) *WorksmobileRelayWorker {
|
||||
return &WorksmobileRelayWorker{
|
||||
repo: repo,
|
||||
client: client,
|
||||
interval: 3 * time.Second,
|
||||
batchLimit: 10,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorksmobileRelayWorker) SetLeaderLock(lock WorksmobileRelayLeaderLock) {
|
||||
w.leaderLock = lock
|
||||
}
|
||||
|
||||
func (w *WorksmobileRelayWorker) SetBatchLimit(limit int) {
|
||||
if limit <= 0 {
|
||||
return
|
||||
}
|
||||
w.batchLimit = limit
|
||||
}
|
||||
|
||||
func (w *WorksmobileRelayWorker) Start(ctx context.Context) {
|
||||
if w.repo == nil || w.client == nil {
|
||||
slog.Warn("Worksmobile relay worker disabled")
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if err := w.ProcessOnce(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
slog.Warn("Worksmobile relay tick failed", "error", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorksmobileRelayWorker) ProcessOnce(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
err = fmt.Errorf("worksmobile relay panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
|
||||
if w.leaderLock != nil {
|
||||
isLeader, err := w.leaderLock.EnsureLeadership(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isLeader {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
jobs, err := w.repo.ListReady(ctx, w.batchLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
jobs = sortWorksmobileReadyJobs(jobs)
|
||||
for _, job := range jobs {
|
||||
if err := w.processJob(ctx, job); err != nil {
|
||||
slog.Warn("Worksmobile relay job failed", "jobID", job.ID, "resourceType", job.ResourceType, "resourceID", job.ResourceID, "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WorksmobileRelayWorker) processJob(ctx context.Context, job domain.WorksmobileOutbox) error {
|
||||
claimed, err := w.repo.MarkProcessing(ctx, job.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !claimed {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = w.dispatch(ctx, job)
|
||||
if err != nil {
|
||||
nextAttempt := time.Now().Add(worksmobileRetryDelay(job.RetryCount))
|
||||
_ = w.repo.MarkFailed(ctx, job.ID, err.Error(), nextAttempt)
|
||||
return err
|
||||
}
|
||||
return w.repo.MarkProcessed(ctx, job.ID)
|
||||
}
|
||||
|
||||
func (w *WorksmobileRelayWorker) dispatch(ctx context.Context, job domain.WorksmobileOutbox) error {
|
||||
if job.Action == domain.WorksmobileActionDryRun {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch job.ResourceType {
|
||||
case domain.WorksmobileResourceOrgUnit:
|
||||
if job.Action == domain.WorksmobileActionDelete {
|
||||
return w.client.DeleteOrgUnit(ctx, stringValue(job.Payload["worksmobileId"]))
|
||||
}
|
||||
if job.Action != domain.WorksmobileActionUpsert {
|
||||
return nil
|
||||
}
|
||||
var payload WorksmobileOrgUnitPayload
|
||||
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.client.UpsertOrgUnit(ctx, payload, stringValue(job.Payload["matchLocalPart"]))
|
||||
case domain.WorksmobileResourceUser:
|
||||
switch job.Action {
|
||||
case domain.WorksmobileActionUpsert:
|
||||
var payload WorksmobileUserPayload
|
||||
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
|
||||
return err
|
||||
}
|
||||
aliasEmails := append([]string(nil), payload.AliasEmails...)
|
||||
payload.AliasEmails = nil
|
||||
if err := w.client.UpsertUser(ctx, payload); err != nil {
|
||||
return fmt.Errorf("worksmobile user upsert failed: %w", err)
|
||||
}
|
||||
for _, aliasEmail := range aliasEmails {
|
||||
if err := w.client.AddUserAliasEmail(ctx, payload.Email, aliasEmail); err != nil {
|
||||
return fmt.Errorf("worksmobile user alias add failed: %w", err)
|
||||
}
|
||||
}
|
||||
if stringValue(job.Payload["baronStatus"]) == domain.UserStatusActive {
|
||||
if err := w.client.SetUserActive(ctx, worksmobileOutboxUserIdentifier(job), true); err != nil {
|
||||
if isWorksmobileSCIMTokenNotConfiguredError(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("worksmobile user set active failed: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case domain.WorksmobileActionDelete:
|
||||
return w.client.DeleteUser(ctx, worksmobileOutboxUserIdentifier(job))
|
||||
case domain.WorksmobileActionSuspend:
|
||||
return w.client.SetUserActive(ctx, worksmobileOutboxUserIdentifier(job), false)
|
||||
case domain.WorksmobileActionPasswordReset:
|
||||
var payload WorksmobilePasswordResetPayload
|
||||
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
|
||||
return err
|
||||
}
|
||||
identifier := strings.TrimSpace(payload.Email)
|
||||
if identifier == "" {
|
||||
identifier = worksmobileOutboxUserIdentifier(job)
|
||||
}
|
||||
return w.client.ResetUserPassword(ctx, identifier, payload.PasswordConfig.Password)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isWorksmobileSCIMTokenNotConfiguredError(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "worksmobile scim token is not configured")
|
||||
}
|
||||
|
||||
func sortWorksmobileReadyJobs(jobs []domain.WorksmobileOutbox) []domain.WorksmobileOutbox {
|
||||
sorted := append([]domain.WorksmobileOutbox(nil), jobs...)
|
||||
depthByID := worksmobileOrgUnitDepths(sorted)
|
||||
sort.SliceStable(sorted, func(i, j int) bool {
|
||||
leftClass := worksmobileRelayOrderClass(sorted[i])
|
||||
rightClass := worksmobileRelayOrderClass(sorted[j])
|
||||
if leftClass != rightClass {
|
||||
return leftClass < rightClass
|
||||
}
|
||||
leftDepth := depthByID[sorted[i].ID]
|
||||
rightDepth := depthByID[sorted[j].ID]
|
||||
if leftDepth != rightDepth {
|
||||
return leftDepth < rightDepth
|
||||
}
|
||||
return sorted[i].CreatedAt.Before(sorted[j].CreatedAt)
|
||||
})
|
||||
return sorted
|
||||
}
|
||||
|
||||
func worksmobileRelayOrderClass(job domain.WorksmobileOutbox) int {
|
||||
if job.ResourceType == domain.WorksmobileResourceOrgUnit && job.Action == domain.WorksmobileActionUpsert {
|
||||
return 0
|
||||
}
|
||||
if job.ResourceType == domain.WorksmobileResourceUser {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func worksmobileOrgUnitDepths(jobs []domain.WorksmobileOutbox) map[string]int {
|
||||
type orgUnitJob struct {
|
||||
jobID string
|
||||
parentKey string
|
||||
}
|
||||
byExternalKey := map[string]orgUnitJob{}
|
||||
for _, job := range jobs {
|
||||
externalKey, parentKey := worksmobileOrgUnitExternalKeys(job)
|
||||
if externalKey == "" {
|
||||
continue
|
||||
}
|
||||
byExternalKey[externalKey] = orgUnitJob{jobID: job.ID, parentKey: parentKey}
|
||||
}
|
||||
|
||||
depthByExternalKey := map[string]int{}
|
||||
var depth func(externalKey string, seen map[string]bool) int
|
||||
depth = func(externalKey string, seen map[string]bool) int {
|
||||
if value, ok := depthByExternalKey[externalKey]; ok {
|
||||
return value
|
||||
}
|
||||
job, ok := byExternalKey[externalKey]
|
||||
if !ok || job.parentKey == "" || seen[externalKey] {
|
||||
depthByExternalKey[externalKey] = 0
|
||||
return 0
|
||||
}
|
||||
seen[externalKey] = true
|
||||
value := depth(job.parentKey, seen) + 1
|
||||
delete(seen, externalKey)
|
||||
depthByExternalKey[externalKey] = value
|
||||
return value
|
||||
}
|
||||
|
||||
depthByJobID := map[string]int{}
|
||||
for externalKey, job := range byExternalKey {
|
||||
depthByJobID[job.jobID] = depth(externalKey, map[string]bool{})
|
||||
}
|
||||
return depthByJobID
|
||||
}
|
||||
|
||||
func worksmobileOrgUnitExternalKeys(job domain.WorksmobileOutbox) (string, string) {
|
||||
if job.ResourceType != domain.WorksmobileResourceOrgUnit || job.Action != domain.WorksmobileActionUpsert {
|
||||
return "", ""
|
||||
}
|
||||
var payload WorksmobileOrgUnitPayload
|
||||
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
|
||||
return "", ""
|
||||
}
|
||||
parentKey := strings.TrimSpace(payload.ParentOrgUnitID)
|
||||
if strings.HasPrefix(parentKey, "externalKey:") {
|
||||
parentKey = strings.TrimSpace(strings.TrimPrefix(parentKey, "externalKey:"))
|
||||
} else {
|
||||
parentKey = ""
|
||||
}
|
||||
return strings.TrimSpace(payload.OrgUnitExternalKey), parentKey
|
||||
}
|
||||
|
||||
func worksmobileOutboxUserIdentifier(job domain.WorksmobileOutbox) string {
|
||||
userID := stringValue(job.Payload["loginEmail"])
|
||||
if userID == "" {
|
||||
userID = stringValue(job.Payload["userExternalKey"])
|
||||
}
|
||||
return userID
|
||||
}
|
||||
|
||||
func decodeWorksmobileRequest(payload domain.JSONMap, target any) error {
|
||||
raw := payload["request"]
|
||||
if raw == nil {
|
||||
return errors.New("worksmobile request payload is missing")
|
||||
}
|
||||
data, err := json.Marshal(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decoder := json.NewDecoder(strings.NewReader(string(data)))
|
||||
decoder.DisallowUnknownFields()
|
||||
return decoder.Decode(target)
|
||||
}
|
||||
|
||||
func worksmobileRetryDelay(retryCount int) time.Duration {
|
||||
if retryCount < 0 {
|
||||
retryCount = 0
|
||||
}
|
||||
if retryCount > 5 {
|
||||
retryCount = 5
|
||||
}
|
||||
return time.Duration(1<<retryCount) * time.Minute
|
||||
}
|
||||
2066
baron-sso/backend/internal/service/worksmobile_sync_service.go
Normal file
2066
baron-sso/backend/internal/service/worksmobile_sync_service.go
Normal file
File diff suppressed because it is too large
Load Diff
2224
baron-sso/backend/internal/service/worksmobile_sync_service_test.go
Normal file
2224
baron-sso/backend/internal/service/worksmobile_sync_service_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user