package service import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "fmt" "strings" "time" "gorm.io/gorm" "gorm.io/gorm/clause" "wx_service/internal/membership/model" usermodel "wx_service/internal/model" ) var ( ErrAdminTokenRequired = errors.New("admin api token is not configured") ErrInvalidAdminToken = errors.New("invalid admin token") ErrRedeemCodeInvalid = errors.New("redeem code is invalid") ErrRedeemCodeExpired = errors.New("redeem code is expired") ErrRedeemCodeUsedUp = errors.New("redeem code is already used") ErrRedeemCodeDisabled = errors.New("redeem code is disabled") ) type RedeemCodeService struct { db *gorm.DB adminToken string } func NewRedeemCodeService(db *gorm.DB, adminToken string) *RedeemCodeService { return &RedeemCodeService{ db: db, adminToken: adminToken, } } type GenerateRedeemCodesRequest struct { Count int Plan string DurationDays int ExpiresAt *time.Time MaxUses int } type GeneratedRedeemCode struct { Code string `json:"code"` Plan string `json:"plan"` } func (s *RedeemCodeService) ValidateAdminToken(token string) error { if s.adminToken == "" { return ErrAdminTokenRequired } if token == "" || token != s.adminToken { return ErrInvalidAdminToken } return nil } func (s *RedeemCodeService) Generate(ctx context.Context, req GenerateRedeemCodesRequest) ([]GeneratedRedeemCode, error) { count := req.Count if count <= 0 { count = 1 } if count > 500 { count = 500 } plan := strings.TrimSpace(req.Plan) if plan == "" { plan = "default" } if req.DurationDays <= 0 { return nil, fmt.Errorf("duration_days must be > 0") } maxUses := req.MaxUses if maxUses <= 0 { maxUses = 1 } results := make([]GeneratedRedeemCode, 0, count) records := make([]model.MembershipRedeemCode, 0, count) // 生成时尽量保证唯一性;如遇碰撞(极低概率)则重试。 for len(records) < count { code, err := generateCode(20) if err != nil { return nil, err } hash := hashCode(code) suffix := suffixOf(code, 6) records = append(records, model.MembershipRedeemCode{ CodeHash: hash, CodeSuffix: suffix, Plan: plan, DurationDays: req.DurationDays, ExpiresAt: req.ExpiresAt, MaxUses: maxUses, UsedUses: 0, Status: "active", }) results = append(results, GeneratedRedeemCode{Code: code, Plan: plan}) } if err := s.db.WithContext(ctx).Create(&records).Error; err != nil { return nil, fmt.Errorf("save redeem codes: %w", err) } return results, nil } type RedeemResult struct { Plan string `json:"plan"` StartsAt time.Time `json:"starts_at"` EndsAt time.Time `json:"ends_at"` Extended bool `json:"extended"` CodeSuffix string `json:"code_suffix"` } func (s *RedeemCodeService) Redeem(ctx context.Context, user *usermodel.User, code string, clientIP string, userAgent string) (*RedeemResult, error) { normalized := normalizeCode(code) if normalized == "" { return nil, ErrRedeemCodeInvalid } hash := hashCode(normalized) now := time.Now() var result *RedeemResult err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var redeemCode model.MembershipRedeemCode if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Where("code_hash = ? AND deleted_at IS NULL", hash). First(&redeemCode).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrRedeemCodeInvalid } return fmt.Errorf("load redeem code: %w", err) } if redeemCode.Status != "" && redeemCode.Status != "active" { return ErrRedeemCodeDisabled } if redeemCode.ExpiresAt != nil && redeemCode.ExpiresAt.Before(now) { return ErrRedeemCodeExpired } if redeemCode.MaxUses <= 0 { redeemCode.MaxUses = 1 } if redeemCode.UsedUses >= redeemCode.MaxUses { return ErrRedeemCodeUsedUp } if redeemCode.DurationDays <= 0 { return fmt.Errorf("redeem code duration_days invalid") } // 兑换:先创建/延长会员,再计数,最后写 redemption log。 var membership usermodel.UserMembership var hasActive bool if err := tx. Where("mini_program_id = ? AND user_id = ? AND status = ? AND ends_at > ?", user.MiniProgramID, user.ID, "active", now). Order("ends_at DESC"). First(&membership).Error; err == nil { hasActive = true } else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("load membership: %w", err) } base := now if hasActive && membership.EndsAt.After(now) { base = membership.EndsAt } newEnds := base.AddDate(0, 0, redeemCode.DurationDays) if hasActive { if err := tx.Model(&membership).Updates(map[string]interface{}{ "ends_at": newEnds, "updated_at": now, }).Error; err != nil { return fmt.Errorf("extend membership: %w", err) } } else { membership = usermodel.UserMembership{ MiniProgramID: user.MiniProgramID, UserID: user.ID, Plan: redeemCode.Plan, Status: "active", StartsAt: now, EndsAt: newEnds, } if err := tx.Create(&membership).Error; err != nil { return fmt.Errorf("create membership: %w", err) } } if err := tx.Model(&redeemCode).UpdateColumn("used_uses", gorm.Expr("used_uses + 1")).Error; err != nil { return fmt.Errorf("update redeem usage: %w", err) } redemption := model.MembershipRedemption{ MiniProgramID: user.MiniProgramID, UserID: user.ID, RedeemCodeID: redeemCode.ID, CodeSuffix: redeemCode.CodeSuffix, MembershipID: membership.ID, ClientIP: truncateString(clientIP, 64), UserAgent: truncateString(userAgent, 255), } if err := tx.Create(&redemption).Error; err != nil { return fmt.Errorf("create redemption log: %w", err) } result = &RedeemResult{ Plan: redeemCode.Plan, StartsAt: membership.StartsAt, EndsAt: newEnds, Extended: hasActive, CodeSuffix: redeemCode.CodeSuffix, } return nil }) if err != nil { return nil, err } return result, nil } func normalizeCode(code string) string { c := strings.TrimSpace(code) c = strings.ReplaceAll(c, "-", "") c = strings.ReplaceAll(c, " ", "") c = strings.ToUpper(c) return c } func hashCode(code string) string { sum := sha256.Sum256([]byte(code)) return hex.EncodeToString(sum[:]) } func suffixOf(code string, n int) string { if n <= 0 { return "" } if len(code) <= n { return code } return code[len(code)-n:] } func generateCode(length int) (string, error) { if length <= 0 { length = 16 } // 去掉易混淆字符:0/O, 1/I/L const alphabet = "ABCDEFGHJKMNPQRSTUVWXYZ23456789" buf := make([]byte, length) if _, err := rand.Read(buf); err != nil { return "", fmt.Errorf("rand: %w", err) } out := make([]byte, length) for i, b := range buf { out[i] = alphabet[int(b)%len(alphabet)] } return string(out), nil } func truncateString(s string, max int) string { if max <= 0 || len(s) <= max { return s } return s[:max] }