forked from baron/baron-sso
614 lines
15 KiB
Go
614 lines
15 KiB
Go
package service
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"baron-sso-backend/internal/pagination"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
)
|
|
|
|
var ctx = context.Background()
|
|
|
|
type RedisService struct {
|
|
Client *redis.Client
|
|
}
|
|
|
|
type identityMirrorStateStore struct {
|
|
Status string `json:"status"`
|
|
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
|
|
LastError string `json:"lastError,omitempty"`
|
|
MirrorVersion string `json:"mirrorVersion,omitempty"`
|
|
ObservedCount int64 `json:"observedCount,omitempty"`
|
|
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
|
|
}
|
|
|
|
type IdentityMirrorPageQuery struct {
|
|
Limit int
|
|
Offset int
|
|
Cursor string
|
|
Search string
|
|
TenantID string
|
|
TenantSlug string
|
|
AllowedTenantKeys map[string]bool
|
|
}
|
|
|
|
type IdentityMirrorPageResult struct {
|
|
Items []KratosIdentity
|
|
Total int64
|
|
Cursor string
|
|
NextCursor string
|
|
}
|
|
|
|
// NewRedisService creates and returns a new RedisService
|
|
func NewRedisService() (*RedisService, error) {
|
|
redisAddr := os.Getenv("REDIS_ADDR")
|
|
if redisAddr == "" {
|
|
redisAddr = "localhost:6389" // Fallback for local dev without Docker
|
|
}
|
|
|
|
rdb := redis.NewClient(&redis.Options{
|
|
Addr: redisAddr,
|
|
})
|
|
|
|
// Ping the server to check the connection
|
|
if _, err := rdb.Ping(ctx).Result(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// [DEV-FIX] Disable stop-writes-on-bgsave-error to allow writes even if persistence fails
|
|
// This is common in dev docker environments with permission issues.
|
|
rdb.ConfigSet(ctx, "stop-writes-on-bgsave-error", "no")
|
|
|
|
return &RedisService{Client: rdb}, nil
|
|
}
|
|
|
|
func (s *RedisService) Ping(ctx context.Context) error {
|
|
if s.Client == nil {
|
|
return os.ErrInvalid
|
|
}
|
|
return s.Client.Ping(ctx).Err()
|
|
}
|
|
|
|
// StoreVerificationCode saves the SMS verification code with a 3-minute expiration
|
|
func (s *RedisService) StoreVerificationCode(phone, code string) error {
|
|
if s == nil || s.Client == nil {
|
|
return os.ErrInvalid
|
|
}
|
|
// Key format: "sms_verify:01012345678"
|
|
key := "sms_verify:" + phone
|
|
expiration := 3 * time.Minute
|
|
err := s.Client.Set(ctx, key, code, expiration).Err()
|
|
return err
|
|
}
|
|
|
|
// GetVerificationCode retrieves the SMS verification code
|
|
func (s *RedisService) GetVerificationCode(phone string) (string, error) {
|
|
if s == nil || s.Client == nil {
|
|
return "", os.ErrInvalid
|
|
}
|
|
key := "sms_verify:" + phone
|
|
code, err := s.Client.Get(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
// Key does not exist (expired or incorrect phone number)
|
|
return "", nil
|
|
} else if err != nil {
|
|
return "", err
|
|
}
|
|
return code, nil
|
|
}
|
|
|
|
// DeleteVerificationCode removes the verification code after successful verification
|
|
func (s *RedisService) DeleteVerificationCode(phone string) error {
|
|
if s == nil || s.Client == nil {
|
|
return os.ErrInvalid
|
|
}
|
|
key := "sms_verify:" + phone
|
|
return s.Client.Del(ctx, key).Err()
|
|
}
|
|
|
|
// Set stores a key-value pair with expiration
|
|
func (s *RedisService) Set(key string, value string, expiration time.Duration) error {
|
|
if s == nil || s.Client == nil {
|
|
return os.ErrInvalid
|
|
}
|
|
return s.Client.Set(ctx, key, value, expiration).Err()
|
|
}
|
|
|
|
// Get retrieves a value by key
|
|
func (s *RedisService) Get(key string) (string, error) {
|
|
if s == nil || s.Client == nil {
|
|
return "", os.ErrInvalid
|
|
}
|
|
val, err := s.Client.Get(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
return "", nil
|
|
}
|
|
return val, err
|
|
}
|
|
|
|
// Delete removes a key
|
|
func (s *RedisService) Delete(key string) error {
|
|
if s == nil || s.Client == nil {
|
|
return os.ErrInvalid
|
|
}
|
|
return s.Client.Del(ctx, key).Err()
|
|
}
|
|
|
|
func (s *RedisService) DeleteByPrefix(ctx context.Context, prefix string) (int64, error) {
|
|
if s == nil || s.Client == nil {
|
|
return 0, os.ErrInvalid
|
|
}
|
|
prefix = strings.TrimSpace(prefix)
|
|
if prefix == "" {
|
|
return 0, os.ErrInvalid
|
|
}
|
|
|
|
var deleted int64
|
|
var cursor uint64
|
|
pattern := prefix + "*"
|
|
for {
|
|
keys, next, err := s.Client.Scan(ctx, cursor, pattern, 250).Result()
|
|
if err != nil {
|
|
return deleted, err
|
|
}
|
|
for len(keys) > 0 {
|
|
chunkSize := len(keys)
|
|
if chunkSize > 500 {
|
|
chunkSize = 500
|
|
}
|
|
chunk := keys[:chunkSize]
|
|
count, err := s.Client.Del(ctx, chunk...).Result()
|
|
if err != nil {
|
|
return deleted, err
|
|
}
|
|
deleted += count
|
|
keys = keys[chunkSize:]
|
|
}
|
|
cursor = next
|
|
if cursor == 0 {
|
|
break
|
|
}
|
|
}
|
|
return deleted, nil
|
|
}
|
|
|
|
func (s *RedisService) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
|
|
if s == nil || s.Client == nil {
|
|
return domain.IdentityCacheStatus{
|
|
Status: "unavailable",
|
|
RedisReady: false,
|
|
LastError: "redis service unavailable",
|
|
}, nil
|
|
}
|
|
|
|
if err := s.Client.Ping(ctx).Err(); err != nil {
|
|
return domain.IdentityCacheStatus{
|
|
Status: "failed",
|
|
RedisReady: false,
|
|
LastError: err.Error(),
|
|
}, nil
|
|
}
|
|
|
|
keyCount, err := s.countIdentityCacheKeys(ctx)
|
|
if err != nil {
|
|
return domain.IdentityCacheStatus{
|
|
Status: "failed",
|
|
RedisReady: true,
|
|
LastError: err.Error(),
|
|
}, nil
|
|
}
|
|
|
|
raw, err := s.Client.Get(ctx, "identity:mirror:state").Result()
|
|
if err == redis.Nil {
|
|
return domain.IdentityCacheStatus{
|
|
Status: "empty",
|
|
RedisReady: true,
|
|
KeyCount: keyCount,
|
|
}, nil
|
|
}
|
|
if err != nil {
|
|
return domain.IdentityCacheStatus{
|
|
Status: "failed",
|
|
RedisReady: true,
|
|
KeyCount: keyCount,
|
|
LastError: err.Error(),
|
|
}, nil
|
|
}
|
|
|
|
var stored identityMirrorStateStore
|
|
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
|
|
return domain.IdentityCacheStatus{
|
|
Status: "failed",
|
|
RedisReady: true,
|
|
KeyCount: keyCount,
|
|
LastError: err.Error(),
|
|
}, nil
|
|
}
|
|
status := stored.Status
|
|
if status == "" {
|
|
status = "unknown"
|
|
}
|
|
return domain.IdentityCacheStatus{
|
|
Status: status,
|
|
RedisReady: true,
|
|
MirrorVersion: stored.MirrorVersion,
|
|
ObservedCount: stored.ObservedCount,
|
|
KeyCount: keyCount,
|
|
LastRefreshedAt: stored.LastRefreshedAt,
|
|
LastError: stored.LastError,
|
|
UpdatedAt: stored.UpdatedAt,
|
|
}, nil
|
|
}
|
|
|
|
func (s *RedisService) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
|
|
if s == nil || s.Client == nil {
|
|
return domain.IdentityCacheFlushResult{}, os.ErrInvalid
|
|
}
|
|
|
|
keys, err := s.identityCacheKeys(ctx)
|
|
if err != nil {
|
|
return domain.IdentityCacheFlushResult{}, err
|
|
}
|
|
var deleted int64
|
|
for len(keys) > 0 {
|
|
chunkSize := len(keys)
|
|
if chunkSize > 500 {
|
|
chunkSize = 500
|
|
}
|
|
chunk := keys[:chunkSize]
|
|
count, err := s.Client.Del(ctx, chunk...).Result()
|
|
if err != nil {
|
|
return domain.IdentityCacheFlushResult{}, err
|
|
}
|
|
deleted += count
|
|
keys = keys[chunkSize:]
|
|
}
|
|
|
|
return domain.IdentityCacheFlushResult{
|
|
Status: "success",
|
|
FlushedKeys: deleted,
|
|
UpdatedAt: time.Now().UTC(),
|
|
}, nil
|
|
}
|
|
|
|
func (s *RedisService) ListIdentityMirrors(ctx context.Context) ([]KratosIdentity, error) {
|
|
if s == nil || s.Client == nil {
|
|
return nil, os.ErrInvalid
|
|
}
|
|
|
|
keys, err := s.identityCacheKeys(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
identities := make([]KratosIdentity, 0, len(keys))
|
|
for _, key := range keys {
|
|
if key == "identity:mirror:state" || !strings.HasPrefix(key, "identity:mirror:") {
|
|
continue
|
|
}
|
|
raw, err := s.Client.Get(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var identity KratosIdentity
|
|
if err := json.Unmarshal([]byte(raw), &identity); err != nil {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(identity.ID) == "" {
|
|
continue
|
|
}
|
|
identities = append(identities, identity)
|
|
}
|
|
return identities, nil
|
|
}
|
|
|
|
func (s *RedisService) StoreIdentityMirror(ctx context.Context, identity KratosIdentity) error {
|
|
if s == nil || s.Client == nil {
|
|
return os.ErrInvalid
|
|
}
|
|
identityID := strings.TrimSpace(identity.ID)
|
|
if identityID == "" {
|
|
return os.ErrInvalid
|
|
}
|
|
raw, err := json.Marshal(identity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := s.Client.Set(ctx, "identity:mirror:"+identityID, string(raw), 0).Err(); err != nil {
|
|
return err
|
|
}
|
|
score := float64(identityMirrorScoreTime(identity).UnixMilli())
|
|
if err := s.Client.ZAdd(ctx, "identity:index:created_at", &redis.Z{
|
|
Score: score,
|
|
Member: identityID,
|
|
}).Err(); err != nil {
|
|
return err
|
|
}
|
|
for _, tenantKey := range identityMirrorTenantKeys(identity.Traits) {
|
|
if err := s.Client.SAdd(ctx, "identity:index:tenant:"+tenantKey, identityID).Err(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisService) ListIdentityMirrorPage(ctx context.Context, query IdentityMirrorPageQuery) (IdentityMirrorPageResult, error) {
|
|
if s == nil || s.Client == nil {
|
|
return IdentityMirrorPageResult{}, os.ErrInvalid
|
|
}
|
|
if query.Limit <= 0 {
|
|
query.Limit = 50
|
|
}
|
|
if query.Offset < 0 {
|
|
query.Offset = 0
|
|
}
|
|
cursor, err := pagination.Decode(query.Cursor)
|
|
if err != nil {
|
|
return IdentityMirrorPageResult{}, err
|
|
}
|
|
search := strings.ToLower(strings.TrimSpace(query.Search))
|
|
targetTenantKeys := identityMirrorTargetTenantKeys(query)
|
|
maxScore := "+inf"
|
|
if cursor != nil {
|
|
maxScore = strconv.FormatInt(cursor.Timestamp.UnixMilli(), 10)
|
|
}
|
|
|
|
const batchSize int64 = 250
|
|
var offset int64
|
|
var total int64
|
|
matched := make([]KratosIdentity, 0, query.Limit+1)
|
|
pageStart := query.Offset
|
|
if cursor != nil {
|
|
pageStart = 0
|
|
}
|
|
|
|
for {
|
|
zItems, err := s.Client.ZRevRangeByScoreWithScores(ctx, "identity:index:created_at", &redis.ZRangeBy{
|
|
Max: maxScore,
|
|
Min: "-inf",
|
|
Offset: offset,
|
|
Count: batchSize,
|
|
}).Result()
|
|
if err != nil {
|
|
return IdentityMirrorPageResult{}, err
|
|
}
|
|
if len(zItems) == 0 {
|
|
break
|
|
}
|
|
keys := make([]string, 0, len(zItems))
|
|
for _, item := range zItems {
|
|
id, ok := item.Member.(string)
|
|
if !ok || strings.TrimSpace(id) == "" {
|
|
continue
|
|
}
|
|
keys = append(keys, "identity:mirror:"+id)
|
|
}
|
|
rawItems, err := s.Client.MGet(ctx, keys...).Result()
|
|
if err != nil {
|
|
return IdentityMirrorPageResult{}, err
|
|
}
|
|
for _, raw := range rawItems {
|
|
rawString, ok := raw.(string)
|
|
if !ok || strings.TrimSpace(rawString) == "" {
|
|
continue
|
|
}
|
|
var identity KratosIdentity
|
|
if err := json.Unmarshal([]byte(rawString), &identity); err != nil {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(identity.ID) == "" {
|
|
continue
|
|
}
|
|
if cursor != nil {
|
|
timestamp, id := identityMirrorCursorKey(identity)
|
|
if !pagination.ComesAfter(timestamp, id, cursor) {
|
|
continue
|
|
}
|
|
}
|
|
if !identityMirrorMatchesTenantScope(identity, targetTenantKeys, query.AllowedTenantKeys) {
|
|
continue
|
|
}
|
|
if !identityMirrorMatchesSearch(identity, search) {
|
|
continue
|
|
}
|
|
if total >= int64(pageStart) && len(matched) < query.Limit+1 {
|
|
matched = append(matched, identity)
|
|
}
|
|
total++
|
|
}
|
|
if len(zItems) < int(batchSize) {
|
|
break
|
|
}
|
|
offset += int64(len(zItems))
|
|
}
|
|
|
|
nextCursor := ""
|
|
items := matched
|
|
if len(matched) > query.Limit {
|
|
items = matched[:query.Limit]
|
|
lastTimestamp, lastID := identityMirrorCursorKey(items[len(items)-1])
|
|
nextCursor = pagination.Encode(lastTimestamp, lastID)
|
|
}
|
|
return IdentityMirrorPageResult{
|
|
Items: items,
|
|
Total: total,
|
|
Cursor: query.Cursor,
|
|
NextCursor: nextCursor,
|
|
}, nil
|
|
}
|
|
|
|
func identityMirrorScoreTime(identity KratosIdentity) time.Time {
|
|
if identity.CreatedAt.IsZero() {
|
|
return time.Unix(0, 0).UTC()
|
|
}
|
|
return identity.CreatedAt.UTC()
|
|
}
|
|
|
|
func identityMirrorCursorKey(identity KratosIdentity) (time.Time, string) {
|
|
return identityMirrorScoreTime(identity), identity.ID
|
|
}
|
|
|
|
func identityMirrorTenantKeys(traits map[string]any) []string {
|
|
keys := make([]string, 0, 4)
|
|
seen := make(map[string]bool)
|
|
appendKey := func(value string) {
|
|
key := strings.ToLower(strings.TrimSpace(value))
|
|
if key == "" || seen[key] {
|
|
return
|
|
}
|
|
seen[key] = true
|
|
keys = append(keys, key)
|
|
}
|
|
appendKey(identityMirrorTraitString(traits, "tenant_id"))
|
|
appendKey(identityMirrorTraitString(traits, "tenantSlug"))
|
|
appointments := identityMirrorAppointments(traits["additionalAppointments"])
|
|
if len(appointments) == 0 {
|
|
if metadata, ok := traits["metadata"].(map[string]any); ok {
|
|
appointments = identityMirrorAppointments(metadata["additionalAppointments"])
|
|
}
|
|
}
|
|
for _, appointment := range appointments {
|
|
appendKey(identityMirrorAnyString(appointment["tenantId"]))
|
|
appendKey(identityMirrorAnyString(appointment["tenantSlug"]))
|
|
appendKey(identityMirrorAnyString(appointment["slug"]))
|
|
}
|
|
return keys
|
|
}
|
|
|
|
func identityMirrorTargetTenantKeys(query IdentityMirrorPageQuery) map[string]bool {
|
|
targets := make(map[string]bool)
|
|
for _, value := range []string{query.TenantID, query.TenantSlug} {
|
|
key := strings.ToLower(strings.TrimSpace(value))
|
|
if key != "" {
|
|
targets[key] = true
|
|
}
|
|
}
|
|
return targets
|
|
}
|
|
|
|
func identityMirrorMatchesTenantScope(identity KratosIdentity, targetTenantKeys map[string]bool, allowedTenantKeys map[string]bool) bool {
|
|
identityKeys := identityMirrorTenantKeys(identity.Traits)
|
|
if len(allowedTenantKeys) > 0 && !identityMirrorAnyKeyAllowed(identityKeys, allowedTenantKeys) {
|
|
return false
|
|
}
|
|
if len(targetTenantKeys) > 0 && !identityMirrorAnyKeyAllowed(identityKeys, targetTenantKeys) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func identityMirrorAnyKeyAllowed(keys []string, allowed map[string]bool) bool {
|
|
for _, key := range keys {
|
|
if allowed[key] {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func identityMirrorMatchesSearch(identity KratosIdentity, search string) bool {
|
|
search = strings.TrimSpace(search)
|
|
if search == "" {
|
|
return true
|
|
}
|
|
values := []string{
|
|
identity.ID,
|
|
identityMirrorTraitString(identity.Traits, "email"),
|
|
identityMirrorTraitString(identity.Traits, "name"),
|
|
identityMirrorTraitString(identity.Traits, "phone_number"),
|
|
identityMirrorTraitString(identity.Traits, "loginId"),
|
|
}
|
|
for _, value := range values {
|
|
if strings.Contains(strings.ToLower(value), search) {
|
|
return true
|
|
}
|
|
}
|
|
rawTraits, err := json.Marshal(identity.Traits)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return strings.Contains(strings.ToLower(string(rawTraits)), search)
|
|
}
|
|
|
|
func identityMirrorTraitString(traits map[string]any, key string) string {
|
|
if traits == nil {
|
|
return ""
|
|
}
|
|
return identityMirrorAnyString(traits[key])
|
|
}
|
|
|
|
func identityMirrorAnyString(value any) string {
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return typed
|
|
case fmt.Stringer:
|
|
return typed.String()
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func identityMirrorAppointments(value any) []map[string]any {
|
|
switch typed := value.(type) {
|
|
case []map[string]any:
|
|
return typed
|
|
case []any:
|
|
result := make([]map[string]any, 0, len(typed))
|
|
for _, item := range typed {
|
|
if appointment, ok := item.(map[string]any); ok {
|
|
result = append(result, appointment)
|
|
}
|
|
}
|
|
return result
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *RedisService) countIdentityCacheKeys(ctx context.Context) (int64, error) {
|
|
keys, err := s.identityCacheKeys(ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int64(len(keys)), nil
|
|
}
|
|
|
|
func (s *RedisService) identityCacheKeys(ctx context.Context) ([]string, error) {
|
|
seen := make(map[string]bool)
|
|
patterns := []string{
|
|
"identity:mirror:*",
|
|
"identity:index:*",
|
|
}
|
|
for _, pattern := range patterns {
|
|
var cursor uint64
|
|
for {
|
|
keys, next, err := s.Client.Scan(ctx, cursor, pattern, 250).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, key := range keys {
|
|
seen[key] = true
|
|
}
|
|
cursor = next
|
|
if cursor == 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
keys := make([]string, 0, len(seen))
|
|
for key := range seen {
|
|
keys = append(keys, key)
|
|
}
|
|
return keys, nil
|
|
}
|