diff --git a/internal/smoke/service/smoke_ai_advice_service.go b/internal/smoke/service/smoke_ai_advice_service.go index a772e76..d37bb27 100644 --- a/internal/smoke/service/smoke_ai_advice_service.go +++ b/internal/smoke/service/smoke_ai_advice_service.go @@ -31,8 +31,8 @@ const ( ) const ( - SmokeAIAdviceTypeDaily = "daily_advice" - SmokeAIAdviceTypeNextSmoke = "next_smoke_time" + SmokeAIAdviceTypeDaily = "daily_advice" + SmokeAIAdviceTypeNextSmoke = "next_smoke_time" ) type SmokeAIAdviceService struct { @@ -70,16 +70,17 @@ type adviceSnapshot struct { } type adviceUserProfile struct { - BaselineCigsPerDay int `json:"baseline_cigs_per_day,omitempty"` - SmokingYears float64 `json:"smoking_years,omitempty"` - PackPriceCent int `json:"pack_price_cent,omitempty"` - SmokeMotivations []string `json:"smoke_motivations,omitempty"` - QuitMotivations []string `json:"quit_motivations,omitempty"` - WakeUpTime string `json:"wake_up_time,omitempty"` - SleepTime string `json:"sleep_time,omitempty"` - AwakeMinutes int `json:"awake_minutes,omitempty"` - BaselineIntervalMinutes int `json:"baseline_interval_minutes,omitempty"` - OnboardingCompletedAtISO string `json:"onboarding_completed_at,omitempty"` + BaselineCigsPerDay int `json:"baseline_cigs_per_day,omitempty"` + SmokingYears float64 `json:"smoking_years,omitempty"` + PackPriceCent int `json:"pack_price_cent,omitempty"` + SmokeMotivations []string `json:"smoke_motivations,omitempty"` + QuitMotivations []string `json:"quit_motivations,omitempty"` + WakeUpTime string `json:"wake_up_time,omitempty"` + SleepTime string `json:"sleep_time,omitempty"` + AwakeMinutes int `json:"awake_minutes,omitempty"` + BaselineIntervalMinutes int `json:"baseline_interval_minutes,omitempty"` + UserSegment string `json:"user_segment,omitempty"` + OnboardingCompletedAtISO string `json:"onboarding_completed_at,omitempty"` } func (s *SmokeAIAdviceService) GetOrGenerate(ctx context.Context, user *usermodel.User, adviceDate time.Time, promptVersion string) (*smokemodel.SmokeAIAdvice, error) { @@ -284,9 +285,9 @@ func (s *SmokeAIAdviceService) buildSnapshot(ctx context.Context, uid int, advic timeLabel = t.eventAt.Format("15:04") } nodes = append(nodes, adviceSnapshotNode{ - Time: timeLabel, - Num: t.log.Num, - Level: t.log.Level, + Time: timeLabel, + Num: t.log.Num, + Level: t.log.Level, Remark: t.log.Remark, }) } @@ -334,6 +335,7 @@ func loadAdviceUserProfile(ctx context.Context, db *gorm.DB, uid int) *adviceUse SleepTime: sleep, AwakeMinutes: awake, BaselineIntervalMinutes: baselineIntervalMinutes(awake, profile.BaselineCigsPerDay), + UserSegment: deriveUserSegment(profile.BaselineCigsPerDay, profile.SmokingYears), } if profile.OnboardingCompletedAt != nil { out.OnboardingCompletedAtISO = profile.OnboardingCompletedAt.In(time.Local).Format(time.RFC3339) @@ -467,3 +469,13 @@ func truncateString(s string, max int) string { } return s[:max] } + +func deriveUserSegment(baselineCigsPerDay int, smokingYears float64) string { + if baselineCigsPerDay >= 20 || smokingYears >= 10 { + return "heavy" + } + if baselineCigsPerDay >= 10 || smokingYears >= 3 { + return "moderate" + } + return "newbie" +} diff --git a/internal/smoke/service/smoke_ai_advice_service_test.go b/internal/smoke/service/smoke_ai_advice_service_test.go new file mode 100644 index 0000000..44432e5 --- /dev/null +++ b/internal/smoke/service/smoke_ai_advice_service_test.go @@ -0,0 +1,32 @@ +package service + +import "testing" + +func TestDeriveUserSegment(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baselinePerDay int + smokingYears float64 + expectedSegment string + }{ + {name: "新手用户", baselinePerDay: 6, smokingYears: 1.5, expectedSegment: "newbie"}, + {name: "中度用户_按日均", baselinePerDay: 12, smokingYears: 1, expectedSegment: "moderate"}, + {name: "中度用户_按烟龄", baselinePerDay: 8, smokingYears: 3.2, expectedSegment: "moderate"}, + {name: "重度用户_按日均", baselinePerDay: 20, smokingYears: 2, expectedSegment: "heavy"}, + {name: "重度用户_按烟龄", baselinePerDay: 9, smokingYears: 10, expectedSegment: "heavy"}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := deriveUserSegment(tc.baselinePerDay, tc.smokingYears) + if got != tc.expectedSegment { + t.Fatalf("deriveUserSegment(%d, %.1f)=%s, want=%s", tc.baselinePerDay, tc.smokingYears, got, tc.expectedSegment) + } + }) + } +}