1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/service/developer_service.go

240 lines
6.4 KiB
Go

package service
import (
"baron-sso-backend/internal/domain"
"context"
"sort"
"strings"
"gorm.io/gorm"
)
type DeveloperService struct {
db *gorm.DB
}
func NewDeveloperService(db *gorm.DB) *DeveloperService {
return &DeveloperService{db: db}
}
func normalizeDeveloperAccessPages(pages []string) []string {
seen := make(map[string]struct{})
normalized := make([]string, 0, len(pages))
add := func(page string) {
page = strings.ToLower(strings.TrimSpace(page))
if page == "" {
return
}
if page == domain.DeveloperAccessPageAll {
normalized = []string{domain.DeveloperAccessPageAll}
seen = map[string]struct{}{domain.DeveloperAccessPageAll: struct{}{}}
return
}
if page != domain.DeveloperAccessPageOverview &&
page != domain.DeveloperAccessPageClientCreate &&
page != domain.DeveloperAccessPageAudit {
return
}
if _, exists := seen[page]; exists {
return
}
seen[page] = struct{}{}
normalized = append(normalized, page)
}
for _, page := range pages {
add(page)
if len(normalized) == 1 && normalized[0] == domain.DeveloperAccessPageAll {
return normalized
}
}
if len(normalized) == 0 {
return []string{domain.DeveloperAccessPageAll}
}
sort.SliceStable(normalized, func(i, j int) bool {
return accessPageSortIndex(normalized[i]) < accessPageSortIndex(normalized[j])
})
return normalized
}
func accessPageSortIndex(page string) int {
switch page {
case domain.DeveloperAccessPageOverview:
return 0
case domain.DeveloperAccessPageClientCreate:
return 1
case domain.DeveloperAccessPageAudit:
return 2
default:
return 99
}
}
func accessPagesOverlap(left, right []string) bool {
if len(left) == 0 || len(right) == 0 {
return false
}
leftSet := make(map[string]struct{}, len(left))
for _, page := range normalizeDeveloperAccessPages(left) {
if page == domain.DeveloperAccessPageAll {
return true
}
leftSet[page] = struct{}{}
}
for _, page := range normalizeDeveloperAccessPages(right) {
if page == domain.DeveloperAccessPageAll {
return true
}
if _, ok := leftSet[page]; ok {
return true
}
}
return false
}
func unionDeveloperAccessPages(requests []domain.DeveloperRequest, statuses ...string) []string {
statusSet := make(map[string]struct{}, len(statuses))
for _, status := range statuses {
if trimmed := strings.TrimSpace(status); trimmed != "" {
statusSet[trimmed] = struct{}{}
}
}
acc := make(map[string]struct{})
for _, req := range requests {
if len(statusSet) > 0 {
if _, ok := statusSet[strings.TrimSpace(req.Status)]; !ok {
continue
}
}
pages := normalizeDeveloperAccessPages(req.AccessPages)
for _, page := range pages {
acc[page] = struct{}{}
}
}
if len(acc) == 0 {
return nil
}
result := make([]string, 0, len(acc))
if _, ok := acc[domain.DeveloperAccessPageAll]; ok {
return []string{domain.DeveloperAccessPageAll}
}
for _, page := range domain.DeveloperAccessPageOrder {
if _, ok := acc[page]; ok {
result = append(result, page)
}
}
return result
}
func (s *DeveloperService) RequestAccess(ctx context.Context, req domain.DeveloperRequest) error {
req.AccessPages = normalizeDeveloperAccessPages(req.AccessPages)
// Check if there is already a pending request
var existing []domain.DeveloperRequest
err := s.db.WithContext(ctx).
Where("user_id = ? AND tenant_id = ? AND status = ?", req.UserID, req.TenantID, domain.DeveloperRequestStatusPending).
Order("created_at DESC").
Find(&existing).Error
if err != nil {
return err
}
for _, current := range existing {
if accessPagesOverlap(current.AccessPages, req.AccessPages) {
return nil
}
}
return s.db.WithContext(ctx).Create(&req).Error
}
func (s *DeveloperService) CreateGrant(ctx context.Context, req domain.DeveloperRequest) error {
req.AccessPages = normalizeDeveloperAccessPages(req.AccessPages)
return s.db.WithContext(ctx).Create(&req).Error
}
func (s *DeveloperService) GetRequestStatus(ctx context.Context, userID, tenantID string) (*domain.DeveloperAccessStatus, error) {
var requests []domain.DeveloperRequest
err := s.db.WithContext(ctx).
Where("user_id = ? AND tenant_id = ?", userID, tenantID).
Order("created_at DESC").
Find(&requests).Error
if err != nil {
return nil, err
}
if len(requests) == 0 {
return &domain.DeveloperAccessStatus{Status: "none"}, nil
}
approvedPages := unionDeveloperAccessPages(requests, domain.DeveloperRequestStatusApproved)
pendingPages := unionDeveloperAccessPages(requests, domain.DeveloperRequestStatusPending)
status := "none"
switch {
case len(approvedPages) > 0:
status = domain.DeveloperRequestStatusApproved
case len(pendingPages) > 0:
status = domain.DeveloperRequestStatusPending
}
return &domain.DeveloperAccessStatus{
Status: status,
ApprovedPages: approvedPages,
PendingPages: pendingPages,
}, nil
}
func (s *DeveloperService) GetRequestByID(ctx context.Context, id uint) (*domain.DeveloperRequest, error) {
var req domain.DeveloperRequest
err := s.db.WithContext(ctx).First(&req, id).Error
if err != nil {
return nil, err
}
return &req, nil
}
func (s *DeveloperService) ListRequests(ctx context.Context, userID, status, tenantID string) ([]domain.DeveloperRequest, error) {
var requests []domain.DeveloperRequest
query := s.db.WithContext(ctx)
if userID != "" {
query = query.Where("user_id = ?", userID)
}
if status != "" {
query = query.Where("status = ?", status)
}
if tenantID != "" {
query = query.Where("tenant_id = ?", tenantID)
}
err := query.Order("created_at DESC").Find(&requests).Error
return requests, err
}
func (s *DeveloperService) ApproveRequest(ctx context.Context, id uint, adminNotes string) error {
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.DeveloperRequestStatusApproved,
"admin_notes": adminNotes,
}).Error
}
func (s *DeveloperService) RejectRequest(ctx context.Context, id uint, adminNotes string) error {
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.DeveloperRequestStatusRejected,
"admin_notes": adminNotes,
}).Error
}
func (s *DeveloperService) CancelApprovedRequest(ctx context.Context, id uint, adminNotes string) error {
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.DeveloperRequestStatusCancelled,
"admin_notes": adminNotes,
}).Error
}