diff --git a/cmd/api/main.go b/cmd/api/main.go index 1a1d7f1..9abfed8 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -9,8 +9,14 @@ import ( "wx_service/internal/database" "wx_service/internal/handler" "wx_service/internal/model" + rmhandler "wx_service/internal/remove_watermark/handler" + rmmodel "wx_service/internal/remove_watermark/model" + rmservice "wx_service/internal/remove_watermark/service" "wx_service/internal/routes" "wx_service/internal/service" + smokehandler "wx_service/internal/smoke/handler" + smokemodel "wx_service/internal/smoke/model" + smokeservice "wx_service/internal/smoke/service" ) func main() { @@ -22,7 +28,13 @@ func main() { log.Fatalf("init database failed: %v", err) } // 3) 自动建表/迁移(开发阶段很方便;生产环境可改为手动迁移) - if err := database.AutoMigrate(&model.MiniProgram{}, &model.User{}, &model.VideoParseLog{}, &model.VideoParseUnlock{}); err != nil { + if err := database.AutoMigrate( + &model.MiniProgram{}, + &model.User{}, + &rmmodel.VideoParseLog{}, + &rmmodel.VideoParseUnlock{}, + &smokemodel.SmokeLog{}, + ); err != nil { log.Fatalf("auto migrate failed: %v", err) } @@ -34,14 +46,17 @@ func main() { miniProgramService := service.NewMiniProgramService(database.DB) authService := service.NewAuthService(database.DB, miniProgramService) authHandler := handler.NewAuthHandler(authService) - videoService, err := service.NewVideoService(database.DB, config.AppConfig.ShortVideo) + videoService, err := rmservice.NewVideoService(database.DB, config.AppConfig.ShortVideo) if err != nil { log.Fatalf("init video service failed: %v", err) } - videoHandler := handler.NewVideoHandler(videoService) + videoHandler := rmhandler.NewVideoHandler(videoService) + + smokeLogService := smokeservice.NewSmokeLogService(database.DB) + smokeHandler := smokehandler.NewSmokeHandler(smokeLogService) // 6) 注册路由:把 URL 映射到 handler - routes.Register(router, database.DB, authHandler, videoHandler) + routes.Register(router, database.DB, authHandler, videoHandler, smokeHandler) // 7) 启动监听端口 addr := ":" + config.AppConfig.Server.Port diff --git a/internal/middleware/current_user.go b/internal/middleware/current_user.go new file mode 100644 index 0000000..5cdcc7d --- /dev/null +++ b/internal/middleware/current_user.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + + "wx_service/internal/model" +) + +// CurrentUser 从 gin.Context 中取出鉴权中间件写入的当前用户。 +// 返回值 ok=false 表示:未经过鉴权中间件,或 token 无效导致未设置用户。 +func CurrentUser(c *gin.Context) (*model.User, bool) { + userVal, exists := c.Get(ContextCurrentUserKey) + if !exists { + return nil, false + } + user, ok := userVal.(*model.User) + return user, ok +} diff --git a/internal/handler/video_handler.go b/internal/remove_watermark/handler/video_handler.go similarity index 89% rename from internal/handler/video_handler.go rename to internal/remove_watermark/handler/video_handler.go index 5f1c4fd..7b680e3 100644 --- a/internal/handler/video_handler.go +++ b/internal/remove_watermark/handler/video_handler.go @@ -8,7 +8,7 @@ import ( "wx_service/internal/middleware" "wx_service/internal/model" - "wx_service/internal/service" + "wx_service/internal/remove_watermark/service" ) type VideoHandler struct { @@ -32,7 +32,7 @@ func (h *VideoHandler) RemoveWatermark(c *gin.Context) { return } - user, ok := getCurrentUser(c) + user, ok := middleware.CurrentUser(c) if !ok { c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) return @@ -69,7 +69,7 @@ func (h *VideoHandler) RemoveWatermark(c *gin.Context) { } func (h *VideoHandler) UnlockQuota(c *gin.Context) { - user, ok := getCurrentUser(c) + user, ok := middleware.CurrentUser(c) if !ok { c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) return @@ -84,12 +84,3 @@ func (h *VideoHandler) UnlockQuota(c *gin.Context) { "unlocked": true, })) } - -func getCurrentUser(c *gin.Context) (*model.User, bool) { - userVal, exists := c.Get(middleware.ContextCurrentUserKey) - if !exists { - return nil, false - } - user, ok := userVal.(*model.User) - return user, ok -} diff --git a/internal/model/video_parse.go b/internal/remove_watermark/model/video_parse.go similarity index 100% rename from internal/model/video_parse.go rename to internal/remove_watermark/model/video_parse.go diff --git a/internal/service/video_service.go b/internal/remove_watermark/service/video_service.go similarity index 93% rename from internal/service/video_service.go rename to internal/remove_watermark/service/video_service.go index 7c81f77..32406d3 100644 --- a/internal/service/video_service.go +++ b/internal/remove_watermark/service/video_service.go @@ -15,7 +15,8 @@ import ( "gorm.io/gorm" "wx_service/config" - "wx_service/internal/model" + usermodel "wx_service/internal/model" + rmmodel "wx_service/internal/remove_watermark/model" ) const removeWatermarkEndpoint = "https://api.23bt.cn/api/d1w/index" @@ -73,7 +74,7 @@ func NewVideoService(db *gorm.DB, cfg config.ShortVideoConfig) (*VideoService, e }, nil } -func (s *VideoService) RemoveWatermark(ctx context.Context, user *model.User, content string) (*RemoveWatermarkResult, error) { +func (s *VideoService) RemoveWatermark(ctx context.Context, user *usermodel.User, content string) (*RemoveWatermarkResult, error) { // RemoveWatermark 的整体流程: // 1) 从 content 里提取链接 // 2) 检查每日免费额度(或是否已解锁) @@ -99,7 +100,7 @@ func (s *VideoService) RemoveWatermark(ctx context.Context, user *model.User, co statusCode, body, requestErr = s.callThirdParty(ctx, link) duration := int(time.Since(now).Milliseconds()) - logEntry := model.VideoParseLog{ + logEntry := rmmodel.VideoParseLog{ MiniProgramID: user.MiniProgramID, UserID: user.ID, RequestContent: content, @@ -130,11 +131,11 @@ func (s *VideoService) RemoveWatermark(ctx context.Context, user *model.User, co }, nil } -func (s *VideoService) UnlockForToday(ctx context.Context, user *model.User) error { +func (s *VideoService) UnlockForToday(ctx context.Context, user *usermodel.User) error { // “看广告解锁”的实现方式:在当天写一条 unlock 记录即可(存在则更新时间戳) startOfDay, _ := dayRange(time.Now()) - var unlock model.VideoParseUnlock + var unlock rmmodel.VideoParseUnlock tx := s.db.WithContext(ctx) err := tx.Where("user_id = ? AND mini_program_id = ? AND unlock_date = ?", user.ID, user.MiniProgramID, startOfDay).First(&unlock).Error if err == nil { @@ -146,7 +147,7 @@ func (s *VideoService) UnlockForToday(ctx context.Context, user *model.User) err return fmt.Errorf("load unlock record: %w", err) } - record := model.VideoParseUnlock{ + record := rmmodel.VideoParseUnlock{ MiniProgramID: user.MiniProgramID, UserID: user.ID, UnlockDate: startOfDay, @@ -158,7 +159,7 @@ func (s *VideoService) UnlockForToday(ctx context.Context, user *model.User) err return nil } -func (s *VideoService) ensureQuota(ctx context.Context, user *model.User) (bool, error) { +func (s *VideoService) ensureQuota(ctx context.Context, user *usermodel.User) (bool, error) { // ensureQuota 返回值 freeQuotaUsed 的含义: // - true:这次调用会消耗一次“免费额度” // - false:不消耗(例如今日已解锁或未启用限额) @@ -169,7 +170,7 @@ func (s *VideoService) ensureQuota(ctx context.Context, user *model.User) (bool, startOfDay, endOfDay := dayRange(time.Now()) tx := s.db.WithContext(ctx) - var unlock model.VideoParseUnlock + var unlock rmmodel.VideoParseUnlock if err := tx.Where("user_id = ? AND mini_program_id = ? AND unlock_date = ?", user.ID, user.MiniProgramID, startOfDay).First(&unlock).Error; err == nil { return false, nil } else if err != gorm.ErrRecordNotFound { @@ -177,7 +178,7 @@ func (s *VideoService) ensureQuota(ctx context.Context, user *model.User) (bool, } var count int64 - if err := tx.Model(&model.VideoParseLog{}). + if err := tx.Model(&rmmodel.VideoParseLog{}). Where("user_id = ? AND mini_program_id = ? AND free_quota_used = ? AND created_at >= ? AND created_at < ?", user.ID, user.MiniProgramID, true, startOfDay, endOfDay). Count(&count).Error; err != nil { diff --git a/internal/routes/remove_watermark_routes.go b/internal/routes/remove_watermark_routes.go new file mode 100644 index 0000000..20f5ff6 --- /dev/null +++ b/internal/routes/remove_watermark_routes.go @@ -0,0 +1,13 @@ +package routes + +import ( + "github.com/gin-gonic/gin" + + rmhandler "wx_service/internal/remove_watermark/handler" +) + +func registerRemoveWatermarkRoutes(protected *gin.RouterGroup, videoHandler *rmhandler.VideoHandler) { + // 去水印相关接口(保持原有路径不变) + protected.POST("/video/remove_watermark", videoHandler.RemoveWatermark) + protected.POST("/video/remove_watermark/unlock", videoHandler.UnlockQuota) +} diff --git a/internal/routes/routes.go b/internal/routes/routes.go index fcb10d0..5234e5f 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -8,9 +8,17 @@ import ( "wx_service/internal/handler" "wx_service/internal/middleware" + rmhandler "wx_service/internal/remove_watermark/handler" + smokehandler "wx_service/internal/smoke/handler" ) -func Register(router *gin.Engine, db *gorm.DB, authHandler *handler.AuthHandler, videoHandler *handler.VideoHandler) { +func Register( + router *gin.Engine, + db *gorm.DB, + authHandler *handler.AuthHandler, + videoHandler *rmhandler.VideoHandler, + smokeHandler *smokehandler.SmokeHandler, +) { // Register 用来集中注册所有 HTTP 路由,便于工程结构更清晰: // - main 只负责初始化(配置/DB/依赖注入) // - routes 只负责把 URL 映射到 handler @@ -23,8 +31,8 @@ func Register(router *gin.Engine, db *gorm.DB, authHandler *handler.AuthHandler, protected := api.Group("") protected.Use(middleware.AuthMiddleware(db)) { - protected.POST("/video/remove_watermark", videoHandler.RemoveWatermark) - protected.POST("/video/remove_watermark/unlock", videoHandler.UnlockQuota) + registerRemoveWatermarkRoutes(protected, videoHandler) + registerSmokeRoutes(protected, smokeHandler) } } diff --git a/internal/routes/smoke_routes.go b/internal/routes/smoke_routes.go new file mode 100644 index 0000000..22f31f9 --- /dev/null +++ b/internal/routes/smoke_routes.go @@ -0,0 +1,19 @@ +package routes + +import ( + "github.com/gin-gonic/gin" + + smokehandler "wx_service/internal/smoke/handler" +) + +func registerSmokeRoutes(protected *gin.RouterGroup, smokeHandler *smokehandler.SmokeHandler) { + // 戒烟/抽烟记录(与 video 去水印功能在路由前缀上区分开) + smoke := protected.Group("/smoke") + { + smoke.POST("/logs", smokeHandler.Create) + smoke.GET("/logs", smokeHandler.List) + smoke.GET("/logs/:id", smokeHandler.Get) + smoke.PUT("/logs/:id", smokeHandler.Update) + smoke.DELETE("/logs/:id", smokeHandler.Delete) + } +} diff --git a/internal/smoke/handler/smoke_handler.go b/internal/smoke/handler/smoke_handler.go new file mode 100644 index 0000000..4fd4a44 --- /dev/null +++ b/internal/smoke/handler/smoke_handler.go @@ -0,0 +1,231 @@ +package handler + +import ( + "errors" + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" + + "wx_service/internal/middleware" + "wx_service/internal/model" + smokeservice "wx_service/internal/smoke/service" +) + +type SmokeHandler struct { + smokeLogService *smokeservice.SmokeLogService +} + +func NewSmokeHandler(smokeLogService *smokeservice.SmokeLogService) *SmokeHandler { + return &SmokeHandler{smokeLogService: smokeLogService} +} + +// dateLayout 用于解析前端传入的日期字符串(例如:2025-12-31) +const dateLayout = "2006-01-02" + +type createSmokeLogRequest struct { + // 只记录“日期”即可;如果不传,后端会按当天处理 + SmokeTime string `json:"smoke_time"` + Remark string `json:"remark"` + Level int64 `json:"level"` + Num int `json:"num"` +} + +func (h *SmokeHandler) Create(c *gin.Context) { + user, ok := middleware.CurrentUser(c) + if !ok { + c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) + return + } + + var req createSmokeLogRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "请求参数错误")) + return + } + + var smokeTime *time.Time + if req.SmokeTime != "" { + parsed, err := time.ParseInLocation(dateLayout, req.SmokeTime, time.Local) + if err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "smoke_time 格式错误,应为 YYYY-MM-DD")) + return + } + smokeTime = &parsed + } + + record, err := h.smokeLogService.Create(c.Request.Context(), int(user.ID), smokeservice.CreateSmokeLogRequest{ + SmokeTime: smokeTime, + Remark: req.Remark, + Level: req.Level, + Num: req.Num, + }) + if err != nil { + c.JSON(http.StatusInternalServerError, model.Error(http.StatusInternalServerError, "创建记录失败,请稍后重试")) + return + } + + c.JSON(http.StatusOK, model.Success(record)) +} + +func (h *SmokeHandler) Get(c *gin.Context) { + user, ok := middleware.CurrentUser(c) + if !ok { + c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) + return + } + + id, err := strconv.Atoi(c.Param("id")) + if err != nil || id <= 0 { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "id 参数错误")) + return + } + + record, err := h.smokeLogService.GetByID(c.Request.Context(), int(user.ID), id) + if err != nil { + if errors.Is(err, smokeservice.ErrSmokeLogNotFound) { + c.JSON(http.StatusNotFound, model.Error(http.StatusNotFound, "记录不存在")) + return + } + c.JSON(http.StatusInternalServerError, model.Error(http.StatusInternalServerError, "查询失败,请稍后重试")) + return + } + + c.JSON(http.StatusOK, model.Success(record)) +} + +func (h *SmokeHandler) List(c *gin.Context) { + user, ok := middleware.CurrentUser(c) + if !ok { + c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + var start *time.Time + if v := c.Query("start"); v != "" { + parsed, err := time.ParseInLocation(dateLayout, v, time.Local) + if err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "start 格式错误,应为 YYYY-MM-DD")) + return + } + start = &parsed + } + var end *time.Time + if v := c.Query("end"); v != "" { + parsed, err := time.ParseInLocation(dateLayout, v, time.Local) + if err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "end 格式错误,应为 YYYY-MM-DD")) + return + } + end = &parsed + } + + result, err := h.smokeLogService.List(c.Request.Context(), int(user.ID), smokeservice.ListSmokeLogsRequest{ + Page: page, + PageSize: pageSize, + Start: start, + End: end, + }) + if err != nil { + c.JSON(http.StatusInternalServerError, model.Error(http.StatusInternalServerError, "查询列表失败,请稍后重试")) + return + } + + c.JSON(http.StatusOK, model.Success(gin.H{ + "items": result.Items, + "total": result.Total, + "page": result.Page, + "page_size": result.PageSize, + })) +} + +type updateSmokeLogRequest struct { + SmokeTime *string `json:"smoke_time"` + Remark *string `json:"remark"` + Level *int64 `json:"level"` + Num *int `json:"num"` +} + +func (h *SmokeHandler) Update(c *gin.Context) { + user, ok := middleware.CurrentUser(c) + if !ok { + c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) + return + } + + id, err := strconv.Atoi(c.Param("id")) + if err != nil || id <= 0 { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "id 参数错误")) + return + } + + var req updateSmokeLogRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "请求参数错误")) + return + } + + smokeTimeProvided := req.SmokeTime != nil + var smokeTime *time.Time + if req.SmokeTime != nil { + if *req.SmokeTime == "" { + smokeTime = nil + } else { + parsed, err := time.ParseInLocation(dateLayout, *req.SmokeTime, time.Local) + if err != nil { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "smoke_time 格式错误,应为 YYYY-MM-DD")) + return + } + smokeTime = &parsed + } + } + + record, err := h.smokeLogService.Update(c.Request.Context(), int(user.ID), id, smokeservice.UpdateSmokeLogRequest{ + SmokeTimeProvided: smokeTimeProvided, + SmokeTime: smokeTime, + Remark: req.Remark, + Level: req.Level, + Num: req.Num, + }) + if err != nil { + if errors.Is(err, smokeservice.ErrSmokeLogNotFound) { + c.JSON(http.StatusNotFound, model.Error(http.StatusNotFound, "记录不存在")) + return + } + c.JSON(http.StatusInternalServerError, model.Error(http.StatusInternalServerError, "更新失败,请稍后重试")) + return + } + + c.JSON(http.StatusOK, model.Success(record)) +} + +func (h *SmokeHandler) Delete(c *gin.Context) { + user, ok := middleware.CurrentUser(c) + if !ok { + c.JSON(http.StatusUnauthorized, model.Error(http.StatusUnauthorized, "未登录或登录已过期")) + return + } + + id, err := strconv.Atoi(c.Param("id")) + if err != nil || id <= 0 { + c.JSON(http.StatusBadRequest, model.Error(http.StatusBadRequest, "id 参数错误")) + return + } + + if err := h.smokeLogService.Delete(c.Request.Context(), int(user.ID), id); err != nil { + if errors.Is(err, smokeservice.ErrSmokeLogNotFound) { + c.JSON(http.StatusNotFound, model.Error(http.StatusNotFound, "记录不存在")) + return + } + c.JSON(http.StatusInternalServerError, model.Error(http.StatusInternalServerError, "删除失败,请稍后重试")) + return + } + + c.JSON(http.StatusOK, model.Success(gin.H{ + "deleted": true, + })) +} diff --git a/internal/smoke/model/smoke_log.go b/internal/smoke/model/smoke_log.go new file mode 100644 index 0000000..9819e87 --- /dev/null +++ b/internal/smoke/model/smoke_log.go @@ -0,0 +1,30 @@ +package model + +import "time" + +// SmokeLog 对应数据库表 fa_smoke_log(戒烟/抽烟记录)。 +// +// 注意:这个表的字段命名来自旧系统(createtime/updatetime/deletetime 为秒级时间戳), +// 因此这里不使用 gorm.Model 的 created_at/updated_at/deleted_at。 +type SmokeLog struct { + // 复合主键(id, uid),其中 id 自增。 + ID int `gorm:"column:id;primaryKey;autoIncrement" json:"id"` + UID int `gorm:"column:uid;primaryKey" json:"-"` + + // smoke_time 在库里是 date 类型(只包含日期,不包含时分秒)。 + SmokeTime *time.Time `gorm:"column:smoke_time;type:date" json:"smoke_time,omitempty"` + + Remark string `gorm:"column:remark;type:text" json:"remark,omitempty"` + + // createtime/updatetime/deletetime:秒级 Unix 时间戳(与 gorm 默认字段不同) + CreateTime *int64 `gorm:"column:createtime" json:"createtime,omitempty"` + UpdateTime *int64 `gorm:"column:updatetime" json:"updatetime,omitempty"` + DeleteTime *int64 `gorm:"column:deletetime" json:"deletetime,omitempty"` + + Level int64 `gorm:"column:level;default:1" json:"level"` + Num int `gorm:"column:num;default:1" json:"num"` +} + +func (SmokeLog) TableName() string { + return "fa_smoke_log" +} diff --git a/internal/smoke/service/smoke_log_service.go b/internal/smoke/service/smoke_log_service.go new file mode 100644 index 0000000..d1271c4 --- /dev/null +++ b/internal/smoke/service/smoke_log_service.go @@ -0,0 +1,217 @@ +package service + +import ( + "context" + "errors" + "fmt" + "time" + + "gorm.io/gorm" + + smokemodel "wx_service/internal/smoke/model" +) + +var ( + ErrSmokeLogNotFound = errors.New("smoke log not found") +) + +type SmokeLogService struct { + db *gorm.DB +} + +func NewSmokeLogService(db *gorm.DB) *SmokeLogService { + return &SmokeLogService{db: db} +} + +type CreateSmokeLogRequest struct { + SmokeTime *time.Time + Remark string + Level int64 + Num int +} + +func (s *SmokeLogService) Create(ctx context.Context, uid int, req CreateSmokeLogRequest) (*smokemodel.SmokeLog, error) { + now := time.Now().Unix() + createTime := now + updateTime := now + + level := req.Level + if level <= 0 { + level = 1 + } + num := req.Num + if num <= 0 { + num = 1 + } + + smokeTime := req.SmokeTime + if smokeTime == nil { + today := time.Now() + startOfDay := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, today.Location()) + smokeTime = &startOfDay + } + + record := smokemodel.SmokeLog{ + UID: uid, + SmokeTime: smokeTime, + Remark: req.Remark, + CreateTime: &createTime, + UpdateTime: &updateTime, + Level: level, + Num: num, + } + + if err := s.db.WithContext(ctx).Create(&record).Error; err != nil { + return nil, fmt.Errorf("create smoke log: %w", err) + } + return &record, nil +} + +func (s *SmokeLogService) GetByID(ctx context.Context, uid int, id int) (*smokemodel.SmokeLog, error) { + var record smokemodel.SmokeLog + err := s.db.WithContext(ctx). + Where("id = ? AND uid = ? AND (deletetime IS NULL OR deletetime = 0)", id, uid). + First(&record).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrSmokeLogNotFound + } + return nil, fmt.Errorf("load smoke log: %w", err) + } + return &record, nil +} + +type ListSmokeLogsRequest struct { + Page int + PageSize int + Start *time.Time + End *time.Time +} + +type ListSmokeLogsResult struct { + Items []smokemodel.SmokeLog + Total int64 + Page int + PageSize int +} + +func (s *SmokeLogService) List(ctx context.Context, uid int, req ListSmokeLogsRequest) (ListSmokeLogsResult, error) { + page := req.Page + if page <= 0 { + page = 1 + } + pageSize := req.PageSize + if pageSize <= 0 { + pageSize = 20 + } + if pageSize > 200 { + pageSize = 200 + } + + tx := s.db.WithContext(ctx).Model(&smokemodel.SmokeLog{}). + Where("uid = ? AND (deletetime IS NULL OR deletetime = 0)", uid) + + if req.Start != nil { + tx = tx.Where("smoke_time >= ?", req.Start.Format("2006-01-02")) + } + if req.End != nil { + tx = tx.Where("smoke_time <= ?", req.End.Format("2006-01-02")) + } + + var total int64 + if err := tx.Count(&total).Error; err != nil { + return ListSmokeLogsResult{}, fmt.Errorf("count smoke logs: %w", err) + } + + var items []smokemodel.SmokeLog + offset := (page - 1) * pageSize + if err := tx. + Order("smoke_time DESC"). + Order("id DESC"). + Limit(pageSize). + Offset(offset). + Find(&items).Error; err != nil { + return ListSmokeLogsResult{}, fmt.Errorf("list smoke logs: %w", err) + } + + return ListSmokeLogsResult{ + Items: items, + Total: total, + Page: page, + PageSize: pageSize, + }, nil +} + +type UpdateSmokeLogRequest struct { + // SmokeTimeProvided 用于区分: + // - false:前端没传 smoke_time(不修改) + // - true:前端传了 smoke_time(可以设置为具体日期,也可以清空为 NULL) + SmokeTimeProvided bool + SmokeTime *time.Time + Remark *string + Level *int64 + Num *int +} + +func (s *SmokeLogService) Update(ctx context.Context, uid int, id int, req UpdateSmokeLogRequest) (*smokemodel.SmokeLog, error) { + record, err := s.GetByID(ctx, uid, id) + if err != nil { + return nil, err + } + + updates := map[string]interface{}{} + if req.SmokeTimeProvided { + updates["smoke_time"] = req.SmokeTime + } + if req.Remark != nil { + updates["remark"] = *req.Remark + } + if req.Level != nil { + if *req.Level <= 0 { + updates["level"] = int64(1) + } else { + updates["level"] = *req.Level + } + } + if req.Num != nil { + if *req.Num <= 0 { + updates["num"] = 1 + } else { + updates["num"] = *req.Num + } + } + + now := time.Now().Unix() + updates["updatetime"] = now + + if len(updates) == 1 { + return record, nil + } + + if err := s.db.WithContext(ctx). + Model(&smokemodel.SmokeLog{}). + Where("id = ? AND uid = ?", id, uid). + Updates(updates).Error; err != nil { + return nil, fmt.Errorf("update smoke log: %w", err) + } + + return s.GetByID(ctx, uid, id) +} + +func (s *SmokeLogService) Delete(ctx context.Context, uid int, id int) error { + now := time.Now().Unix() + result := s.db.WithContext(ctx). + Model(&smokemodel.SmokeLog{}). + Where("id = ? AND uid = ? AND (deletetime IS NULL OR deletetime = 0)", id, uid). + Updates(map[string]interface{}{ + "deletetime": now, + "updatetime": now, + }) + if result.Error != nil { + return fmt.Errorf("delete smoke log: %w", result.Error) + } + if result.RowsAffected == 0 { + return ErrSmokeLogNotFound + } + return nil +}