diff --git a/internal/expiry/repository.go b/internal/expiry/repository.go index d29bbd5..5ae7422 100644 --- a/internal/expiry/repository.go +++ b/internal/expiry/repository.go @@ -1,12 +1,219 @@ package expiry -import "gorm.io/gorm" +import ( + "errors" + "fmt" + "time" + + "gorm.io/gorm" +) // Repository 封装保质期模块的数据访问能力。 type Repository struct { db *gorm.DB } +var ( + ErrExpiryItemNotFound = errors.New("expiry item not found") +) + func NewRepository(db *gorm.DB) *Repository { return &Repository{db: db} } + +// Create 创建物品。 +func (r *Repository) Create(item *ExpiryItem) error { + if err := r.db.Create(item).Error; err != nil { + return fmt.Errorf("create expiry item: %w", err) + } + return nil +} + +// Update 更新物品。 +func (r *Repository) Update(item *ExpiryItem) error { + updates := map[string]interface{}{ + "mini_program_id": item.MiniProgramID, + "name": item.Name, + "category": item.Category, + "production_date": item.ProductionDate, + "expiry_date": item.ExpiryDate, + "shelf_life_days": item.ShelfLifeDays, + "quantity": item.Quantity, + "location": item.Location, + "remark": item.Remark, + "status": item.Status, + } + + tx := r.db.Model(&ExpiryItem{}). + Where("id = ? AND user_id = ?", item.ID, item.UserID). + Updates(updates) + if tx.Error != nil { + return fmt.Errorf("update expiry item: %w", tx.Error) + } + if tx.RowsAffected == 0 { + return ErrExpiryItemNotFound + } + return nil +} + +// Delete 软删除物品。 +func (r *Repository) Delete(id, userID uint) error { + tx := r.db.Where("id = ? AND user_id = ?", id, userID).Delete(&ExpiryItem{}) + if tx.Error != nil { + return fmt.Errorf("delete expiry item: %w", tx.Error) + } + if tx.RowsAffected == 0 { + return ErrExpiryItemNotFound + } + return nil +} + +// FindByID 根据 ID 查询单个物品。 +func (r *Repository) FindByID(id, userID uint) (*ExpiryItem, error) { + var item ExpiryItem + err := r.db.Where("id = ? AND user_id = ?", id, userID).First(&item).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrExpiryItemNotFound + } + return nil, fmt.Errorf("find expiry item by id: %w", err) + } + return &item, nil +} + +// FindByUser 按用户 + 筛选条件分页查询物品列表。 +func (r *Repository) FindByUser( + userID uint, + filters map[string]interface{}, + page, pageSize int, +) ([]ExpiryItem, int64, error) { + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 20 + } + if pageSize > 100 { + pageSize = 100 + } + + query := r.db.Model(&ExpiryItem{}).Where("user_id = ?", userID) + now := dateOnly(time.Now()) + sevenDaysLater := now.AddDate(0, 0, 7) + + if status, ok := filters["status"].(string); ok && status != "" && status != "all" { + switch status { + case StatusExpiring: + query = query.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). + Where("expiry_date BETWEEN ? AND ?", now, sevenDaysLater) + case StatusExpired: + query = query.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). + Where("expiry_date < ?", now) + case StatusNormal: + query = query.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). + Where("expiry_date > ?", sevenDaysLater) + case StatusUsed, StatusDiscarded: + query = query.Where("status = ?", status) + } + } + + if category, ok := filters["category"].(string); ok && category != "" && category != "all" { + query = query.Where("category = ?", category) + } + + orderBy := "expiry_date ASC" + if sort, ok := filters["sort"].(string); ok && sort != "" { + switch sort { + case "created_at": + orderBy = "created_at DESC" + case "expiry_date": + orderBy = "expiry_date ASC" + } + } + query = query.Order(orderBy).Order("id DESC") + + var total int64 + if err := query.Count(&total).Error; err != nil { + return nil, 0, fmt.Errorf("count expiry items: %w", err) + } + + var items []ExpiryItem + offset := (page - 1) * pageSize + if err := query.Offset(offset).Limit(pageSize).Find(&items).Error; err != nil { + return nil, 0, fmt.Errorf("list expiry items: %w", err) + } + + return items, total, nil +} + +// GetSummary 获取首页汇总统计数据。 +func (r *Repository) GetSummary(userID uint) (map[string]int, error) { + now := dateOnly(time.Now()) + sevenDaysLater := now.AddDate(0, 0, 7) + + base := r.db.Model(&ExpiryItem{}).Where("user_id = ?", userID) + + count := func(tx *gorm.DB) (int64, error) { + var v int64 + if err := tx.Count(&v).Error; err != nil { + return 0, err + } + return v, nil + } + + totalItems, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded})) + if err != nil { + return nil, fmt.Errorf("count total items: %w", err) + } + + expiringSoon, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). + Where("expiry_date BETWEEN ? AND ?", now, sevenDaysLater)) + if err != nil { + return nil, fmt.Errorf("count expiring items: %w", err) + } + + expired, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). + Where("expiry_date < ?", now)) + if err != nil { + return nil, fmt.Errorf("count expired items: %w", err) + } + + normal, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}). + Where("expiry_date > ?", sevenDaysLater)) + if err != nil { + return nil, fmt.Errorf("count normal items: %w", err) + } + + used, err := count(base.Where("status = ?", StatusUsed)) + if err != nil { + return nil, fmt.Errorf("count used items: %w", err) + } + + discarded, err := count(base.Where("status = ?", StatusDiscarded)) + if err != nil { + return nil, fmt.Errorf("count discarded items: %w", err) + } + + return map[string]int{ + "total_items": int(totalItems), + "expiring_soon": int(expiringSoon), + "expired": int(expired), + "normal": int(normal), + "used": int(used), + "discarded": int(discarded), + }, nil +} + +// UpdateStatus 更新物品状态(used/discarded)。 +func (r *Repository) UpdateStatus(id, userID uint, status string) error { + tx := r.db.Model(&ExpiryItem{}). + Where("id = ? AND user_id = ?", id, userID). + Update("status", status) + if tx.Error != nil { + return fmt.Errorf("update expiry item status: %w", tx.Error) + } + if tx.RowsAffected == 0 { + return ErrExpiryItemNotFound + } + return nil +}