package expiry import ( "errors" "fmt" "time" "gorm.io/gorm" ) // Repository 封装保质期模块的数据访问能力。 type Repository struct { db *gorm.DB } var ( ErrExpiryItemNotFound = errors.New("expiry item not found") ) func NewRepository(db *gorm.DB) *Repository { return &Repository{db: db} } // Create 创建物品。 func (r *Repository) Create(item *ExpiryItem) error { if err := r.db.Create(item).Error; err != nil { return fmt.Errorf("create expiry item: %w", err) } return nil } // Update 更新物品。 func (r *Repository) Update(item *ExpiryItem) error { updates := map[string]interface{}{ "mini_program_id": item.MiniProgramID, "name": item.Name, "category": item.Category, "production_date": item.ProductionDate, "expiry_date": item.ExpiryDate, "shelf_life_days": item.ShelfLifeDays, "quantity": item.Quantity, "location": item.Location, "remark": item.Remark, "status": item.Status, } tx := r.db.Model(&ExpiryItem{}). Where("id = ? AND user_id = ?", item.ID, item.UserID). Updates(updates) if tx.Error != nil { return fmt.Errorf("update expiry item: %w", tx.Error) } if tx.RowsAffected == 0 { return ErrExpiryItemNotFound } return nil } // Delete 软删除物品。 func (r *Repository) Delete(id, userID uint) error { tx := r.db.Where("id = ? AND user_id = ?", id, userID).Delete(&ExpiryItem{}) if tx.Error != nil { return fmt.Errorf("delete expiry item: %w", tx.Error) } if tx.RowsAffected == 0 { return ErrExpiryItemNotFound } return nil } // FindByID 根据 ID 查询单个物品。 func (r *Repository) FindByID(id, userID uint) (*ExpiryItem, error) { var item ExpiryItem err := r.db.Where("id = ? AND user_id = ?", id, userID).First(&item).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrExpiryItemNotFound } return nil, fmt.Errorf("find expiry item by id: %w", err) } return &item, nil } // FindByUser 按用户 + 筛选条件分页查询物品列表。 func (r *Repository) FindByUser( userID uint, filters map[string]interface{}, page, pageSize int, ) ([]ExpiryItem, int64, error) { if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 20 } if pageSize > 100 { pageSize = 100 } query := r.db.Model(&ExpiryItem{}).Where("user_id = ?", userID) now := dateOnly(time.Now()) sevenDaysLater := now.AddDate(0, 0, 7) if status, ok := filters["status"].(string); ok && status != "" && status != "all" { switch status { case StatusExpiring: query = query.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). Where("expiry_date BETWEEN ? AND ?", now, sevenDaysLater) case StatusExpired: query = query.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). Where("expiry_date < ?", now) case StatusNormal: query = query.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). Where("expiry_date > ?", sevenDaysLater) case StatusUsed, StatusDiscarded: query = query.Where("status = ?", status) } } if category, ok := filters["category"].(string); ok && category != "" && category != "all" { query = query.Where("category = ?", category) } orderBy := "expiry_date ASC" if sort, ok := filters["sort"].(string); ok && sort != "" { switch sort { case "created_at": orderBy = "created_at DESC" case "expiry_date": orderBy = "expiry_date ASC" } } query = query.Order(orderBy).Order("id DESC") var total int64 if err := query.Count(&total).Error; err != nil { return nil, 0, fmt.Errorf("count expiry items: %w", err) } var items []ExpiryItem offset := (page - 1) * pageSize if err := query.Offset(offset).Limit(pageSize).Find(&items).Error; err != nil { return nil, 0, fmt.Errorf("list expiry items: %w", err) } return items, total, nil } // GetSummary 获取首页汇总统计数据。 func (r *Repository) GetSummary(userID uint) (map[string]int, error) { now := dateOnly(time.Now()) sevenDaysLater := now.AddDate(0, 0, 7) base := r.db.Model(&ExpiryItem{}).Where("user_id = ?", userID) count := func(tx *gorm.DB) (int64, error) { var v int64 if err := tx.Count(&v).Error; err != nil { return 0, err } return v, nil } totalItems, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded})) if err != nil { return nil, fmt.Errorf("count total items: %w", err) } expiringSoon, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). Where("expiry_date BETWEEN ? AND ?", now, sevenDaysLater)) if err != nil { return nil, fmt.Errorf("count expiring items: %w", err) } expired, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). Where("expiry_date < ?", now)) if err != nil { return nil, fmt.Errorf("count expired items: %w", err) } normal, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). Where("expiry_date > ?", sevenDaysLater)) if err != nil { return nil, fmt.Errorf("count normal items: %w", err) } used, err := count(base.Where("status = ?", StatusUsed)) if err != nil { return nil, fmt.Errorf("count used items: %w", err) } discarded, err := count(base.Where("status = ?", StatusDiscarded)) if err != nil { return nil, fmt.Errorf("count discarded items: %w", err) } return map[string]int{ "total_items": int(totalItems), "expiring_soon": int(expiringSoon), "expired": int(expired), "normal": int(normal), "used": int(used), "discarded": int(discarded), }, nil } // UpdateStatus 更新物品状态(used/discarded)。 func (r *Repository) UpdateStatus(id, userID uint, status string) error { tx := r.db.Model(&ExpiryItem{}). Where("id = ? AND user_id = ?", id, userID). Update("status", status) if tx.Error != nil { return fmt.Errorf("update expiry item status: %w", tx.Error) } if tx.RowsAffected == 0 { return ErrExpiryItemNotFound } return nil } // GetSettings 查询用户提醒设置。 func (r *Repository) GetSettings(userID uint) (*ExpiryUserSettings, error) { var settings ExpiryUserSettings err := r.db.Where("user_id = ?", userID).First(&settings).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("find expiry settings: %w", err) } return &settings, nil } // UpdateSettings 更新用户提醒设置(不存在则创建)。 func (r *Repository) UpdateSettings(userID uint, remindDays []int) (*ExpiryUserSettings, error) { settings, err := r.GetSettings(userID) if err != nil { return nil, err } if settings == nil { settings = &ExpiryUserSettings{ UserID: userID, RemindDays: copyIntSlice(remindDays), } if err := r.db.Create(settings).Error; err != nil { return nil, fmt.Errorf("create expiry settings: %w", err) } return settings, nil } settings.RemindDays = copyIntSlice(remindDays) if err := r.db.Model(&ExpiryUserSettings{}). Where("user_id = ?", userID). Update("remind_days", settings.RemindDays).Error; err != nil { return nil, fmt.Errorf("update expiry settings: %w", err) } return settings, nil } func copyIntSlice(values []int) []int { if len(values) == 0 { return nil } copied := make([]int, len(values)) copy(copied, values) return copied }