From 1c0aeb152ab8f9b725db7d76b06bab5fa2ba7ae1 Mon Sep 17 00:00:00 2001 From: nepiedg Date: Sat, 4 Apr 2026 04:02:06 +0800 Subject: [PATCH] feat(marketing): user profile update, ad placement management, logo limits - Add PUT /auth/profile endpoint for nickname and avatar updates - Add ad_placements table and CRUD admin API for managing ad units - Add GET /marketing/ad-config public API for mini-program to fetch ad config - Reduce logo limit from 10 to 3 per user, add 2MB file size validation Made-with: Cursor --- cmd/api/main.go | 4 + internal/common/auth/handler/auth_handler.go | 33 ++++ internal/common/auth/service/auth_service.go | 25 +++ .../marketing/handler/ad_placement_handler.go | 157 ++++++++++++++++++ .../marketing/handler/user_logo_handler.go | 2 +- internal/marketing/model/ad_placement.go | 19 +++ .../marketing/repository/ad_placement_repo.go | 76 +++++++++ .../marketing/service/user_logo_service.go | 12 +- internal/routes/admin_routes.go | 15 +- internal/routes/marketing_routes.go | 5 + internal/routes/routes.go | 6 +- 11 files changed, 347 insertions(+), 7 deletions(-) create mode 100644 internal/marketing/handler/ad_placement_handler.go create mode 100644 internal/marketing/model/ad_placement.go create mode 100644 internal/marketing/repository/ad_placement_repo.go diff --git a/cmd/api/main.go b/cmd/api/main.go index eaa52a2..c8abf45 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -83,6 +83,7 @@ func main() { &marketingmodel.MarketingTemplate{}, &marketingmodel.MarketingDownload{}, &marketingmodel.UserLogo{}, + &marketingmodel.AdPlacement{}, &quitcheckinmodel.Profile{}, &quitcheckinmodel.DailyStatus{}, &quitcheckinmodel.RelapseEvent{}, @@ -166,6 +167,8 @@ func main() { marketingTemplateHandler := marketinghandler.NewTemplateHandler(templateSvc) marketingDownloadHandler := marketinghandler.NewDownloadHandler(downloadSvc) marketingUserLogoHandler := marketinghandler.NewUserLogoHandler(userLogoSvc) + adPlacementRepo := marketingrepo.NewAdPlacementRepository(database.DB) + marketingAdPlacementHandler := marketinghandler.NewAdPlacementHandler(adPlacementRepo) adminService := adminmodule.NewService( database.DB, @@ -199,6 +202,7 @@ func main() { marketingTemplateHandler, marketingDownloadHandler, marketingUserLogoHandler, + marketingAdPlacementHandler, quitCheckinHandler, ) diff --git a/internal/common/auth/handler/auth_handler.go b/internal/common/auth/handler/auth_handler.go index f0b33a8..72b3a3f 100644 --- a/internal/common/auth/handler/auth_handler.go +++ b/internal/common/auth/handler/auth_handler.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "wx_service/internal/common/auth/service" + "wx_service/internal/middleware" "wx_service/internal/model" ) @@ -148,3 +149,35 @@ func (h *AuthHandler) DevLogin(c *gin.Context) { }, })) } + +type updateProfileRequest struct { + Nickname string `json:"nickname"` + AvatarURL string `json:"avatar_url"` +} + +func (h *AuthHandler) UpdateProfile(c *gin.Context) { + user := middleware.MustCurrentUser(c) + + var req updateProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "参数错误")) + return + } + if req.Nickname == "" && req.AvatarURL == "" { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "请提供昵称或头像")) + return + } + + updated, err := h.authService.UpdateProfile(c.Request.Context(), user.ID, req.Nickname, req.AvatarURL) + if err != nil { + log.Printf("[update_profile] error: %v", err) + c.JSON(http.StatusInternalServerError, model.Error(http.StatusInternalServerError, "更新失败")) + return + } + + c.JSON(http.StatusOK, model.Success(gin.H{ + "id": updated.ID, + "nickname": updated.NickName, + "avatar_url": updated.AvatarURL, + })) +} diff --git a/internal/common/auth/service/auth_service.go b/internal/common/auth/service/auth_service.go index e37b89c..4f6423b 100644 --- a/internal/common/auth/service/auth_service.go +++ b/internal/common/auth/service/auth_service.go @@ -192,6 +192,31 @@ func (s *AuthService) DevLogin(ctx context.Context, miniProgramID uint) (*LoginR }, nil } +// UpdateProfile 更新用户昵称和头像。 +func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, nickname, avatarURL string) (*model.User, error) { + tx := s.db.WithContext(ctx) + var user model.User + if err := tx.First(&user, userID).Error; err != nil { + return nil, fmt.Errorf("find user: %w", err) + } + + updates := map[string]interface{}{} + if nickname != "" { + updates["nick_name"] = nickname + user.NickName = nickname + } + if avatarURL != "" { + updates["avatar_url"] = avatarURL + user.AvatarURL = avatarURL + } + if len(updates) > 0 { + if err := tx.Model(&user).Updates(updates).Error; err != nil { + return nil, fmt.Errorf("update profile: %w", err) + } + } + return &user, nil +} + func (s *AuthService) getSmokeMode(ctx context.Context, uid int) (string, error) { var profile smokemodel.SmokeUserProfile err := s.db.WithContext(ctx). diff --git a/internal/marketing/handler/ad_placement_handler.go b/internal/marketing/handler/ad_placement_handler.go new file mode 100644 index 0000000..9a3a046 --- /dev/null +++ b/internal/marketing/handler/ad_placement_handler.go @@ -0,0 +1,157 @@ +package handler + +import ( + "errors" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "wx_service/internal/marketing/model" + "wx_service/internal/marketing/repository" + commonmodel "wx_service/internal/model" +) + +type AdPlacementHandler struct { + repo *repository.AdPlacementRepository +} + +func NewAdPlacementHandler(repo *repository.AdPlacementRepository) *AdPlacementHandler { + return &AdPlacementHandler{repo: repo} +} + +func (h *AdPlacementHandler) AdminList(c *gin.Context) { + list, err := h.repo.ListAll() + if err != nil { + c.JSON(http.StatusInternalServerError, commonmodel.Error(http.StatusInternalServerError, "获取广告位列表失败")) + return + } + c.JSON(http.StatusOK, commonmodel.Success(list)) +} + +type adPlacementCreateReq struct { + MiniProgramID uint `json:"mini_program_id" binding:"required"` + Name string `json:"name" binding:"required"` + AdType string `json:"ad_type" binding:"required"` + AdUnitID string `json:"ad_unit_id"` + Status *int `json:"status"` + Description string `json:"description"` +} + +func (h *AdPlacementHandler) AdminCreate(c *gin.Context) { + var req adPlacementCreateReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, commonmodel.Error(http.StatusBadRequest, "参数错误")) + return + } + + p := &model.AdPlacement{ + MiniProgramID: req.MiniProgramID, + Name: req.Name, + AdType: req.AdType, + AdUnitID: req.AdUnitID, + Description: req.Description, + Status: 1, + } + if req.Status != nil { + p.Status = *req.Status + } + + if err := h.repo.Create(p); err != nil { + c.JSON(http.StatusInternalServerError, commonmodel.Error(http.StatusInternalServerError, "创建失败")) + return + } + c.JSON(http.StatusOK, commonmodel.Success(p)) +} + +type adPlacementUpdateReq struct { + Name *string `json:"name"` + AdUnitID *string `json:"ad_unit_id"` + Status *int `json:"status"` + Description *string `json:"description"` +} + +func (h *AdPlacementHandler) AdminUpdate(c *gin.Context) { + id, err := strconv.ParseUint(c.Param("id"), 10, 64) + if err != nil || id == 0 { + c.JSON(http.StatusBadRequest, commonmodel.Error(http.StatusBadRequest, "无效 ID")) + return + } + + p, err := h.repo.FindByID(uint(id)) + if err != nil { + if errors.Is(err, repository.ErrAdPlacementNotFound) { + c.JSON(http.StatusNotFound, commonmodel.Error(http.StatusNotFound, "广告位不存在")) + return + } + c.JSON(http.StatusInternalServerError, commonmodel.Error(http.StatusInternalServerError, "查询失败")) + return + } + + var req adPlacementUpdateReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, commonmodel.Error(http.StatusBadRequest, "参数错误")) + return + } + + if req.Name != nil { + p.Name = *req.Name + } + if req.AdUnitID != nil { + p.AdUnitID = *req.AdUnitID + } + if req.Status != nil { + p.Status = *req.Status + } + if req.Description != nil { + p.Description = *req.Description + } + + if err := h.repo.Update(p); err != nil { + c.JSON(http.StatusInternalServerError, commonmodel.Error(http.StatusInternalServerError, "更新失败")) + return + } + c.JSON(http.StatusOK, commonmodel.Success(p)) +} + +func (h *AdPlacementHandler) AdminDelete(c *gin.Context) { + id, err := strconv.ParseUint(c.Param("id"), 10, 64) + if err != nil || id == 0 { + c.JSON(http.StatusBadRequest, commonmodel.Error(http.StatusBadRequest, "无效 ID")) + return + } + + if err := h.repo.Delete(uint(id)); err != nil { + if errors.Is(err, repository.ErrAdPlacementNotFound) { + c.JSON(http.StatusNotFound, commonmodel.Error(http.StatusNotFound, "广告位不存在")) + return + } + c.JSON(http.StatusInternalServerError, commonmodel.Error(http.StatusInternalServerError, "删除失败")) + return + } + c.JSON(http.StatusOK, commonmodel.Success(nil)) +} + +func (h *AdPlacementHandler) GetAdConfig(c *gin.Context) { + miniProgramIDStr := c.Query("mini_program_id") + miniProgramID, _ := strconv.ParseUint(miniProgramIDStr, 10, 64) + if miniProgramID == 0 { + c.JSON(http.StatusBadRequest, commonmodel.Error(http.StatusBadRequest, "缺少 mini_program_id")) + return + } + + p, err := h.repo.FindByMiniProgramAndType(uint(miniProgramID), "rewarded_video") + if err != nil { + if errors.Is(err, repository.ErrAdPlacementNotFound) { + c.JSON(http.StatusOK, commonmodel.Success(gin.H{"ad_unit_id": "", "enabled": false})) + return + } + c.JSON(http.StatusInternalServerError, commonmodel.Error(http.StatusInternalServerError, "查询失败")) + return + } + + c.JSON(http.StatusOK, commonmodel.Success(gin.H{ + "ad_unit_id": p.AdUnitID, + "enabled": p.Status == 1 && p.AdUnitID != "", + })) +} diff --git a/internal/marketing/handler/user_logo_handler.go b/internal/marketing/handler/user_logo_handler.go index da02c5e..42aa4e9 100644 --- a/internal/marketing/handler/user_logo_handler.go +++ b/internal/marketing/handler/user_logo_handler.go @@ -42,7 +42,7 @@ func (h *UserLogoHandler) Save(c *gin.Context) { logo, err := h.svc.Save(user.ID, req) if err != nil { - if errors.Is(err, service.ErrLogoLimitReached) { + if errors.Is(err, service.ErrLogoLimitReached) || errors.Is(err, service.ErrLogoTooLarge) { c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, err.Error())) return } diff --git a/internal/marketing/model/ad_placement.go b/internal/marketing/model/ad_placement.go new file mode 100644 index 0000000..7285312 --- /dev/null +++ b/internal/marketing/model/ad_placement.go @@ -0,0 +1,19 @@ +package model + +import "time" + +type AdPlacement struct { + ID uint `json:"id" gorm:"primaryKey;comment:主键ID"` + MiniProgramID uint `json:"mini_program_id" gorm:"not null;index;comment:小程序ID"` + Name string `json:"name" gorm:"size:100;not null;comment:广告位名称"` + AdType string `json:"ad_type" gorm:"size:50;not null;default:rewarded_video;comment:广告类型(rewarded_video/banner/interstitial)"` + AdUnitID string `json:"ad_unit_id" gorm:"size:200;comment:微信广告单元ID"` + Status int `json:"status" gorm:"default:1;comment:状态(0禁用/1启用)"` + Description string `json:"description" gorm:"size:500;comment:备注说明"` + CreatedAt time.Time `json:"created_at" gorm:"comment:创建时间"` + UpdatedAt time.Time `json:"updated_at" gorm:"comment:更新时间"` +} + +func (AdPlacement) TableName() string { + return "marketing_ad_placements" +} diff --git a/internal/marketing/repository/ad_placement_repo.go b/internal/marketing/repository/ad_placement_repo.go new file mode 100644 index 0000000..4253975 --- /dev/null +++ b/internal/marketing/repository/ad_placement_repo.go @@ -0,0 +1,76 @@ +package repository + +import ( + "errors" + "fmt" + + "gorm.io/gorm" + + "wx_service/internal/marketing/model" +) + +var ErrAdPlacementNotFound = errors.New("ad placement not found") + +type AdPlacementRepository struct { + db *gorm.DB +} + +func NewAdPlacementRepository(db *gorm.DB) *AdPlacementRepository { + return &AdPlacementRepository{db: db} +} + +func (r *AdPlacementRepository) Create(p *model.AdPlacement) error { + if err := r.db.Create(p).Error; err != nil { + return fmt.Errorf("create ad placement: %w", err) + } + return nil +} + +func (r *AdPlacementRepository) Update(p *model.AdPlacement) error { + if err := r.db.Save(p).Error; err != nil { + return fmt.Errorf("update ad placement: %w", err) + } + return nil +} + +func (r *AdPlacementRepository) Delete(id uint) error { + tx := r.db.Delete(&model.AdPlacement{}, id) + if tx.Error != nil { + return fmt.Errorf("delete ad placement: %w", tx.Error) + } + if tx.RowsAffected == 0 { + return ErrAdPlacementNotFound + } + return nil +} + +func (r *AdPlacementRepository) FindByID(id uint) (*model.AdPlacement, error) { + var p model.AdPlacement + err := r.db.First(&p, id).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrAdPlacementNotFound + } + return nil, fmt.Errorf("find ad placement: %w", err) + } + return &p, nil +} + +func (r *AdPlacementRepository) ListAll() ([]model.AdPlacement, error) { + var list []model.AdPlacement + err := r.db.Order("id DESC").Find(&list).Error + return list, err +} + +func (r *AdPlacementRepository) FindByMiniProgramAndType(miniProgramID uint, adType string) (*model.AdPlacement, error) { + var p model.AdPlacement + err := r.db.Where("mini_program_id = ? AND ad_type = ? AND status = 1", miniProgramID, adType). + First(&p).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrAdPlacementNotFound + } + return nil, fmt.Errorf("find ad placement by type: %w", err) + } + return &p, nil +} diff --git a/internal/marketing/service/user_logo_service.go b/internal/marketing/service/user_logo_service.go index a7db7f6..7515b07 100644 --- a/internal/marketing/service/user_logo_service.go +++ b/internal/marketing/service/user_logo_service.go @@ -7,9 +7,13 @@ import ( "wx_service/internal/marketing/repository" ) -const MaxLogosPerUser = 10 +const MaxLogosPerUser = 3 +const MaxLogoFileSize = 2 * 1024 * 1024 // 2MB -var ErrLogoLimitReached = errors.New("Logo 数量已达上限") +var ( + ErrLogoLimitReached = errors.New("Logo 数量已达上限") + ErrLogoTooLarge = errors.New("Logo 文件不能超过 2MB") +) type UserLogoService struct { repo *repository.UserLogoRepository @@ -26,6 +30,10 @@ type SaveLogoRequest struct { } func (s *UserLogoService) Save(userID uint, req SaveLogoRequest) (*model.UserLogo, error) { + if req.FileSize > MaxLogoFileSize { + return nil, ErrLogoTooLarge + } + count, err := s.repo.CountByUser(userID) if err != nil { return nil, err diff --git a/internal/routes/admin_routes.go b/internal/routes/admin_routes.go index d433b09..70c35a8 100644 --- a/internal/routes/admin_routes.go +++ b/internal/routes/admin_routes.go @@ -13,6 +13,7 @@ func registerAdminRoutes( categoryHandler *marketinghandler.CategoryHandler, templateHandler *marketinghandler.TemplateHandler, downloadHandler *marketinghandler.DownloadHandler, + adPlacementHandler *marketinghandler.AdPlacementHandler, ) { if handler == nil { return @@ -107,11 +108,21 @@ func registerAdminRoutes( marketing.PUT("/templates/:id", templateHandler.AdminUpdate) marketing.DELETE("/templates/:id", templateHandler.AdminDelete) - marketing.GET("/stats", downloadHandler.AdminStats) - marketing.POST("/upload/oss/token", downloadHandler.AdminUploadToken) + marketing.GET("/stats", downloadHandler.AdminStats) + marketing.POST("/upload/oss/token", downloadHandler.AdminUploadToken) marketing.POST("/upload", downloadHandler.AdminUploadFile) + } + + if adPlacementHandler != nil { + ads := protected.Group("/marketing/ad-placements") + { + ads.GET("", adPlacementHandler.AdminList) + ads.POST("", adPlacementHandler.AdminCreate) + ads.PUT("/:id", adPlacementHandler.AdminUpdate) + ads.DELETE("/:id", adPlacementHandler.AdminDelete) } } } } } +} diff --git a/internal/routes/marketing_routes.go b/internal/routes/marketing_routes.go index 7192579..165b6c3 100644 --- a/internal/routes/marketing_routes.go +++ b/internal/routes/marketing_routes.go @@ -14,6 +14,7 @@ func registerMarketingRoutes( templateHandler *marketinghandler.TemplateHandler, downloadHandler *marketinghandler.DownloadHandler, userLogoHandler *marketinghandler.UserLogoHandler, + adPlacementHandler *marketinghandler.AdPlacementHandler, ) { if categoryHandler == nil || templateHandler == nil || downloadHandler == nil { return @@ -24,6 +25,10 @@ func registerMarketingRoutes( marketing.GET("/categories", categoryHandler.ListEnabled) marketing.GET("/templates", templateHandler.ListEnabled) marketing.GET("/templates/:id", templateHandler.GetDetail) + + if adPlacementHandler != nil { + marketing.GET("/ad-config", adPlacementHandler.GetAdConfig) + } } protectedMarketing := protected.Group("/marketing") diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 9624bfc..6d873ca 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -40,6 +40,7 @@ func Register( marketingTemplateHandler *marketinghandler.TemplateHandler, marketingDownloadHandler *marketinghandler.DownloadHandler, marketingUserLogoHandler *marketinghandler.UserLogoHandler, + marketingAdPlacementHandler *marketinghandler.AdPlacementHandler, quitCheckinHandler *quitcheckinhandler.Handler, ) { // Register 用来集中注册所有 HTTP 路由,便于工程结构更清晰: @@ -65,13 +66,14 @@ func Register( protected.Use(middleware.AuthMiddleware(db, sessionCache)) protected.Use(middleware.RequireUserMiddleware()) { + protected.PUT("/auth/profile", authHandler.UpdateProfile) registerCommonRoutes(protected, uploadHandler) registerRemoveWatermarkRoutes(api, protected, videoHandler) registerMembershipRoutes(protected, redeemCodeHandler) registerSmokeRoutes(protected, smokeHandler, quitPlanHandler) } - registerMarketingRoutes(api, protected, adminToken, marketingCategoryHandler, marketingTemplateHandler, marketingDownloadHandler, marketingUserLogoHandler) + registerMarketingRoutes(api, protected, adminToken, marketingCategoryHandler, marketingTemplateHandler, marketingDownloadHandler, marketingUserLogoHandler, marketingAdPlacementHandler) } apiV2 := router.Group("/api/v2") @@ -84,7 +86,7 @@ func Register( } } - registerAdminRoutes(router, adminHandler, marketingCategoryHandler, marketingTemplateHandler, marketingDownloadHandler) + registerAdminRoutes(router, adminHandler, marketingCategoryHandler, marketingTemplateHandler, marketingDownloadHandler, marketingAdPlacementHandler) // 保质期提醒模块使用独立前缀 /api/expiry,与现有 /api/v1 并存。 expiryAPI := router.Group("/api/expiry")