diff --git a/internal/marketing/handler/download_handler.go b/internal/marketing/handler/download_handler.go index 410ef95..80e40cf 100644 --- a/internal/marketing/handler/download_handler.go +++ b/internal/marketing/handler/download_handler.go @@ -29,6 +29,10 @@ func (h *DownloadHandler) Create(c *gin.Context) { dl, err := h.svc.Create(user.ID, req) if err != nil { + if service.IsNotFoundError(err) { + c.JSON(http.StatusNotFound, model.Error(http.StatusNotFound, "模板不存在")) + return + } if service.IsBadRequestError(err) { c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, err.Error())) return diff --git a/internal/marketing/repository/template_repo.go b/internal/marketing/repository/template_repo.go index 8369210..e10d8f5 100644 --- a/internal/marketing/repository/template_repo.go +++ b/internal/marketing/repository/template_repo.go @@ -70,10 +70,10 @@ func (r *TemplateRepository) FindByID(id uint) (*model.MarketingTemplate, error) } type TemplateListParams struct { - CategoryID uint + CategoryID uint OnlyEnabled bool - Page int - PageSize int + Page int + PageSize int } func (r *TemplateRepository) FindList(params TemplateListParams) ([]model.MarketingTemplate, int64, error) { @@ -116,5 +116,8 @@ func (r *TemplateRepository) IncrementDownloadCount(id uint) error { if tx.Error != nil { return fmt.Errorf("increment download count: %w", tx.Error) } + if tx.RowsAffected == 0 { + return ErrTemplateNotFound + } return nil } diff --git a/internal/marketing/service/download_service.go b/internal/marketing/service/download_service.go index b3cb4df..0ca4c4e 100644 --- a/internal/marketing/service/download_service.go +++ b/internal/marketing/service/download_service.go @@ -9,6 +9,7 @@ import ( var ( ErrDownloadTemplateRequired = errors.New("请选择模板") + ErrDownloadTemplateDisabled = errors.New("模板已禁用") ) type DownloadService struct { @@ -44,6 +45,14 @@ func (s *DownloadService) Create(userID uint, req CreateDownloadRequest) (*model return nil, ErrDownloadTemplateRequired } + tpl, err := s.templateRepo.FindByID(req.TemplateID) + if err != nil { + return nil, err + } + if tpl.Status != 1 { + return nil, ErrDownloadTemplateDisabled + } + dl := &model.MarketingDownload{ UserID: userID, TemplateID: req.TemplateID, @@ -58,7 +67,9 @@ func (s *DownloadService) Create(userID uint, req CreateDownloadRequest) (*model return nil, err } - _ = s.templateRepo.IncrementDownloadCount(req.TemplateID) + if err := s.templateRepo.IncrementDownloadCount(req.TemplateID); err != nil { + return nil, err + } return dl, nil } @@ -86,6 +97,7 @@ func (s *DownloadService) GetStats() (*repository.DownloadStats, error) { // IsBadRequestError checks if the error is a client-side validation error. func IsBadRequestError(err error) bool { return errors.Is(err, ErrDownloadTemplateRequired) || + errors.Is(err, ErrDownloadTemplateDisabled) || errors.Is(err, ErrCategoryNameRequired) || errors.Is(err, ErrCategoryNameTooLong) || errors.Is(err, ErrCategoryHasTemplates) || diff --git a/internal/marketing/service/download_service_test.go b/internal/marketing/service/download_service_test.go new file mode 100644 index 0000000..3275f00 --- /dev/null +++ b/internal/marketing/service/download_service_test.go @@ -0,0 +1,133 @@ +package service + +import ( + "errors" + "testing" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "wx_service/internal/marketing/model" + "wx_service/internal/marketing/repository" +) + +func newDownloadTestService(t *testing.T) (*DownloadService, *gorm.DB) { + t.Helper() + + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + + if err := db.AutoMigrate( + &model.MarketingCategory{}, + &model.MarketingTemplate{}, + &model.MarketingDownload{}, + ); err != nil { + t.Fatalf("auto migrate: %v", err) + } + + downloadRepo := repository.NewDownloadRepository(db) + templateRepo := repository.NewTemplateRepository(db) + svc := NewDownloadService(downloadRepo, templateRepo) + return svc, db +} + +func seedTemplate(t *testing.T, db *gorm.DB, status int) *model.MarketingTemplate { + t.Helper() + + cat := &model.MarketingCategory{ + Name: "测试分类", + SortOrder: 1, + Status: 1, + } + if err := db.Create(cat).Error; err != nil { + t.Fatalf("create category: %v", err) + } + + tpl := &model.MarketingTemplate{ + CategoryID: cat.ID, + Title: "测试模板", + ImageURL: "https://example.com/image.jpg", + ThumbnailURL: "https://example.com/thumb.jpg", + Width: 1080, + Height: 1920, + SortOrder: 1, + Status: status, + } + if err := db.Create(tpl).Error; err != nil { + t.Fatalf("create template: %v", err) + } + if status == 0 { + if err := db.Model(&model.MarketingTemplate{}).Where("id = ?", tpl.ID).Update("status", 0).Error; err != nil { + t.Fatalf("update template status: %v", err) + } + } + return tpl +} + +func TestDownloadService_Create_DisabledTemplate(t *testing.T) { + svc, db := newDownloadTestService(t) + tpl := seedTemplate(t, db, 0) + + _, err := svc.Create(100, CreateDownloadRequest{ + TemplateID: tpl.ID, + LogoURL: "https://example.com/logo.png", + LogoX: 0.2, + LogoY: 0.3, + LogoW: 0.4, + LogoH: 0.5, + }) + if !errors.Is(err, ErrDownloadTemplateDisabled) { + t.Fatalf("expected ErrDownloadTemplateDisabled, got %v", err) + } +} + +func TestDownloadService_Create_TemplateNotFound(t *testing.T) { + svc, _ := newDownloadTestService(t) + + _, err := svc.Create(100, CreateDownloadRequest{ + TemplateID: 9999, + LogoURL: "https://example.com/logo.png", + }) + if !errors.Is(err, repository.ErrTemplateNotFound) { + t.Fatalf("expected ErrTemplateNotFound, got %v", err) + } +} + +func TestDownloadService_Create_Success(t *testing.T) { + svc, db := newDownloadTestService(t) + tpl := seedTemplate(t, db, 1) + + record, err := svc.Create(101, CreateDownloadRequest{ + TemplateID: tpl.ID, + LogoURL: "https://example.com/logo.png", + LogoX: 0.12, + LogoY: 0.34, + LogoW: 0.56, + LogoH: 0.78, + }) + if err != nil { + t.Fatalf("create download: %v", err) + } + + if record.ID == 0 { + t.Fatalf("expected created download id > 0") + } + + var storedTpl model.MarketingTemplate + if err := db.First(&storedTpl, tpl.ID).Error; err != nil { + t.Fatalf("reload template: %v", err) + } + if storedTpl.DownloadCount != 1 { + t.Fatalf("expected download_count=1, got %d", storedTpl.DownloadCount) + } + + var storedDl model.MarketingDownload + if err := db.First(&storedDl, record.ID).Error; err != nil { + t.Fatalf("reload download: %v", err) + } + if storedDl.LogoURL != "https://example.com/logo.png" { + t.Fatalf("unexpected logo url: %s", storedDl.LogoURL) + } +}