1
0
forked from baron/baron-sso

refactor: backend tenant_group 제거 및 리팩터 반영

This commit is contained in:
Lectom C Han
2026-02-12 22:14:34 +09:00
parent b0792113ae
commit a8a219d7ef
26 changed files with 494 additions and 1001 deletions

View File

@@ -1,51 +1,22 @@
package handler
import (
"baron-sso-backend/internal/service"
"runtime"
"time"
"github.com/gofiber/fiber/v2"
)
type AdminHandler struct {
Keto service.KetoService
}
type AdminHandler struct{}
func NewAdminHandler(keto service.KetoService) *AdminHandler {
return &AdminHandler{Keto: keto}
func NewAdminHandler() *AdminHandler {
return &AdminHandler{}
}
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
}
func (h *AdminHandler) CheckPermission(c *fiber.Ctx) error {
namespace := c.Query("namespace")
object := c.Query("object")
relation := c.Query("relation")
subject := c.Query("subject")
if namespace == "" || object == "" || relation == "" || subject == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "namespace, object, relation, and subject are required"})
}
allowed, err := h.Keto.CheckPermission(c.Context(), subject, namespace, object, relation)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{
"allowed": allowed,
"query": fiber.Map{
"namespace": namespace,
"object": object,
"relation": relation,
"subject": subject,
},
})
}
// GetSystemStats returns runtime statistics for monitoring
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
var m runtime.MemStats

View File

@@ -125,11 +125,10 @@ func GenerateSecureAlnumToken(length int) string {
func GenerateUserCode() string {
const letters = "ABCDEFGHJKLMNPQRSTUVWXYZ"
// [Fixed] 요청하신 포맷 (영문 2자리 + 숫자 6자리, 하이픈 없음)으로 변경
return fmt.Sprintf("%c%c%06d",
return fmt.Sprintf("%c%c-%03d",
letters[rand.Intn(len(letters))],
letters[rand.Intn(len(letters))],
rand.Intn(1000000),
rand.Intn(1000),
)
}
@@ -455,7 +454,8 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
slog.Info("[Signup] New user registered", "email", req.Email, "type", req.AffiliationType, "provider", h.IdpProvider.Name(), "subject", providerID)
// [New] Local DB Sync
// [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
// 로컬 DB 저장이 실패하더라도 회원가입 프로세스는 성공으로 간주합니다.
localUser := &domain.User{
ID: providerID, // Match IDP Subject
Email: req.Email,
@@ -471,9 +471,17 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
}
if h.UserRepo != nil {
if err := h.UserRepo.Create(c.Context(), localUser); err != nil {
slog.Error("[Signup] Failed to sync user to local DB", "email", req.Email, "error", err)
}
go func(u *domain.User) {
// 요청 Context가 취소될 수 있으므로 Background Context 사용
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.UserRepo.Create(ctx, u); err != nil {
slog.Error("[Signup] Failed to sync user to Read-Model (Local DB)", "email", u.Email, "error", err)
} else {
slog.Debug("[Signup] Synced user to Read-Model", "email", u.Email)
}
}(localUser)
}
// [Keto] Sync user-tenant relationship
@@ -959,20 +967,13 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "Identity provider unavailable"})
}
// [Changed] 토큰 길이를 사용자의 요청에 맞춰 6글자(3바이트)로, pendingRef를 8글자(4바이트)로 조정
userCode := GenerateUserCode()
token := GenerateSecureToken(3)
pendingRef := GenerateSecureToken(3)
slog.Info("[Enchanted] Initiating enchanted link", "loginID", loginID, "token", token, "pendingRef", pendingRef)
// [Added] 사용자가 입력할 간편 코드를 Redis에 저장합니다. (이게 없으면 인증이 안 됩니다)
shortCodePayload, _ := json.Marshal(shortLoginCodePayload{
LoginID: lookupLoginID,
Code: token,
PendingRef: pendingRef,
})
h.RedisService.Set(prefixLoginCodeShort+userCode, string(shortCodePayload), defaultExpiration)
// Store in Redis
sessionData, _ := json.Marshal(map[string]string{
"status": statusPending,
@@ -1026,13 +1027,12 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error {
}
} else {
// Send SMS
phone := sanitizePhoneForSms(loginID)
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s | 간편 코드: %s", link, userCode)
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s | 코드: %s", link, userCode)
if drySend {
slog.Info("[Enchanted][DrySend] SMS send skipped", "loginID", phone, "content", content)
slog.Info("[Enchanted][DrySend] SMS send skipped", "loginID", loginID, "content", content)
} else {
slog.Info("[Enchanted] Sending SMS via Naver Cloud", "to", phone)
if err := h.SmsService.SendSms(phone, content); err != nil {
slog.Info("[Enchanted] Sending SMS via Naver Cloud", "loginID", loginID)
if err := h.SmsService.SendSms(loginID, content); err != nil {
slog.Error("[Enchanted] SMS Failed", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to send SMS"})
}
@@ -1526,7 +1526,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
loginID := strings.TrimSpace(req.LoginID)
ale.LoginIDs["loginId"] = req.LoginID // 원문
ale.LoginIDs["loginId_normalized"] = loginID
// ale.NewPassword = req.Password // For test only, logging password (sensitive)
ale.NewPassword = req.Password // For test only, logging password (sensitive)
ale.Log(slog.LevelInfo, "Attempting to login")
@@ -1568,25 +1568,22 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
// --- OIDC 로그인 흐름 처리 ---
if req.LoginChallenge != "" {
slog.Info("OIDC login flow detected", "challenge", req.LoginChallenge, "subject", subject)
slog.Info("OIDC login flow detected", "challenge", req.LoginChallenge)
// Check if the client is active
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), req.LoginChallenge)
if err == nil && loginReq != nil {
slog.Info("OIDC Client Info", "client_id", loginReq.Client.ClientID, "name", loginReq.Client.ClientName)
if loginReq.Client.Metadata != nil {
if status, ok := loginReq.Client.Metadata["status"].(string); ok {
if strings.ToLower(status) == "inactive" {
slog.Warn("Login rejected for inactive client in PasswordLogin", "client_id", loginReq.Client.ClientID)
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
}
if err == nil && loginReq != nil && loginReq.Client.Metadata != nil {
if status, ok := loginReq.Client.Metadata["status"].(string); ok {
if strings.ToLower(status) == "inactive" {
slog.Warn("Login rejected for inactive client in PasswordLogin", "client_id", loginReq.Client.ClientID)
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
}
}
}
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject)
if err != nil {
slog.Error("failed to accept hydra login request", "error", err, "challenge", req.LoginChallenge)
slog.Error("failed to accept hydra login request", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
}
slog.Info("Hydra login request accepted", "redirectTo", acceptResp.RedirectTo)
@@ -1597,13 +1594,12 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
// --- OIDC 로그인 흐름 처리 끝 ---
resp := fiber.Map{
"sessionToken": authInfo.SessionToken.JWT,
"sessionJwt": authInfo.SessionToken.JWT, // Frontend compatibility
"status": "ok",
"provider": h.IdpProvider.Name(),
"sessionJwt": authInfo.SessionToken.JWT,
"status": "ok",
"provider": h.IdpProvider.Name(),
}
if authInfo.RefreshToken != nil {
resp["refreshToken"] = authInfo.RefreshToken.JWT
resp["refreshJwt"] = authInfo.RefreshToken.JWT
}
if authInfo.Subject != "" {
resp["subject"] = authInfo.Subject
@@ -2079,16 +2075,6 @@ type kratosCourierRequest struct {
Body string `json:"body"`
}
// sanitizePhoneForSms - 네이버 SMS 등 국내 발송기를 위해 +82 형식을 010 형식으로 변환합니다.
func sanitizePhoneForSms(phone string) string {
p := strings.ReplaceAll(phone, "-", "")
p = strings.ReplaceAll(p, " ", "")
if strings.HasPrefix(p, "+82") {
return "0" + p[3:]
}
return p
}
// HandleKratosCourierRelay - Kratos courier HTTP 요청을 받아 메일/SMS 발송으로 변환합니다.
func (h *AuthHandler) HandleKratosCourierRelay(c *fiber.Ctx) error {
var req kratosCourierRequest
@@ -2467,6 +2453,16 @@ func extractFirstString(data map[string]interface{}, keys ...string) string {
return ""
}
func sanitizePhoneForSms(phone string) string {
sanitized := strings.TrimSpace(phone)
if strings.HasPrefix(sanitized, "+82") {
sanitized = "0" + sanitized[3:]
}
sanitized = strings.ReplaceAll(sanitized, "-", "")
sanitized = strings.ReplaceAll(sanitized, " ", "")
return sanitized
}
// --- User Profile Handlers ---
func (h *AuthHandler) formatPhoneForDisplay(phone string) string {
@@ -2484,56 +2480,7 @@ func (h *AuthHandler) formatPhoneForStorage(phone string) string {
return phone
}
// ProxyOidc - 프론트엔드의 OIDC 요청을 내부 Hydra 서비스로 프록시합니다.
func (h *AuthHandler) ProxyOidc(c *fiber.Ctx) error {
path := c.Params("*")
// [Strict] Always use internal Docker network address for proxying to avoid external loops
targetURL := "http://hydra:4444"
// 프록시 URL 구성
u, err := url.Parse(targetURL)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "invalid hydra public url")
}
u.Path = strings.TrimRight(u.Path, "/") + "/" + path
u.RawQuery = string(c.Request().URI().QueryString())
slog.Debug("Proxying OIDC request", "from", c.Path(), "to", u.String())
// 요청 준비
req, err := http.NewRequestWithContext(c.Context(), c.Method(), u.String(), bytes.NewReader(c.Body()))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "failed to create proxy request")
}
// 헤더 복사
c.Request().Header.VisitAll(func(key, value []byte) {
k := string(key)
if k != "Host" && k != "Connection" {
req.Header.Add(k, string(value))
}
})
// 요청 실행 (Hydra 내부 HttpClient 사용)
resp, err := h.Hydra.HttpClient().Do(req)
if err != nil {
return fiber.NewError(fiber.StatusServiceUnavailable, "hydra public api unavailable")
}
defer resp.Body.Close()
// 응답 헤더 복사
for k, values := range resp.Header {
for _, v := range values {
c.Set(k, v)
}
}
// 상태 코드 및 바디 설정
c.Status(resp.StatusCode)
_, err = io.Copy(c.Response().BodyWriter(), resp.Body)
return err
}
// GetMe - Returns current user's profile with enriched data from local DB
func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
profile, err := h.resolveCurrentProfile(c)
if err != nil {
@@ -4006,13 +3953,6 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe
}
}
// Fetch Manageable Tenants for Admins
if profile.Role == domain.RoleSuperAdmin || profile.Role == domain.RoleTenantAdmin || profile.Role == domain.RoleRPAdmin {
if tenants, err := h.TenantService.ListManageableTenants(c.Context(), profile.ID); err == nil {
profile.ManageableTenants = tenants
}
}
// 4. Save to Redis Cache (Short TTL)
if h.RedisService != nil && cacheKey != "" {
if data, err := json.Marshal(profile); err == nil {
@@ -4842,7 +4782,10 @@ func extractLoginIDFromClaims(claims map[string]any) string {
}
func (h *AuthHandler) getKratosIdentity(sessionToken string) (string, map[string]interface{}, error) {
kratosURL := strings.TrimRight(utils.GetEnv("KRATOS_PUBLIC_URL", "http://kratos:4433"), "/")
kratosURL := strings.TrimRight(os.Getenv("KRATOS_PUBLIC_URL"), "/")
if kratosURL == "" {
kratosURL = "http://kratos:4433"
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kratosURL+"/sessions/whoami", nil)
if err != nil {
return "", nil, err
@@ -4850,44 +4793,33 @@ func (h *AuthHandler) getKratosIdentity(sessionToken string) (string, map[string
req.Header.Set("X-Session-Token", sessionToken)
resp, err := http.DefaultClient.Do(req)
if err == nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
var result struct {
Identity struct {
ID string `json:"id"`
Traits map[string]interface{} `json:"traits"`
} `json:"identity"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err == nil {
return result.Identity.ID, result.Identity.Traits, nil
}
}
if err != nil {
return "", nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return "", nil, fmt.Errorf("kratos whoami failed status=%d body=%s", resp.StatusCode, string(body))
}
// 2. Kratos 실패 시 Hydra Introspection 시도 (OIDC Access Token 대응)
if h.Hydra != nil {
slog.Debug("[Auth] Kratos whoami failed, trying Hydra introspection", "token_prefix", sessionToken[:min(len(sessionToken), 10)])
introspection, err := h.Hydra.IntrospectToken(context.Background(), sessionToken)
if err == nil && introspection["active"] == true {
subject, _ := introspection["sub"].(string)
if subject != "" {
// Hydra는 Traits를 직접 주지 않으므로, Kratos Admin API로 상세 정보를 가져옴
identity, err := h.KratosAdmin.GetIdentity(context.Background(), subject)
if err == nil && identity != nil {
return identity.ID, identity.Traits, nil
}
// Identity 정보가 없더라도 최소한 Subject는 반환
return subject, map[string]interface{}{}, nil
}
}
var result struct {
Identity struct {
ID string `json:"id"`
Traits map[string]interface{} `json:"traits"`
} `json:"identity"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", nil, err
}
return "", nil, fmt.Errorf("invalid session or token")
return result.Identity.ID, result.Identity.Traits, nil
}
func (h *AuthHandler) getKratosSessionID(sessionToken string) (string, error) {
kratosURL := strings.TrimRight(utils.GetEnv("KRATOS_PUBLIC_URL", "http://kratos:4433"), "/")
kratosURL := strings.TrimRight(os.Getenv("KRATOS_PUBLIC_URL"), "/")
if kratosURL == "" {
kratosURL = "http://kratos:4433"
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kratosURL+"/sessions/whoami", nil)
if err != nil {
return "", err
@@ -4910,7 +4842,6 @@ func (h *AuthHandler) getKratosSessionID(sessionToken string) (string, error) {
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
return result.ID, nil
}
@@ -4919,7 +4850,10 @@ func (h *AuthHandler) issueKratosSession(ctx context.Context, identityID string)
return "", fmt.Errorf("kratos identity id is empty")
}
kratosAdminURL := strings.TrimRight(utils.GetEnv("KRATOS_ADMIN_URL", "http://kratos:4434"), "/")
kratosAdminURL := strings.TrimRight(os.Getenv("KRATOS_ADMIN_URL"), "/")
if kratosAdminURL == "" {
kratosAdminURL = "http://kratos:4434"
}
payload := map[string]interface{}{
"identity_id": identityID,

View File

@@ -0,0 +1,251 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Async Test Mocks ---
type AsyncMockIdpProvider struct {
mock.Mock
}
func (m *AsyncMockIdpProvider) Name() string { return "mock-idp" }
func (m *AsyncMockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{}, nil
}
func (m *AsyncMockIdpProvider) UserExists(loginID string) (bool, error) {
args := m.Called(loginID)
return args.Bool(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
args := m.Called(user, password)
return args.String(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) { return nil, nil }
func (m *AsyncMockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{MinLength: 12}, nil
}
func (m *AsyncMockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
func (m *AsyncMockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
type AsyncMockUserRepo struct {
mock.Mock
createCalled chan bool
}
func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error {
// Simulate DB latency
time.Sleep(50 * time.Millisecond)
args := m.Called(ctx, user)
if m.createCalled != nil {
m.createCalled <- true
}
return args.Error(0)
}
func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error { return nil }
func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search string) ([]domain.User, int64, error) {
return nil, 0, nil
}
type AsyncMockRedisRepo struct {
mock.Mock
}
func (m *AsyncMockRedisRepo) Set(key string, value string, expiration time.Duration) error {
args := m.Called(key, value, expiration)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *AsyncMockRedisRepo) Delete(key string) error {
args := m.Called(key)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) StoreVerificationCode(phone, code string) error { return nil }
func (m *AsyncMockRedisRepo) GetVerificationCode(phone string) (string, error) { return "", nil }
func (m *AsyncMockRedisRepo) DeleteVerificationCode(phone string) error { return nil }
type AsyncMockTenantService struct {
mock.Mock
}
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
args := m.Called(ctx, emailDomain)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
return nil, nil
}
type AsyncMockKetoService struct {
mock.Mock
}
func (m *AsyncMockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Error(0)
}
func (m *AsyncMockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (m *AsyncMockKetoService) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
return false, nil
}
func (m *AsyncMockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
return nil, nil
}
func (m *AsyncMockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
return nil, nil
}
// --- Tests ---
func TestSignup_AsyncDB_Isolation(t *testing.T) {
mockIdp := new(AsyncMockIdpProvider)
mockUserRepo := new(AsyncMockUserRepo)
mockRedis := new(AsyncMockRedisRepo)
mockTenant := new(AsyncMockTenantService)
mockKeto := new(AsyncMockKetoService)
h := &AuthHandler{
IdpProvider: mockIdp,
UserRepo: mockUserRepo,
RedisService: mockRedis,
TenantService: mockTenant,
KetoService: mockKeto,
}
app := fiber.New()
app.Post("/signup", h.Signup)
t.Run("SoT_DB_Failure_Ignored_And_Async", func(t *testing.T) {
email := "test@example.com"
phone := "010-1234-5678"
emailKey := "signup:email:" + email
phoneKey := "signup:phone:" + "01012345678"
// Redis Mocks
mockRedis.On("Get", emailKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Get", phoneKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Delete", emailKey).Return(nil)
mockRedis.On("Delete", phoneKey).Return(nil)
// Tenant Mocks
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, errors.New("not found"))
// Kratos Mocks (Success)
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)
// UserRepo Mocks (Async & Failure)
mockUserRepo.createCalled = make(chan bool, 1)
mockUserRepo.On("Create", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
return u.Email == email
})).Return(errors.New("db connection error"))
// Keto Mocks (Optional, since it's also async)
// We won't block on this either
body, _ := json.Marshal(domain.SignupRequest{
Email: email,
Password: "Password123!",
Name: "Test User",
Phone: phone,
TermsAccepted: true,
})
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
start := time.Now()
resp, err := app.Test(req, 5000)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
assert.Equal(t, 200, resp.StatusCode)
// Ensure API responded faster than DB latency (50ms)
assert.Less(t, int64(elapsed), int64(60*time.Millisecond), "API should return before DB timeout")
// Wait for async execution
select {
case <-mockUserRepo.createCalled:
// Pass
case <-time.After(2 * time.Second):
t.Fatal("UserRepo.Create was not called asynchronously")
}
mockRedis.AssertExpectations(t)
mockIdp.AssertExpectations(t)
mockUserRepo.AssertExpectations(t)
})
}

View File

@@ -288,8 +288,8 @@ func TestPasswordLogin_NoOIDC_Success(t *testing.T) {
}
var got map[string]string
json.NewDecoder(resp.Body).Decode(&got)
if got["sessionToken"] != "valid-jwt" {
t.Errorf("expected jwt valid-jwt, got %s", got["sessionToken"])
if got["sessionJwt"] != "valid-jwt" {
t.Errorf("expected jwt valid-jwt, got %s", got["sessionJwt"])
}
// No redirectTo
if _, ok := got["redirectTo"]; ok {

View File

@@ -22,17 +22,15 @@ type DevHandler struct {
SecretRepo domain.ClientSecretRepository
KratosAdmin *service.KratosAdminService
ConsentRepo repository.ClientConsentRepository
RPService service.RelyingPartyService
}
func NewDevHandler(redis domain.RedisRepository, secretRepo domain.ClientSecretRepository, consentRepo repository.ClientConsentRepository, rpService service.RelyingPartyService) *DevHandler {
func NewDevHandler(redis domain.RedisRepository, secretRepo domain.ClientSecretRepository, consentRepo repository.ClientConsentRepository) *DevHandler {
return &DevHandler{
Hydra: service.NewHydraAdminService(),
Redis: redis,
SecretRepo: secretRepo,
KratosAdmin: service.NewKratosAdminService(),
ConsentRepo: consentRepo,
RPService: rpService,
}
}
@@ -97,58 +95,38 @@ type clientUpsertRequest struct {
}
func (h *DevHandler) ListClients(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized: user profile not found"})
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
if limit <= 0 {
limit = 50
}
if offset < 0 {
offset = 0
}
// Super Admin sees all (best effort via Hydra list for now, or we can use RPService if it's improved)
if profile.Role == domain.RoleSuperAdmin {
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
clients, err := h.Hydra.ListClients(c.Context(), limit, offset)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
items := make([]clientSummary, 0, len(clients))
for _, client := range clients {
items = append(items, h.mapClientSummary(client))
}
return c.JSON(clientListResponse{Items: items, Limit: limit, Offset: offset})
}
// For others, only show manageable tenants' clients
var tenantIDs []string
for _, t := range profile.ManageableTenants {
tenantIDs = append(tenantIDs, t.ID)
}
if len(tenantIDs) == 0 && profile.TenantID != nil {
tenantIDs = append(tenantIDs, *profile.TenantID)
}
if len(tenantIDs) == 0 {
return c.JSON(clientListResponse{Items: []clientSummary{}, Limit: 50, Offset: 0})
}
rps, err := h.RPService.ListByTenantIDs(c.Context(), tenantIDs)
clients, err := h.Hydra.ListClients(c.Context(), limit, offset)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
if errors.Is(err, service.ErrHydraNotFound) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "clients not found"})
}
errMsg := err.Error()
if strings.Contains(errMsg, "connection refused") || strings.Contains(errMsg, "dial tcp") {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "Hydra service is unavailable. Please check if Ory Hydra is running.",
})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": errMsg})
}
items := make([]clientSummary, 0, len(rps))
for _, rp := range rps {
// We need HydraClient details for the summary
client, err := h.Hydra.GetClient(c.Context(), rp.ClientID)
if err == nil {
items = append(items, h.mapClientSummary(*client))
}
items := make([]clientSummary, 0, len(clients))
for _, client := range clients {
items = append(items, h.mapClientSummary(client))
}
return c.JSON(clientListResponse{
Items: items,
Limit: len(items),
Offset: 0,
Limit: limit,
Offset: offset,
})
}
@@ -166,11 +144,6 @@ func (h *DevHandler) GetClient(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// Set for audit logging
if tid, ok := client.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
summary := h.mapClientSummary(*client)
return c.JSON(clientDetailResponse{
Client: summary,
@@ -224,49 +197,11 @@ func (h *DevHandler) UpdateClientStatus(c *fiber.Ctx) error {
}
func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
}
var req clientUpsertRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
}
// Determine Tenant ID
targetTenantID := c.Get("X-Tenant-ID")
if targetTenantID == "" && profile.TenantID != nil {
targetTenantID = *profile.TenantID
}
if targetTenantID == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "X-Tenant-ID header is required"})
}
// Set for audit logging
c.Locals("tenant_id", targetTenantID)
// Validate Permission
isAllowed := false
if profile.Role == domain.RoleSuperAdmin {
isAllowed = true
} else {
for _, t := range profile.ManageableTenants {
if t.ID == targetTenantID {
isAllowed = true
break
}
}
if !isAllowed && profile.TenantID != nil && *profile.TenantID == targetTenantID {
isAllowed = true
}
}
if !isAllowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "you do not have permission to create clients for this tenant"})
}
clientID := strings.TrimSpace(valueOr(req.ID, ""))
if clientID == "" {
clientID = uuid.NewString()
@@ -322,18 +257,11 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
Metadata: metadata,
}
// Use RPService to ensure Keto relations are created
rp, err := h.RPService.Create(c.Context(), targetTenantID, clientReq)
created, err := h.Hydra.CreateClient(c.Context(), clientReq)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// Fetch back the Hydra client to get the secret (RPService.Create returns domain.RelyingParty which has limited fields)
created, err := h.Hydra.GetClient(c.Context(), rp.ClientID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "client created but failed to retrieve details"})
}
// Store secret in metadata for later retrieval
if created.ClientSecret != "" {
// 1. Store in PostgreSQL (Source of Truth)
@@ -379,11 +307,6 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// Set for audit logging
if tid, ok := current.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
clientType := ""
if req.Type != nil {
clientType = strings.ToLower(strings.TrimSpace(*req.Type))
@@ -459,14 +382,6 @@ func (h *DevHandler) DeleteClient(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client id is required"})
}
// Fetch first for audit log tenant_id
client, err := h.Hydra.GetClient(c.Context(), clientID)
if err == nil {
if tid, ok := client.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
}
if err := h.Hydra.DeleteClient(c.Context(), clientID); err != nil {
if errors.Is(err, service.ErrHydraNotFound) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "client not found"})
@@ -488,24 +403,11 @@ func (h *DevHandler) DeleteClient(c *fiber.Ctx) error {
}
func (h *DevHandler) ListConsents(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
}
clientID := strings.TrimSpace(c.Query("client_id"))
if clientID == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client_id is required"})
}
// Permission Check
if profile.Role != domain.RoleSuperAdmin {
allowed, err := h.RPService.CheckPermission(c.Context(), profile.ID, clientID, "view")
if err != nil || !allowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: you do not have permission to view consents for this client"})
}
}
subject := strings.TrimSpace(c.Query("subject"))
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
@@ -582,28 +484,12 @@ func (h *DevHandler) ListConsents(c *fiber.Ctx) error {
}
func (h *DevHandler) RevokeConsents(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
}
subject := strings.TrimSpace(c.Query("subject"))
if subject == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "subject is required"})
}
clientID := strings.TrimSpace(c.Query("client_id"))
// Permission Check (if clientID is provided)
if clientID != "" && profile.Role != domain.RoleSuperAdmin {
allowed, err := h.RPService.CheckPermission(c.Context(), profile.ID, clientID, "manage")
if err != nil || !allowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: you do not have permission to revoke consents for this client"})
}
} else if clientID == "" && profile.Role != domain.RoleSuperAdmin {
// If clientID is not provided, we might need a more global check or just disallow it for non-superadmins
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client_id is required for non-superadmins"})
}
// If subject is not a UUID, try to resolve it as an identifier (email/username)
if _, err := uuid.Parse(subject); err != nil {
resolved, err := h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), subject)
@@ -646,11 +532,6 @@ func (h *DevHandler) RotateClientSecret(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
// Set for audit logging
if tid, ok := current.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
// 3. Update Hydra
current.ClientSecret = newSecret
updated, err := h.Hydra.UpdateClient(c.Context(), clientID, *current)

View File

@@ -1,10 +1,8 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -12,75 +10,8 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockRPService struct {
mock.Mock
}
func (m *MockRPService) Create(ctx context.Context, tenantID string, client domain.HydraClient) (*domain.RelyingParty, error) {
args := m.Called(ctx, tenantID, client)
return args.Get(0).(*domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) Get(ctx context.Context, clientID string) (*domain.RelyingParty, *domain.HydraClient, error) {
args := m.Called(ctx, clientID)
return args.Get(0).(*domain.RelyingParty), args.Get(1).(*domain.HydraClient), args.Error(2)
}
func (m *MockRPService) List(ctx context.Context, tenantID string) ([]domain.RelyingParty, error) {
args := m.Called(ctx, tenantID)
return args.Get(0).([]domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) ListAll(ctx context.Context) ([]domain.RelyingParty, error) {
args := m.Called(ctx)
return args.Get(0).([]domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error) {
args := m.Called(ctx, tenantIDs)
return args.Get(0).([]domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error) {
args := m.Called(ctx, clientID, client)
return args.Get(0).(*domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) Delete(ctx context.Context, clientID string) error {
args := m.Called(ctx, clientID)
return args.Error(0)
}
func (m *MockRPService) CheckPermission(ctx context.Context, userID, clientID, relation string) (bool, error) {
args := m.Called(ctx, userID, clientID, relation)
return args.Bool(0), args.Error(1)
}
func (m *MockRPService) AddOwner(ctx context.Context, clientID, subject string) error {
args := m.Called(ctx, clientID, subject)
return args.Error(0)
}
func (m *MockRPService) RemoveOwner(ctx context.Context, clientID, subject string) error {
args := m.Called(ctx, clientID, subject)
return args.Error(0)
}
func (m *MockRPService) ListOwners(ctx context.Context, clientID string) ([]string, error) {
args := m.Called(ctx, clientID)
return args.Get(0).([]string), args.Error(1)
}
func withMockProfile(profile *domain.UserProfileResponse) fiber.Handler {
return func(c *fiber.Ctx) error {
c.Locals("user_profile", profile)
return c.Next()
}
}
func TestListClients_Success(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients" {
@@ -99,11 +30,7 @@ func TestListClients_Success(t *testing.T) {
},
}
app := fiber.New()
adminProfile := &domain.UserProfileResponse{
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Get("/api/v1/dev/clients", withMockProfile(adminProfile), h.ListClients)
app.Get("/api/v1/dev/clients", h.ListClients)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
@@ -139,11 +66,7 @@ func TestGetClient_Success(t *testing.T) {
},
}
app := fiber.New()
adminProfile := &domain.UserProfileResponse{
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Get("/api/v1/dev/clients/:id", withMockProfile(adminProfile), h.GetClient)
app.Get("/api/v1/dev/clients/:id", h.GetClient)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-123", nil)
resp, _ := app.Test(req, -1)
@@ -169,11 +92,7 @@ func TestGetClient_NotFound(t *testing.T) {
},
}
app := fiber.New()
adminProfile := &domain.UserProfileResponse{
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Get("/api/v1/dev/clients/:id", withMockProfile(adminProfile), h.GetClient)
app.Get("/api/v1/dev/clients/:id", h.GetClient)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/non-existent", nil)
resp, _ := app.Test(req, -1)
@@ -190,49 +109,30 @@ func TestCreateClient_Success(t *testing.T) {
"client_secret": "secret-123",
}), nil
}
if r.Method == http.MethodGet && r.URL.Path == "/clients/new-client-123" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"client_id": "new-client-123",
"client_name": "New App",
"client_secret": "secret-123",
"metadata": map[string]interface{}{"status": "active"},
}), nil
}
return httpJSONAny(r, http.StatusInternalServerError, map[string]string{"error": "hydra error path: " + r.URL.Path}), nil
return httpJSONAny(r, http.StatusInternalServerError, map[string]string{"error": "hydra error"}), nil
})
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
redisRepo := &mockRedisRepo{data: make(map[string]string)}
mockRP := new(MockRPService)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra-public.test",
HTTPClient: &http.Client{Transport: transport},
},
SecretRepo: secretRepo,
Redis: redisRepo,
RPService: mockRP,
}
app := fiber.New()
adminProfile := &domain.UserProfileResponse{
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Post("/api/v1/dev/clients", withMockProfile(adminProfile), h.CreateClient)
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]interface{}{
"client_name": "New App",
"type": "confidential",
"redirectUris": []string{"http://localhost/cb"},
})
mockRP.On("Create", mock.Anything, "t1", mock.Anything).Return(&domain.RelyingParty{ClientID: "new-client-123"}, nil)
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Tenant-ID", "t1")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode)

View File

@@ -10,11 +10,10 @@ import (
type RelyingPartyHandler struct {
Service service.RelyingPartyService
UserSvc *service.KratosAdminService
}
func NewRelyingPartyHandler(s service.RelyingPartyService, userSvc *service.KratosAdminService) *RelyingPartyHandler {
return &RelyingPartyHandler{Service: s, UserSvc: userSvc}
func NewRelyingPartyHandler(s service.RelyingPartyService) *RelyingPartyHandler {
return &RelyingPartyHandler{Service: s}
}
func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error {
@@ -111,58 +110,3 @@ func (h *RelyingPartyHandler) Delete(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
}
func (h *RelyingPartyHandler) ListOwners(c *fiber.Ctx) error {
clientID := c.Params("id")
subjects, err := h.Service.ListOwners(c.Context(), clientID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
type ownerInfo struct {
Subject string `json:"subject"`
Name string `json:"name,omitempty"`
Email string `json:"email,omitempty"`
Type string `json:"type"` // "user" or "group"
}
owners := make([]ownerInfo, 0, len(subjects))
for _, s := range subjects {
info := ownerInfo{Subject: s, Type: "unknown"}
if len(s) > 5 && s[:5] == "User:" {
info.Type = "user"
userID := s[5:]
identity, err := h.UserSvc.GetIdentity(c.Context(), userID)
if err == nil && identity != nil {
info.Name, _ = identity.Traits["name"].(string)
info.Email, _ = identity.Traits["email"].(string)
}
} else if len(s) > 10 && s[:10] == "UserGroup:" {
info.Type = "group"
// Group name enrichment could be added if we have a GroupService here
}
owners = append(owners, info)
}
return c.JSON(owners)
}
func (h *RelyingPartyHandler) AddOwner(c *fiber.Ctx) error {
clientID := c.Params("id")
subject := c.Params("subject") // e.g. "User:uuid"
if err := h.Service.AddOwner(c.Context(), clientID, subject); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "owner added"})
}
func (h *RelyingPartyHandler) RemoveOwner(c *fiber.Ctx) error {
clientID := c.Params("id")
subject := c.Params("subject")
if err := h.Service.RemoveOwner(c.Context(), clientID, subject); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "owner removed"})
}

View File

@@ -14,25 +14,22 @@ import (
type TenantHandler struct {
DB *gorm.DB
Service service.TenantService
Keto service.KetoService
UserSvc *service.KratosAdminService
}
func NewTenantHandler(db *gorm.DB, svc service.TenantService, keto service.KetoService, userSvc *service.KratosAdminService) *TenantHandler {
return &TenantHandler{DB: db, Service: svc, Keto: keto, UserSvc: userSvc}
func NewTenantHandler(db *gorm.DB, svc service.TenantService) *TenantHandler {
return &TenantHandler{DB: db, Service: svc}
}
type tenantSummary struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Description string `json:"description"`
Status string `json:"status"`
TenantGroupID *string `json:"tenantGroupId,omitempty"`
Domains []string `json:"domains,omitempty"`
Config domain.JSONMap `json:"config,omitempty"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains,omitempty"`
Config domain.JSONMap `json:"config,omitempty"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
type tenantListResponse struct {
@@ -103,7 +100,7 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
}
var tenants []domain.Tenant
if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Preload("TenantGroup").Find(&tenants).Error; err != nil {
if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
@@ -126,7 +123,7 @@ func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
}
var tenant domain.Tenant
if err := h.DB.Preload("Domains").Preload("TenantGroup").First(&tenant, "id = ?", tenantID).Error; err != nil {
if err := h.DB.Preload("Domains").First(&tenant, "id = ?", tenantID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "tenant not found"})
}
@@ -207,13 +204,12 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
var req struct {
Name *string `json:"name"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
TenantGroupID *string `json:"tenantGroupId"`
Domains []string `json:"domains"`
Config map[string]any `json:"config"`
Name *string `json:"name"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
Domains []string `json:"domains"`
Config map[string]any `json:"config"`
}
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
@@ -255,29 +251,6 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
tenant.Config = req.Config
}
// Handle Group Change
if req.TenantGroupID != nil {
oldGroupID := tenant.TenantGroupID
newGroupID := req.TenantGroupID
if *newGroupID == "" {
newGroupID = nil
}
// Update Keto if group changed
if h.Keto != nil {
// Remove old group relation if existed
if oldGroupID != nil && (newGroupID == nil || *oldGroupID != *newGroupID) {
_ = h.Keto.DeleteRelation(c.Context(), "Tenant", tenant.ID, "parent_group", *oldGroupID)
}
// Add new group relation
if newGroupID != nil && (oldGroupID == nil || *oldGroupID != *newGroupID) {
_ = h.Keto.CreateRelation(c.Context(), "Tenant", tenant.ID, "parent_group", *newGroupID)
}
}
tenant.TenantGroupID = newGroupID
}
if err := h.DB.Save(&tenant).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
@@ -328,58 +301,6 @@ func (h *TenantHandler) DeleteTenant(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
}
func (h *TenantHandler) ListAdmins(c *fiber.Ctx) error {
tenantID := c.Params("id")
userIDs, err := h.Service.ListTenantAdmins(c.Context(), tenantID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
type adminInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}
admins := make([]adminInfo, 0, len(userIDs))
for _, uid := range userIDs {
identity, err := h.UserSvc.GetIdentity(c.Context(), uid)
if err == nil && identity != nil {
name, _ := identity.Traits["name"].(string)
email, _ := identity.Traits["email"].(string)
admins = append(admins, adminInfo{
ID: uid,
Name: name,
Email: email,
})
} else {
admins = append(admins, adminInfo{ID: uid})
}
}
return c.JSON(admins)
}
func (h *TenantHandler) AddAdmin(c *fiber.Ctx) error {
tenantID := c.Params("id")
userID := c.Params("userId")
if err := h.Service.AddTenantAdmin(c.Context(), tenantID, userID); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "admin added to tenant"})
}
func (h *TenantHandler) RemoveAdmin(c *fiber.Ctx) error {
tenantID := c.Params("id")
userID := c.Params("userId")
if err := h.Service.RemoveTenantAdmin(c.Context(), tenantID, userID); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "admin removed from tenant"})
}
func mapTenantSummary(t domain.Tenant) tenantSummary {
domains := make([]string, 0, len(t.Domains))
for _, d := range t.Domains {
@@ -387,16 +308,15 @@ func mapTenantSummary(t domain.Tenant) tenantSummary {
}
return tenantSummary{
ID: t.ID,
Name: t.Name,
Slug: t.Slug,
Description: t.Description,
Status: t.Status,
TenantGroupID: t.TenantGroupID,
Domains: domains,
Config: t.GetMergedConfig(),
CreatedAt: t.CreatedAt.Format(time.RFC3339),
UpdatedAt: t.UpdatedAt.Format(time.RFC3339),
ID: t.ID,
Name: t.Name,
Slug: t.Slug,
Description: t.Description,
Status: t.Status,
Domains: domains,
Config: t.Config,
CreatedAt: t.CreatedAt.Format(time.RFC3339),
UpdatedAt: t.UpdatedAt.Format(time.RFC3339),
}
}

View File

@@ -70,26 +70,6 @@ func (m *MockTenantService) SetKetoService(keto service.KetoService) {
m.Called(keto)
}
func (m *MockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
args := m.Called(ctx, tenantID, userID)
return args.Error(0)
}
func (m *MockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
args := m.Called(ctx, tenantID, userID)
return args.Error(0)
}
func (m *MockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
args := m.Called(ctx, tenantID)
return args.Get(0).([]string), args.Error(1)
}
func TestTenantHandler_CreateTenant(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)

View File

@@ -304,10 +304,15 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
localUser.TenantID = &tenantID
}
// [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
if h.UserRepo != nil {
if err := h.UserRepo.Create(c.Context(), localUser); err != nil {
slog.Error("[UserHandler] Failed to sync user to local DB", "email", email, "error", err)
}
go func(u *domain.User) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.UserRepo.Create(ctx, u); err != nil {
slog.Error("[UserHandler] Failed to sync user to local DB", "email", u.Email, "error", err)
}
}(localUser)
}
// [Keto] Sync relations
@@ -483,27 +488,32 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
localUser.Metadata = req.Metadata
}
if err := h.UserRepo.Update(c.Context(), localUser); err == nil {
// [Keto Sync on Role Change]
if h.KetoService != nil && req.Role != nil && *req.Role != oldRole {
go func(uID, oldR, newR, tID string) {
ctx := context.Background()
if oldR == domain.RoleSuperAdmin {
// [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
go func(u *domain.User, rRole *string, oRole string, oTenantID string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.UserRepo.Update(ctx, u); err == nil {
// [Keto Sync on Role Change]
if h.KetoService != nil && rRole != nil && *rRole != oRole {
uID := u.ID
newR := *rRole
if oRole == domain.RoleSuperAdmin {
_ = h.KetoService.DeleteRelation(ctx, "System", "global", "super_admins", uID)
} else if oldR == domain.RoleTenantAdmin && tID != "" {
_ = h.KetoService.DeleteRelation(ctx, "Tenant", tID, "admins", uID)
} else if oRole == domain.RoleTenantAdmin && oTenantID != "" {
_ = h.KetoService.DeleteRelation(ctx, "Tenant", oTenantID, "admins", uID)
}
if newR == domain.RoleSuperAdmin {
_ = h.KetoService.CreateRelation(ctx, "System", "global", "super_admins", uID)
} else if newR == domain.RoleTenantAdmin && tID != "" {
_ = h.KetoService.CreateRelation(ctx, "Tenant", tID, "admins", uID)
} else if newR == domain.RoleTenantAdmin && u.TenantID != nil {
_ = h.KetoService.CreateRelation(ctx, "Tenant", *u.TenantID, "admins", uID)
}
}(userID, oldRole, *req.Role, oldTenantID)
}
} else {
slog.Error("[UserHandler] Failed to sync user update to local DB", "userID", u.ID, "error", err)
}
} else {
slog.Error("[UserHandler] Failed to sync user update to local DB", "userID", userID, "error", err)
}
}(localUser, req.Role, oldRole, oldTenantID)
}
}