test(expiry): 完成 #31 后端单元测试与覆盖率提升

This commit is contained in:
root
2026-03-04 18:35:32 +08:00
parent e1b5382004
commit 6c303abd58
7 changed files with 909 additions and 10 deletions
+10 -10
View File
@@ -151,7 +151,9 @@ func (r *Repository) GetSummary(userID uint) (map[string]int, error) {
now := dateOnly(time.Now())
sevenDaysLater := now.AddDate(0, 0, 7)
base := r.db.Model(&ExpiryItem{}).Where("user_id = ?", userID)
base := func() *gorm.DB {
return r.db.Model(&ExpiryItem{}).Where("user_id = ?", userID)
}
count := func(tx *gorm.DB) (int64, error) {
var v int64
@@ -161,35 +163,35 @@ func (r *Repository) GetSummary(userID uint) (map[string]int, error) {
return v, nil
}
totalItems, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}))
totalItems, err := count(base().Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}))
if err != nil {
return nil, fmt.Errorf("count total items: %w", err)
}
expiringSoon, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}).
expiringSoon, err := count(base().Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}).
Where("expiry_date BETWEEN ? AND ?", now, sevenDaysLater))
if err != nil {
return nil, fmt.Errorf("count expiring items: %w", err)
}
expired, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}).
expired, err := count(base().Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}).
Where("expiry_date < ?", now))
if err != nil {
return nil, fmt.Errorf("count expired items: %w", err)
}
normal, err := count(base.Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}).
normal, err := count(base().Where("status NOT IN ?", []string{StatusUsed, StatusDiscarded}).
Where("expiry_date > ?", sevenDaysLater))
if err != nil {
return nil, fmt.Errorf("count normal items: %w", err)
}
used, err := count(base.Where("status = ?", StatusUsed))
used, err := count(base().Where("status = ?", StatusUsed))
if err != nil {
return nil, fmt.Errorf("count used items: %w", err)
}
discarded, err := count(base.Where("status = ?", StatusDiscarded))
discarded, err := count(base().Where("status = ?", StatusDiscarded))
if err != nil {
return nil, fmt.Errorf("count discarded items: %w", err)
}
@@ -250,9 +252,7 @@ func (r *Repository) UpdateSettings(userID uint, remindDays []int) (*ExpiryUserS
}
settings.RemindDays = copyIntSlice(remindDays)
if err := r.db.Model(&ExpiryUserSettings{}).
Where("user_id = ?", userID).
Update("remind_days", settings.RemindDays).Error; err != nil {
if err := r.db.Save(settings).Error; err != nil {
return nil, fmt.Errorf("update expiry settings: %w", err)
}
return settings, nil