6cf7eb2294
- Added new API endpoint `GET /api/v1/smoke/next_smoke_time` to provide AI-generated suggestions for the next smoking time based on user data. - Introduced a new database table `fa_smoke_ai_next_smoke` to store structured AI time node suggestions. - Updated smoke handler and service to integrate the new AI next smoke time functionality. - Enhanced documentation to reflect the new API endpoint and its usage, including details on how to generate AI time nodes.
470 lines
14 KiB
Go
470 lines
14 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"wx_service/config"
|
|
usermodel "wx_service/internal/model"
|
|
smokemodel "wx_service/internal/smoke/model"
|
|
)
|
|
|
|
var (
|
|
ErrAIAdviceLocked = errors.New("ai advice is locked, vip or ad unlock required")
|
|
ErrAIServiceDisabled = errors.New("ai service is not configured")
|
|
ErrNoSmokeLogs = errors.New("no smoke logs for date")
|
|
)
|
|
|
|
const (
|
|
DefaultAdvicePromptVersion = "v2"
|
|
defaultTemperature = 0.7
|
|
)
|
|
|
|
const (
|
|
SmokeAIAdviceTypeDaily = "daily_advice"
|
|
SmokeAIAdviceTypeNextSmoke = "next_smoke_time"
|
|
)
|
|
|
|
type SmokeAIAdviceService struct {
|
|
db *gorm.DB
|
|
cfg config.AIConfig
|
|
client *http.Client
|
|
}
|
|
|
|
func NewSmokeAIAdviceService(db *gorm.DB, cfg config.AIConfig) *SmokeAIAdviceService {
|
|
timeout := cfg.RequestTimeout
|
|
if timeout <= 0 {
|
|
timeout = 15 * time.Second
|
|
}
|
|
return &SmokeAIAdviceService{
|
|
db: db,
|
|
cfg: cfg,
|
|
client: &http.Client{
|
|
Timeout: timeout,
|
|
},
|
|
}
|
|
}
|
|
|
|
type adviceSnapshotNode struct {
|
|
Time string `json:"time"`
|
|
Num int `json:"num"`
|
|
Level int64 `json:"level"`
|
|
Remark string `json:"remark,omitempty"`
|
|
}
|
|
|
|
type adviceSnapshot struct {
|
|
Date string `json:"date"`
|
|
TotalNum int `json:"total_num"`
|
|
Nodes []adviceSnapshotNode `json:"nodes"`
|
|
Profile *adviceUserProfile `json:"profile,omitempty"`
|
|
}
|
|
|
|
type adviceUserProfile struct {
|
|
BaselineCigsPerDay int `json:"baseline_cigs_per_day,omitempty"`
|
|
SmokingYears float64 `json:"smoking_years,omitempty"`
|
|
PackPriceCent int `json:"pack_price_cent,omitempty"`
|
|
SmokeMotivations []string `json:"smoke_motivations,omitempty"`
|
|
QuitMotivations []string `json:"quit_motivations,omitempty"`
|
|
WakeUpTime string `json:"wake_up_time,omitempty"`
|
|
SleepTime string `json:"sleep_time,omitempty"`
|
|
AwakeMinutes int `json:"awake_minutes,omitempty"`
|
|
BaselineIntervalMinutes int `json:"baseline_interval_minutes,omitempty"`
|
|
OnboardingCompletedAtISO string `json:"onboarding_completed_at,omitempty"`
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) GetOrGenerate(ctx context.Context, user *usermodel.User, adviceDate time.Time, promptVersion string) (*smokemodel.SmokeAIAdvice, error) {
|
|
if promptVersion == "" {
|
|
promptVersion = DefaultAdvicePromptVersion
|
|
}
|
|
|
|
cached, err := s.getCached(ctx, int(user.ID), SmokeAIAdviceTypeDaily, adviceDate, promptVersion)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cached != nil {
|
|
return cached, nil
|
|
}
|
|
|
|
allowed, err := s.isAllowed(ctx, user, adviceDate)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !allowed {
|
|
return nil, ErrAIAdviceLocked
|
|
}
|
|
|
|
snapshot, snapshotJSON, err := s.buildSnapshot(ctx, int(user.ID), adviceDate)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
adviceText, modelName, tokensIn, tokensOut, err := s.callAI(ctx, snapshot)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
createTime := now
|
|
updateTime := now
|
|
|
|
record := smokemodel.SmokeAIAdvice{
|
|
UID: int(user.ID),
|
|
Type: SmokeAIAdviceTypeDaily,
|
|
AdviceDate: dateOnly(adviceDate),
|
|
PromptVersion: promptVersion,
|
|
Provider: "openai-compatible",
|
|
Model: modelName,
|
|
InputSnapshot: snapshotJSON,
|
|
Advice: adviceText,
|
|
TokensIn: tokensIn,
|
|
TokensOut: tokensOut,
|
|
CreateTime: &createTime,
|
|
UpdateTime: &updateTime,
|
|
}
|
|
|
|
if err := s.db.WithContext(ctx).Create(&record).Error; err != nil {
|
|
return nil, fmt.Errorf("save ai advice: %w", err)
|
|
}
|
|
return &record, nil
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) Unlock(ctx context.Context, user *usermodel.User, unlockDate time.Time) error {
|
|
now := time.Now()
|
|
nowUnix := now.Unix()
|
|
startOfDay := dateOnly(unlockDate)
|
|
|
|
var existing smokemodel.SmokeAIAdviceUnlock
|
|
tx := s.db.WithContext(ctx)
|
|
err := tx.Where("uid = ? AND unlock_date = ? AND (deletetime IS NULL OR deletetime = 0)", user.ID, startOfDay.Format("2006-01-02")).
|
|
First(&existing).Error
|
|
if err == nil {
|
|
return tx.Model(&existing).Updates(map[string]interface{}{
|
|
"ad_watched_at": now,
|
|
"updatetime": nowUnix,
|
|
}).Error
|
|
}
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("load unlock record: %w", err)
|
|
}
|
|
|
|
createTime := nowUnix
|
|
updateTime := nowUnix
|
|
record := smokemodel.SmokeAIAdviceUnlock{
|
|
UID: int(user.ID),
|
|
UnlockDate: startOfDay,
|
|
AdWatchedAt: now,
|
|
CreateTime: &createTime,
|
|
UpdateTime: &updateTime,
|
|
}
|
|
if err := tx.Create(&record).Error; err != nil {
|
|
return fmt.Errorf("create unlock record: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) getCached(ctx context.Context, uid int, adviceType string, adviceDate time.Time, promptVersion string) (*smokemodel.SmokeAIAdvice, error) {
|
|
var record smokemodel.SmokeAIAdvice
|
|
err := s.db.WithContext(ctx).
|
|
Where("uid = ? AND type = ? AND advice_date = ? AND prompt_version = ? AND (deletetime IS NULL OR deletetime = 0)",
|
|
uid, adviceType, dateOnly(adviceDate).Format("2006-01-02"), promptVersion).
|
|
First(&record).Error
|
|
if err == nil {
|
|
return &record, nil
|
|
}
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("load cached advice: %w", err)
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) isAllowed(ctx context.Context, user *usermodel.User, adviceDate time.Time) (bool, error) {
|
|
isVIP, err := s.isVIP(ctx, user)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if isVIP {
|
|
return true, nil
|
|
}
|
|
return s.isUnlocked(ctx, int(user.ID), adviceDate)
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) isVIP(ctx context.Context, user *usermodel.User) (bool, error) {
|
|
now := time.Now()
|
|
var count int64
|
|
if err := s.db.WithContext(ctx).
|
|
Model(&usermodel.UserMembership{}).
|
|
Where("mini_program_id = ? AND user_id = ? AND status = ? AND ends_at > ?",
|
|
user.MiniProgramID, user.ID, "active", now).
|
|
Count(&count).Error; err != nil {
|
|
return false, fmt.Errorf("check vip: %w", err)
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) isUnlocked(ctx context.Context, uid int, adviceDate time.Time) (bool, error) {
|
|
startOfDay := dateOnly(adviceDate)
|
|
var unlock smokemodel.SmokeAIAdviceUnlock
|
|
err := s.db.WithContext(ctx).
|
|
Where("uid = ? AND unlock_date = ? AND (deletetime IS NULL OR deletetime = 0)", uid, startOfDay.Format("2006-01-02")).
|
|
First(&unlock).Error
|
|
if err == nil {
|
|
return true, nil
|
|
}
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("check unlock: %w", err)
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) buildSnapshot(ctx context.Context, uid int, adviceDate time.Time) (adviceSnapshot, []byte, error) {
|
|
var logs []smokemodel.SmokeLog
|
|
if err := s.db.WithContext(ctx).
|
|
Where("uid = ? AND smoke_time = ? AND (deletetime IS NULL OR deletetime = 0)", uid, dateOnly(adviceDate).Format("2006-01-02")).
|
|
Find(&logs).Error; err != nil {
|
|
return adviceSnapshot{}, nil, fmt.Errorf("load smoke logs: %w", err)
|
|
}
|
|
if len(logs) == 0 {
|
|
return adviceSnapshot{}, nil, ErrNoSmokeLogs
|
|
}
|
|
|
|
profile := loadAdviceUserProfile(ctx, s.db, uid)
|
|
|
|
type timedLog struct {
|
|
log smokemodel.SmokeLog
|
|
eventAt time.Time
|
|
hasEvent bool
|
|
}
|
|
|
|
timed := make([]timedLog, 0, len(logs))
|
|
total := 0
|
|
for _, l := range logs {
|
|
total += l.Num
|
|
var eventAt time.Time
|
|
has := false
|
|
if l.SmokeAt != nil {
|
|
eventAt = *l.SmokeAt
|
|
has = true
|
|
} else if l.CreateTime != nil && *l.CreateTime > 0 {
|
|
eventAt = time.Unix(*l.CreateTime, 0).In(time.Local)
|
|
has = true
|
|
}
|
|
timed = append(timed, timedLog{log: l, eventAt: eventAt, hasEvent: has})
|
|
}
|
|
|
|
// 按时间节点排序(若没有 event time,则排在最后,仍保持稳定)
|
|
sort.SliceStable(timed, func(i, j int) bool {
|
|
a := timed[i]
|
|
b := timed[j]
|
|
if a.hasEvent && b.hasEvent {
|
|
if !a.eventAt.Equal(b.eventAt) {
|
|
return a.eventAt.Before(b.eventAt)
|
|
}
|
|
return a.log.ID < b.log.ID
|
|
}
|
|
if a.hasEvent != b.hasEvent {
|
|
return a.hasEvent
|
|
}
|
|
return a.log.ID < b.log.ID
|
|
})
|
|
|
|
nodes := make([]adviceSnapshotNode, 0, len(timed))
|
|
for _, t := range timed {
|
|
timeLabel := ""
|
|
if t.hasEvent {
|
|
timeLabel = t.eventAt.Format("15:04")
|
|
}
|
|
nodes = append(nodes, adviceSnapshotNode{
|
|
Time: timeLabel,
|
|
Num: t.log.Num,
|
|
Level: t.log.Level,
|
|
Remark: t.log.Remark,
|
|
})
|
|
}
|
|
|
|
snap := adviceSnapshot{
|
|
Date: dateOnly(adviceDate).Format("2006-01-02"),
|
|
TotalNum: total,
|
|
Nodes: nodes,
|
|
Profile: profile,
|
|
}
|
|
|
|
b, err := json.Marshal(snap)
|
|
if err != nil {
|
|
return adviceSnapshot{}, nil, fmt.Errorf("marshal snapshot: %w", err)
|
|
}
|
|
return snap, b, nil
|
|
}
|
|
|
|
func loadAdviceUserProfile(ctx context.Context, db *gorm.DB, uid int) *adviceUserProfile {
|
|
var profile smokemodel.SmokeUserProfile
|
|
err := db.WithContext(ctx).
|
|
Where("uid = ? AND deleted_at IS NULL", uid).
|
|
First(&profile).Error
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
awake := defaultAwakeMinutes
|
|
wake := strings.TrimSpace(profile.WakeUpTime)
|
|
sleep := strings.TrimSpace(profile.SleepTime)
|
|
if v, err := awakeMinutesWithFallback(wake, sleep); err == nil {
|
|
awake = v
|
|
} else {
|
|
wake = ""
|
|
sleep = ""
|
|
}
|
|
|
|
out := adviceUserProfile{
|
|
BaselineCigsPerDay: profile.BaselineCigsPerDay,
|
|
SmokingYears: profile.SmokingYears,
|
|
PackPriceCent: profile.PackPriceCent,
|
|
SmokeMotivations: []string(profile.SmokeMotivations),
|
|
QuitMotivations: []string(profile.QuitMotivations),
|
|
WakeUpTime: wake,
|
|
SleepTime: sleep,
|
|
AwakeMinutes: awake,
|
|
BaselineIntervalMinutes: baselineIntervalMinutes(awake, profile.BaselineCigsPerDay),
|
|
}
|
|
if profile.OnboardingCompletedAt != nil {
|
|
out.OnboardingCompletedAtISO = profile.OnboardingCompletedAt.In(time.Local).Format(time.RFC3339)
|
|
}
|
|
return &out
|
|
}
|
|
|
|
func (s *SmokeAIAdviceService) callAI(ctx context.Context, snap adviceSnapshot) (string, string, *int, *int, error) {
|
|
if s.cfg.APIKey == "" || s.cfg.Model == "" || s.cfg.BaseURL == "" {
|
|
return "", "", nil, nil, ErrAIServiceDisabled
|
|
}
|
|
|
|
systemPrompt := strings.TrimSpace(`
|
|
你是一名专业的戒烟教练与行为改变顾问。你需要基于用户昨天的抽烟总量与时间节点,给出可执行、可量化的戒烟/控烟建议。
|
|
要求:
|
|
1) 用中文输出;
|
|
2) 先给出对昨天模式的简短分析(1-3条);
|
|
3) 给出今天的具体行动方案(至少5条,包含替代行为、触发场景应对、时间节点策略);
|
|
4) 如果 profile 中提供了「作息时间」,建议的执行时间点要避开用户睡眠区间;
|
|
5) 如果 profile 中提供了「抽烟动机/戒烟动力」,你需要在建议中更有针对性地引用它们:
|
|
- 动机:用于解释触发场景与 remark 的关联,给出替代行为;
|
|
- 动力:用于“情感阻断/动摇时的自我提醒”(给 2-3 条可复述的话术);
|
|
6) 给出一个“如果忍不住想抽”的 60 秒应对流程;
|
|
7) 语气友好、不指责;不提供医疗诊断。
|
|
`)
|
|
|
|
userPrompt := fmt.Sprintf("用户昨日数据(JSON):\n%s", mustJSON(snap))
|
|
|
|
reqBody := chatCompletionRequest{
|
|
Model: s.cfg.Model,
|
|
Messages: []chatMessage{
|
|
{Role: "system", Content: systemPrompt},
|
|
{Role: "user", Content: userPrompt},
|
|
},
|
|
Temperature: defaultTemperature,
|
|
}
|
|
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", "", nil, nil, fmt.Errorf("marshal ai request: %w", err)
|
|
}
|
|
|
|
endpoint := strings.TrimRight(s.cfg.BaseURL, "/") + "/chat/completions"
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(payload))
|
|
if err != nil {
|
|
return "", "", nil, nil, fmt.Errorf("build ai request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", "Bearer "+s.cfg.APIKey)
|
|
|
|
resp, err := s.client.Do(httpReq)
|
|
if err != nil {
|
|
return "", "", nil, nil, fmt.Errorf("call ai: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", "", nil, nil, fmt.Errorf("read ai response: %w", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", "", nil, nil, fmt.Errorf("ai http %d: %s", resp.StatusCode, truncateString(string(body), 512))
|
|
}
|
|
|
|
var parsed chatCompletionResponse
|
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
|
return "", "", nil, nil, fmt.Errorf("parse ai response: %w", err)
|
|
}
|
|
if len(parsed.Choices) == 0 {
|
|
return "", "", nil, nil, errors.New("ai response has no choices")
|
|
}
|
|
|
|
content := strings.TrimSpace(parsed.Choices[0].Message.Content)
|
|
if content == "" {
|
|
return "", "", nil, nil, errors.New("ai response content is empty")
|
|
}
|
|
|
|
modelName := parsed.Model
|
|
if modelName == "" {
|
|
modelName = s.cfg.Model
|
|
}
|
|
|
|
var tokensIn, tokensOut *int
|
|
if parsed.Usage != nil {
|
|
tokensIn = &parsed.Usage.PromptTokens
|
|
tokensOut = &parsed.Usage.CompletionTokens
|
|
}
|
|
|
|
return content, modelName, tokensIn, tokensOut, nil
|
|
}
|
|
|
|
type chatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type chatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []chatMessage `json:"messages"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
}
|
|
|
|
type chatCompletionResponse struct {
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Message chatMessage `json:"message"`
|
|
} `json:"choices"`
|
|
Usage *struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
func dateOnly(t time.Time) time.Time {
|
|
local := t.In(time.Local)
|
|
return time.Date(local.Year(), local.Month(), local.Day(), 0, 0, 0, 0, time.Local)
|
|
}
|
|
|
|
func mustJSON(v any) string {
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
return "{}"
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func truncateString(s string, max int) string {
|
|
if max <= 0 || len(s) <= max {
|
|
return s
|
|
}
|
|
return s[:max]
|
|
}
|