package service import ( "errors" "testing" "gorm.io/driver/sqlite" "gorm.io/gorm" "wx_service/internal/marketing/model" "wx_service/internal/marketing/repository" ) func newDownloadTestService(t *testing.T) (*DownloadService, *gorm.DB) { t.Helper() db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate( &model.MarketingCategory{}, &model.MarketingTemplate{}, &model.MarketingDownload{}, ); err != nil { t.Fatalf("auto migrate: %v", err) } downloadRepo := repository.NewDownloadRepository(db) templateRepo := repository.NewTemplateRepository(db) svc := NewDownloadService(downloadRepo, templateRepo) return svc, db } func seedTemplate(t *testing.T, db *gorm.DB, status int) *model.MarketingTemplate { t.Helper() cat := &model.MarketingCategory{ Name: "测试分类", SortOrder: 1, Status: 1, } if err := db.Create(cat).Error; err != nil { t.Fatalf("create category: %v", err) } tpl := &model.MarketingTemplate{ CategoryID: cat.ID, Title: "测试模板", ImageURL: "https://example.com/image.jpg", ThumbnailURL: "https://example.com/thumb.jpg", Width: 1080, Height: 1920, SortOrder: 1, Status: status, } if err := db.Create(tpl).Error; err != nil { t.Fatalf("create template: %v", err) } if status == 0 { if err := db.Model(&model.MarketingTemplate{}).Where("id = ?", tpl.ID).Update("status", 0).Error; err != nil { t.Fatalf("update template status: %v", err) } } return tpl } func TestDownloadService_Create_DisabledTemplate(t *testing.T) { svc, db := newDownloadTestService(t) tpl := seedTemplate(t, db, 0) _, err := svc.Create(100, CreateDownloadRequest{ TemplateID: tpl.ID, LogoURL: "https://example.com/logo.png", LogoX: 0.2, LogoY: 0.3, LogoW: 0.4, LogoH: 0.5, }) if !errors.Is(err, ErrDownloadTemplateDisabled) { t.Fatalf("expected ErrDownloadTemplateDisabled, got %v", err) } } func TestDownloadService_Create_TemplateNotFound(t *testing.T) { svc, _ := newDownloadTestService(t) _, err := svc.Create(100, CreateDownloadRequest{ TemplateID: 9999, LogoURL: "https://example.com/logo.png", }) if !errors.Is(err, repository.ErrTemplateNotFound) { t.Fatalf("expected ErrTemplateNotFound, got %v", err) } } func TestDownloadService_Create_Success(t *testing.T) { svc, db := newDownloadTestService(t) tpl := seedTemplate(t, db, 1) record, err := svc.Create(101, CreateDownloadRequest{ TemplateID: tpl.ID, LogoURL: "https://example.com/logo.png", LogoX: 0.12, LogoY: 0.34, LogoW: 0.56, LogoH: 0.78, }) if err != nil { t.Fatalf("create download: %v", err) } if record.ID == 0 { t.Fatalf("expected created download id > 0") } var storedTpl model.MarketingTemplate if err := db.First(&storedTpl, tpl.ID).Error; err != nil { t.Fatalf("reload template: %v", err) } if storedTpl.DownloadCount != 1 { t.Fatalf("expected download_count=1, got %d", storedTpl.DownloadCount) } var storedDl model.MarketingDownload if err := db.First(&storedDl, record.ID).Error; err != nil { t.Fatalf("reload download: %v", err) } if storedDl.LogoURL != "https://example.com/logo.png" { t.Fatalf("unexpected logo url: %s", storedDl.LogoURL) } }