feat: 完成 #38 营销图CRUD接口下载校验与测试

This commit is contained in:
root
2026-03-09 19:17:01 +08:00
parent e14255cf64
commit efff6eb7d4
4 changed files with 156 additions and 4 deletions
@@ -29,6 +29,10 @@ func (h *DownloadHandler) Create(c *gin.Context) {
dl, err := h.svc.Create(user.ID, req)
if err != nil {
if service.IsNotFoundError(err) {
c.JSON(http.StatusNotFound, model.Error(http.StatusNotFound, "模板不存在"))
return
}
if service.IsBadRequestError(err) {
c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, err.Error()))
return
@@ -116,5 +116,8 @@ func (r *TemplateRepository) IncrementDownloadCount(id uint) error {
if tx.Error != nil {
return fmt.Errorf("increment download count: %w", tx.Error)
}
if tx.RowsAffected == 0 {
return ErrTemplateNotFound
}
return nil
}
+13 -1
View File
@@ -9,6 +9,7 @@ import (
var (
ErrDownloadTemplateRequired = errors.New("请选择模板")
ErrDownloadTemplateDisabled = errors.New("模板已禁用")
)
type DownloadService struct {
@@ -44,6 +45,14 @@ func (s *DownloadService) Create(userID uint, req CreateDownloadRequest) (*model
return nil, ErrDownloadTemplateRequired
}
tpl, err := s.templateRepo.FindByID(req.TemplateID)
if err != nil {
return nil, err
}
if tpl.Status != 1 {
return nil, ErrDownloadTemplateDisabled
}
dl := &model.MarketingDownload{
UserID: userID,
TemplateID: req.TemplateID,
@@ -58,7 +67,9 @@ func (s *DownloadService) Create(userID uint, req CreateDownloadRequest) (*model
return nil, err
}
_ = s.templateRepo.IncrementDownloadCount(req.TemplateID)
if err := s.templateRepo.IncrementDownloadCount(req.TemplateID); err != nil {
return nil, err
}
return dl, nil
}
@@ -86,6 +97,7 @@ func (s *DownloadService) GetStats() (*repository.DownloadStats, error) {
// IsBadRequestError checks if the error is a client-side validation error.
func IsBadRequestError(err error) bool {
return errors.Is(err, ErrDownloadTemplateRequired) ||
errors.Is(err, ErrDownloadTemplateDisabled) ||
errors.Is(err, ErrCategoryNameRequired) ||
errors.Is(err, ErrCategoryNameTooLong) ||
errors.Is(err, ErrCategoryHasTemplates) ||
@@ -0,0 +1,133 @@
package service
import (
"errors"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"wx_service/internal/marketing/model"
"wx_service/internal/marketing/repository"
)
func newDownloadTestService(t *testing.T) (*DownloadService, *gorm.DB) {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(
&model.MarketingCategory{},
&model.MarketingTemplate{},
&model.MarketingDownload{},
); err != nil {
t.Fatalf("auto migrate: %v", err)
}
downloadRepo := repository.NewDownloadRepository(db)
templateRepo := repository.NewTemplateRepository(db)
svc := NewDownloadService(downloadRepo, templateRepo)
return svc, db
}
func seedTemplate(t *testing.T, db *gorm.DB, status int) *model.MarketingTemplate {
t.Helper()
cat := &model.MarketingCategory{
Name: "测试分类",
SortOrder: 1,
Status: 1,
}
if err := db.Create(cat).Error; err != nil {
t.Fatalf("create category: %v", err)
}
tpl := &model.MarketingTemplate{
CategoryID: cat.ID,
Title: "测试模板",
ImageURL: "https://example.com/image.jpg",
ThumbnailURL: "https://example.com/thumb.jpg",
Width: 1080,
Height: 1920,
SortOrder: 1,
Status: status,
}
if err := db.Create(tpl).Error; err != nil {
t.Fatalf("create template: %v", err)
}
if status == 0 {
if err := db.Model(&model.MarketingTemplate{}).Where("id = ?", tpl.ID).Update("status", 0).Error; err != nil {
t.Fatalf("update template status: %v", err)
}
}
return tpl
}
func TestDownloadService_Create_DisabledTemplate(t *testing.T) {
svc, db := newDownloadTestService(t)
tpl := seedTemplate(t, db, 0)
_, err := svc.Create(100, CreateDownloadRequest{
TemplateID: tpl.ID,
LogoURL: "https://example.com/logo.png",
LogoX: 0.2,
LogoY: 0.3,
LogoW: 0.4,
LogoH: 0.5,
})
if !errors.Is(err, ErrDownloadTemplateDisabled) {
t.Fatalf("expected ErrDownloadTemplateDisabled, got %v", err)
}
}
func TestDownloadService_Create_TemplateNotFound(t *testing.T) {
svc, _ := newDownloadTestService(t)
_, err := svc.Create(100, CreateDownloadRequest{
TemplateID: 9999,
LogoURL: "https://example.com/logo.png",
})
if !errors.Is(err, repository.ErrTemplateNotFound) {
t.Fatalf("expected ErrTemplateNotFound, got %v", err)
}
}
func TestDownloadService_Create_Success(t *testing.T) {
svc, db := newDownloadTestService(t)
tpl := seedTemplate(t, db, 1)
record, err := svc.Create(101, CreateDownloadRequest{
TemplateID: tpl.ID,
LogoURL: "https://example.com/logo.png",
LogoX: 0.12,
LogoY: 0.34,
LogoW: 0.56,
LogoH: 0.78,
})
if err != nil {
t.Fatalf("create download: %v", err)
}
if record.ID == 0 {
t.Fatalf("expected created download id > 0")
}
var storedTpl model.MarketingTemplate
if err := db.First(&storedTpl, tpl.ID).Error; err != nil {
t.Fatalf("reload template: %v", err)
}
if storedTpl.DownloadCount != 1 {
t.Fatalf("expected download_count=1, got %d", storedTpl.DownloadCount)
}
var storedDl model.MarketingDownload
if err := db.First(&storedDl, record.ID).Error; err != nil {
t.Fatalf("reload download: %v", err)
}
if storedDl.LogoURL != "https://example.com/logo.png" {
t.Fatalf("unexpected logo url: %s", storedDl.LogoURL)
}
}