forked from baron/baron-sso
240 lines
6.4 KiB
Go
240 lines
6.4 KiB
Go
package service
|
|
|
|
import (
|
|
"baron-sso-backend/internal/domain"
|
|
"context"
|
|
"sort"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type DeveloperService struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewDeveloperService(db *gorm.DB) *DeveloperService {
|
|
return &DeveloperService{db: db}
|
|
}
|
|
|
|
func normalizeDeveloperAccessPages(pages []string) []string {
|
|
seen := make(map[string]struct{})
|
|
normalized := make([]string, 0, len(pages))
|
|
|
|
add := func(page string) {
|
|
page = strings.ToLower(strings.TrimSpace(page))
|
|
if page == "" {
|
|
return
|
|
}
|
|
if page == domain.DeveloperAccessPageAll {
|
|
normalized = []string{domain.DeveloperAccessPageAll}
|
|
seen = map[string]struct{}{domain.DeveloperAccessPageAll: struct{}{}}
|
|
return
|
|
}
|
|
if page != domain.DeveloperAccessPageOverview &&
|
|
page != domain.DeveloperAccessPageClientCreate &&
|
|
page != domain.DeveloperAccessPageAudit {
|
|
return
|
|
}
|
|
if _, exists := seen[page]; exists {
|
|
return
|
|
}
|
|
seen[page] = struct{}{}
|
|
normalized = append(normalized, page)
|
|
}
|
|
|
|
for _, page := range pages {
|
|
add(page)
|
|
if len(normalized) == 1 && normalized[0] == domain.DeveloperAccessPageAll {
|
|
return normalized
|
|
}
|
|
}
|
|
|
|
if len(normalized) == 0 {
|
|
return []string{domain.DeveloperAccessPageAll}
|
|
}
|
|
|
|
sort.SliceStable(normalized, func(i, j int) bool {
|
|
return accessPageSortIndex(normalized[i]) < accessPageSortIndex(normalized[j])
|
|
})
|
|
|
|
return normalized
|
|
}
|
|
|
|
func accessPageSortIndex(page string) int {
|
|
switch page {
|
|
case domain.DeveloperAccessPageOverview:
|
|
return 0
|
|
case domain.DeveloperAccessPageClientCreate:
|
|
return 1
|
|
case domain.DeveloperAccessPageAudit:
|
|
return 2
|
|
default:
|
|
return 99
|
|
}
|
|
}
|
|
|
|
func accessPagesOverlap(left, right []string) bool {
|
|
if len(left) == 0 || len(right) == 0 {
|
|
return false
|
|
}
|
|
|
|
leftSet := make(map[string]struct{}, len(left))
|
|
for _, page := range normalizeDeveloperAccessPages(left) {
|
|
if page == domain.DeveloperAccessPageAll {
|
|
return true
|
|
}
|
|
leftSet[page] = struct{}{}
|
|
}
|
|
|
|
for _, page := range normalizeDeveloperAccessPages(right) {
|
|
if page == domain.DeveloperAccessPageAll {
|
|
return true
|
|
}
|
|
if _, ok := leftSet[page]; ok {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func unionDeveloperAccessPages(requests []domain.DeveloperRequest, statuses ...string) []string {
|
|
statusSet := make(map[string]struct{}, len(statuses))
|
|
for _, status := range statuses {
|
|
if trimmed := strings.TrimSpace(status); trimmed != "" {
|
|
statusSet[trimmed] = struct{}{}
|
|
}
|
|
}
|
|
|
|
acc := make(map[string]struct{})
|
|
for _, req := range requests {
|
|
if len(statusSet) > 0 {
|
|
if _, ok := statusSet[strings.TrimSpace(req.Status)]; !ok {
|
|
continue
|
|
}
|
|
}
|
|
pages := normalizeDeveloperAccessPages(req.AccessPages)
|
|
for _, page := range pages {
|
|
acc[page] = struct{}{}
|
|
}
|
|
}
|
|
|
|
if len(acc) == 0 {
|
|
return nil
|
|
}
|
|
|
|
result := make([]string, 0, len(acc))
|
|
if _, ok := acc[domain.DeveloperAccessPageAll]; ok {
|
|
return []string{domain.DeveloperAccessPageAll}
|
|
}
|
|
for _, page := range domain.DeveloperAccessPageOrder {
|
|
if _, ok := acc[page]; ok {
|
|
result = append(result, page)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *DeveloperService) RequestAccess(ctx context.Context, req domain.DeveloperRequest) error {
|
|
req.AccessPages = normalizeDeveloperAccessPages(req.AccessPages)
|
|
// Check if there is already a pending request
|
|
var existing []domain.DeveloperRequest
|
|
err := s.db.WithContext(ctx).
|
|
Where("user_id = ? AND tenant_id = ? AND status = ?", req.UserID, req.TenantID, domain.DeveloperRequestStatusPending).
|
|
Order("created_at DESC").
|
|
Find(&existing).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, current := range existing {
|
|
if accessPagesOverlap(current.AccessPages, req.AccessPages) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return s.db.WithContext(ctx).Create(&req).Error
|
|
}
|
|
|
|
func (s *DeveloperService) CreateGrant(ctx context.Context, req domain.DeveloperRequest) error {
|
|
req.AccessPages = normalizeDeveloperAccessPages(req.AccessPages)
|
|
return s.db.WithContext(ctx).Create(&req).Error
|
|
}
|
|
|
|
func (s *DeveloperService) GetRequestStatus(ctx context.Context, userID, tenantID string) (*domain.DeveloperAccessStatus, error) {
|
|
var requests []domain.DeveloperRequest
|
|
err := s.db.WithContext(ctx).
|
|
Where("user_id = ? AND tenant_id = ?", userID, tenantID).
|
|
Order("created_at DESC").
|
|
Find(&requests).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(requests) == 0 {
|
|
return &domain.DeveloperAccessStatus{Status: "none"}, nil
|
|
}
|
|
|
|
approvedPages := unionDeveloperAccessPages(requests, domain.DeveloperRequestStatusApproved)
|
|
pendingPages := unionDeveloperAccessPages(requests, domain.DeveloperRequestStatusPending)
|
|
|
|
status := "none"
|
|
switch {
|
|
case len(approvedPages) > 0:
|
|
status = domain.DeveloperRequestStatusApproved
|
|
case len(pendingPages) > 0:
|
|
status = domain.DeveloperRequestStatusPending
|
|
}
|
|
|
|
return &domain.DeveloperAccessStatus{
|
|
Status: status,
|
|
ApprovedPages: approvedPages,
|
|
PendingPages: pendingPages,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DeveloperService) GetRequestByID(ctx context.Context, id uint) (*domain.DeveloperRequest, error) {
|
|
var req domain.DeveloperRequest
|
|
err := s.db.WithContext(ctx).First(&req, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &req, nil
|
|
}
|
|
|
|
func (s *DeveloperService) ListRequests(ctx context.Context, userID, status, tenantID string) ([]domain.DeveloperRequest, error) {
|
|
var requests []domain.DeveloperRequest
|
|
query := s.db.WithContext(ctx)
|
|
if userID != "" {
|
|
query = query.Where("user_id = ?", userID)
|
|
}
|
|
if status != "" {
|
|
query = query.Where("status = ?", status)
|
|
}
|
|
if tenantID != "" {
|
|
query = query.Where("tenant_id = ?", tenantID)
|
|
}
|
|
err := query.Order("created_at DESC").Find(&requests).Error
|
|
return requests, err
|
|
}
|
|
|
|
func (s *DeveloperService) ApproveRequest(ctx context.Context, id uint, adminNotes string) error {
|
|
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
|
|
"status": domain.DeveloperRequestStatusApproved,
|
|
"admin_notes": adminNotes,
|
|
}).Error
|
|
}
|
|
|
|
func (s *DeveloperService) RejectRequest(ctx context.Context, id uint, adminNotes string) error {
|
|
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
|
|
"status": domain.DeveloperRequestStatusRejected,
|
|
"admin_notes": adminNotes,
|
|
}).Error
|
|
}
|
|
|
|
func (s *DeveloperService) CancelApprovedRequest(ctx context.Context, id uint, adminNotes string) error {
|
|
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
|
|
"status": domain.DeveloperRequestStatusCancelled,
|
|
"admin_notes": adminNotes,
|
|
}).Error
|
|
}
|