diff --git a/go.mod b/go.mod index d7bdb04..d8960dc 100755 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 32f787c..0cce27b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -48,6 +50,7 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= diff --git a/internal/smoke/service/membership_gate.go b/internal/smoke/service/membership_gate.go new file mode 100644 index 0000000..759b51d --- /dev/null +++ b/internal/smoke/service/membership_gate.go @@ -0,0 +1,23 @@ +package service + +import ( + "context" + "fmt" + "time" + + usermodel "wx_service/internal/model" + + "gorm.io/gorm" +) + +func hasActiveMembership(ctx context.Context, db *gorm.DB, miniProgramID uint, userID uint, now time.Time) (bool, error) { + var count int64 + if err := db.WithContext(ctx). + Model(&usermodel.UserMembership{}). + Where("mini_program_id = ? AND user_id = ? AND status = ? AND ends_at > ?", + miniProgramID, userID, "active", now). + Count(&count).Error; err != nil { + return false, fmt.Errorf("check membership: %w", err) + } + return count > 0, nil +} diff --git a/internal/smoke/service/membership_gate_test.go b/internal/smoke/service/membership_gate_test.go new file mode 100644 index 0000000..f68522a --- /dev/null +++ b/internal/smoke/service/membership_gate_test.go @@ -0,0 +1,104 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func newMockGormDB(t *testing.T) (*gorm.DB, sqlmock.Sqlmock, func()) { + t.Helper() + + sqlDB, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + + gdb, err := gorm.Open(mysql.New(mysql.Config{ + Conn: sqlDB, + SkipInitializeWithVersion: true, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + _ = sqlDB.Close() + t.Fatalf("gorm.Open: %v", err) + } + + cleanup := func() { + _ = sqlDB.Close() + } + return gdb, mock, cleanup +} + +func TestHasActiveMembership(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 2, 28, 16, 0, 0, 0, time.Local) + db, mock, cleanup := newMockGormDB(t) + defer cleanup() + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM `user_memberships`"). + WithArgs(uint(100), uint(200), "active", now). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + ok, err := hasActiveMembership(context.Background(), db, 100, 200, now) + if err != nil { + t.Fatalf("hasActiveMembership: %v", err) + } + if !ok { + t.Fatalf("hasActiveMembership got=false, want=true") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + +func TestHasActiveMembershipNotFound(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 2, 28, 16, 0, 0, 0, time.Local) + db, mock, cleanup := newMockGormDB(t) + defer cleanup() + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM `user_memberships`"). + WithArgs(uint(101), uint(201), "active", now). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + ok, err := hasActiveMembership(context.Background(), db, 101, 201, now) + if err != nil { + t.Fatalf("hasActiveMembership: %v", err) + } + if ok { + t.Fatalf("hasActiveMembership got=true, want=false") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + +func TestHasActiveMembershipDBError(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 2, 28, 16, 0, 0, 0, time.Local) + db, mock, cleanup := newMockGormDB(t) + defer cleanup() + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM `user_memberships`"). + WithArgs(uint(102), uint(202), "active", now). + WillReturnError(errors.New("db unavailable")) + + _, err := hasActiveMembership(context.Background(), db, 102, 202, now) + if err == nil { + t.Fatalf("expected error when query fails") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} diff --git a/internal/smoke/service/smoke_ai_advice_service.go b/internal/smoke/service/smoke_ai_advice_service.go index d37bb27..38c4ab0 100644 --- a/internal/smoke/service/smoke_ai_advice_service.go +++ b/internal/smoke/service/smoke_ai_advice_service.go @@ -189,7 +189,7 @@ func (s *SmokeAIAdviceService) getCached(ctx context.Context, uid int, adviceTyp } func (s *SmokeAIAdviceService) isAllowed(ctx context.Context, user *usermodel.User, adviceDate time.Time) (bool, error) { - isVIP, err := s.isVIP(ctx, user) + isVIP, err := hasActiveMembership(ctx, s.db, user.MiniProgramID, user.ID, time.Now()) if err != nil { return false, err } @@ -199,19 +199,6 @@ func (s *SmokeAIAdviceService) isAllowed(ctx context.Context, user *usermodel.Us return s.isUnlocked(ctx, int(user.ID), adviceDate) } -func (s *SmokeAIAdviceService) isVIP(ctx context.Context, user *usermodel.User) (bool, error) { - now := time.Now() - var count int64 - if err := s.db.WithContext(ctx). - Model(&usermodel.UserMembership{}). - Where("mini_program_id = ? AND user_id = ? AND status = ? AND ends_at > ?", - user.MiniProgramID, user.ID, "active", now). - Count(&count).Error; err != nil { - return false, fmt.Errorf("check vip: %w", err) - } - return count > 0, nil -} - func (s *SmokeAIAdviceService) isUnlocked(ctx context.Context, uid int, adviceDate time.Time) (bool, error) { startOfDay := dateOnly(adviceDate) var unlock smokemodel.SmokeAIAdviceUnlock diff --git a/internal/smoke/service/smoke_ai_next_smoke_service.go b/internal/smoke/service/smoke_ai_next_smoke_service.go index a3d626b..5f85add 100644 --- a/internal/smoke/service/smoke_ai_next_smoke_service.go +++ b/internal/smoke/service/smoke_ai_next_smoke_service.go @@ -538,7 +538,7 @@ func (s *SmokeAINextSmokeService) loadRecent3Days(ctx context.Context, uid int, } func (s *SmokeAINextSmokeService) isAllowed(ctx context.Context, user *usermodel.User, planDate time.Time) (bool, error) { - isVIP, err := s.isVIP(ctx, user) + isVIP, err := hasActiveMembership(ctx, s.db, user.MiniProgramID, user.ID, time.Now()) if err != nil { return false, err } @@ -548,19 +548,6 @@ func (s *SmokeAINextSmokeService) isAllowed(ctx context.Context, user *usermodel return s.isUnlocked(ctx, int(user.ID), planDate) } -func (s *SmokeAINextSmokeService) isVIP(ctx context.Context, user *usermodel.User) (bool, error) { - now := time.Now() - var count int64 - if err := s.db.WithContext(ctx). - Model(&usermodel.UserMembership{}). - Where("mini_program_id = ? AND user_id = ? AND status = ? AND ends_at > ?", - user.MiniProgramID, user.ID, "active", now). - Count(&count).Error; err != nil { - return false, fmt.Errorf("check vip: %w", err) - } - return count > 0, nil -} - func (s *SmokeAINextSmokeService) isUnlocked(ctx context.Context, uid int, planDate time.Time) (bool, error) { startOfDay := dateOnly(planDate) var unlock smokemodel.SmokeAIAdviceUnlock diff --git a/internal/smoke/service/smoke_membership_gate_test.go b/internal/smoke/service/smoke_membership_gate_test.go new file mode 100644 index 0000000..de26245 --- /dev/null +++ b/internal/smoke/service/smoke_membership_gate_test.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + usermodel "wx_service/internal/model" +) + +func TestSmokeAIAdviceServiceIsAllowedMember(t *testing.T) { + t.Parallel() + + db, mock, cleanup := newMockGormDB(t) + defer cleanup() + + svc := &SmokeAIAdviceService{db: db} + user := &usermodel.User{ID: 200, MiniProgramID: 100} + adviceDate := time.Date(2026, 3, 1, 0, 0, 0, 0, time.Local) + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM `user_memberships`"). + WithArgs(uint(100), uint(200), "active", sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + allowed, err := svc.isAllowed(context.Background(), user, adviceDate) + if err != nil { + t.Fatalf("isAllowed: %v", err) + } + if !allowed { + t.Fatalf("isAllowed got=false, want=true for active member") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + +func TestSmokeAIAdviceServiceIsAllowedNonMemberLocked(t *testing.T) { + t.Parallel() + + db, mock, cleanup := newMockGormDB(t) + defer cleanup() + + svc := &SmokeAIAdviceService{db: db} + user := &usermodel.User{ID: 201, MiniProgramID: 101} + adviceDate := time.Date(2026, 3, 1, 0, 0, 0, 0, time.Local) + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM `user_memberships`"). + WithArgs(uint(101), uint(201), "active", sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectQuery("SELECT \\* FROM `fa_smoke_ai_advice_unlocks`"). + WithArgs(201, adviceDate.Format("2006-01-02"), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "uid", "unlock_date", "ad_watched_at"})) + + allowed, err := svc.isAllowed(context.Background(), user, adviceDate) + if err != nil { + t.Fatalf("isAllowed: %v", err) + } + if allowed { + t.Fatalf("isAllowed got=true, want=false for non-member locked user") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + +func TestSmokeAINextSmokeServiceIsAllowedNonMemberUnlocked(t *testing.T) { + t.Parallel() + + db, mock, cleanup := newMockGormDB(t) + defer cleanup() + + svc := &SmokeAINextSmokeService{db: db} + user := &usermodel.User{ID: 202, MiniProgramID: 102} + planDate := time.Date(2026, 3, 1, 0, 0, 0, 0, time.Local) + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM `user_memberships`"). + WithArgs(uint(102), uint(202), "active", sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectQuery("SELECT \\* FROM `fa_smoke_ai_advice_unlocks`"). + WithArgs(202, planDate.Format("2006-01-02"), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "uid", "unlock_date", "ad_watched_at"}). + AddRow(1, 202, planDate, time.Now())) + + allowed, err := svc.isAllowed(context.Background(), user, planDate) + if err != nil { + t.Fatalf("isAllowed: %v", err) + } + if !allowed { + t.Fatalf("isAllowed got=false, want=true for non-member unlocked user") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +}