package service import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "sort" "strings" "time" "gorm.io/gorm" "wx_service/config" usermodel "wx_service/internal/model" smokemodel "wx_service/internal/smoke/model" ) var ( ErrAIAdviceLocked = errors.New("ai advice is locked, vip or ad unlock required") ErrAIServiceDisabled = errors.New("ai service is not configured") ErrNoSmokeLogs = errors.New("no smoke logs for date") ) const ( DefaultAdvicePromptVersion = "v1" defaultTemperature = 0.7 ) type SmokeAIAdviceService struct { db *gorm.DB cfg config.AIConfig client *http.Client } func NewSmokeAIAdviceService(db *gorm.DB, cfg config.AIConfig) *SmokeAIAdviceService { timeout := cfg.RequestTimeout if timeout <= 0 { timeout = 15 * time.Second } return &SmokeAIAdviceService{ db: db, cfg: cfg, client: &http.Client{ Timeout: timeout, }, } } type adviceSnapshotNode struct { Time string `json:"time"` Num int `json:"num"` Level int64 `json:"level"` Remark string `json:"remark,omitempty"` } type adviceSnapshot struct { Date string `json:"date"` TotalNum int `json:"total_num"` Nodes []adviceSnapshotNode `json:"nodes"` } func (s *SmokeAIAdviceService) GetOrGenerate(ctx context.Context, user *usermodel.User, adviceDate time.Time, promptVersion string) (*smokemodel.SmokeAIAdvice, error) { if promptVersion == "" { promptVersion = DefaultAdvicePromptVersion } cached, err := s.getCached(ctx, int(user.ID), adviceDate, promptVersion) if err != nil { return nil, err } if cached != nil { return cached, nil } allowed, err := s.isAllowed(ctx, user, adviceDate) if err != nil { return nil, err } if !allowed { return nil, ErrAIAdviceLocked } snapshot, snapshotJSON, err := s.buildSnapshot(ctx, int(user.ID), adviceDate) if err != nil { return nil, err } adviceText, modelName, tokensIn, tokensOut, err := s.callAI(ctx, snapshot) if err != nil { return nil, err } now := time.Now().Unix() createTime := now updateTime := now record := smokemodel.SmokeAIAdvice{ UID: int(user.ID), AdviceDate: dateOnly(adviceDate), PromptVersion: promptVersion, Provider: "openai-compatible", Model: modelName, InputSnapshot: snapshotJSON, Advice: adviceText, TokensIn: tokensIn, TokensOut: tokensOut, CreateTime: &createTime, UpdateTime: &updateTime, } if err := s.db.WithContext(ctx).Create(&record).Error; err != nil { return nil, fmt.Errorf("save ai advice: %w", err) } return &record, nil } func (s *SmokeAIAdviceService) Unlock(ctx context.Context, user *usermodel.User, unlockDate time.Time) error { now := time.Now() nowUnix := now.Unix() startOfDay := dateOnly(unlockDate) var existing smokemodel.SmokeAIAdviceUnlock tx := s.db.WithContext(ctx) err := tx.Where("uid = ? AND unlock_date = ? AND (deletetime IS NULL OR deletetime = 0)", user.ID, startOfDay.Format("2006-01-02")). First(&existing).Error if err == nil { return tx.Model(&existing).Updates(map[string]interface{}{ "ad_watched_at": now, "updatetime": nowUnix, }).Error } if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("load unlock record: %w", err) } createTime := nowUnix updateTime := nowUnix record := smokemodel.SmokeAIAdviceUnlock{ UID: int(user.ID), UnlockDate: startOfDay, AdWatchedAt: now, CreateTime: &createTime, UpdateTime: &updateTime, } if err := tx.Create(&record).Error; err != nil { return fmt.Errorf("create unlock record: %w", err) } return nil } func (s *SmokeAIAdviceService) getCached(ctx context.Context, uid int, adviceDate time.Time, promptVersion string) (*smokemodel.SmokeAIAdvice, error) { var record smokemodel.SmokeAIAdvice err := s.db.WithContext(ctx). Where("uid = ? AND advice_date = ? AND prompt_version = ? AND (deletetime IS NULL OR deletetime = 0)", uid, dateOnly(adviceDate).Format("2006-01-02"), promptVersion). First(&record).Error if err == nil { return &record, nil } if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("load cached advice: %w", err) } func (s *SmokeAIAdviceService) isAllowed(ctx context.Context, user *usermodel.User, adviceDate time.Time) (bool, error) { isVIP, err := s.isVIP(ctx, user) if err != nil { return false, err } if isVIP { return true, nil } return s.isUnlocked(ctx, int(user.ID), adviceDate) } func (s *SmokeAIAdviceService) isVIP(ctx context.Context, user *usermodel.User) (bool, error) { now := time.Now() var count int64 if err := s.db.WithContext(ctx). Model(&usermodel.UserMembership{}). Where("mini_program_id = ? AND user_id = ? AND status = ? AND ends_at > ?", user.MiniProgramID, user.ID, "active", now). Count(&count).Error; err != nil { return false, fmt.Errorf("check vip: %w", err) } return count > 0, nil } func (s *SmokeAIAdviceService) isUnlocked(ctx context.Context, uid int, adviceDate time.Time) (bool, error) { startOfDay := dateOnly(adviceDate) var unlock smokemodel.SmokeAIAdviceUnlock err := s.db.WithContext(ctx). Where("uid = ? AND unlock_date = ? AND (deletetime IS NULL OR deletetime = 0)", uid, startOfDay.Format("2006-01-02")). First(&unlock).Error if err == nil { return true, nil } if errors.Is(err, gorm.ErrRecordNotFound) { return false, nil } return false, fmt.Errorf("check unlock: %w", err) } func (s *SmokeAIAdviceService) buildSnapshot(ctx context.Context, uid int, adviceDate time.Time) (adviceSnapshot, []byte, error) { var logs []smokemodel.SmokeLog if err := s.db.WithContext(ctx). Where("uid = ? AND smoke_time = ? AND (deletetime IS NULL OR deletetime = 0)", uid, dateOnly(adviceDate).Format("2006-01-02")). Find(&logs).Error; err != nil { return adviceSnapshot{}, nil, fmt.Errorf("load smoke logs: %w", err) } if len(logs) == 0 { return adviceSnapshot{}, nil, ErrNoSmokeLogs } type timedLog struct { log smokemodel.SmokeLog eventAt time.Time hasEvent bool } timed := make([]timedLog, 0, len(logs)) total := 0 for _, l := range logs { total += l.Num var eventAt time.Time has := false if l.SmokeAt != nil { eventAt = *l.SmokeAt has = true } else if l.CreateTime != nil && *l.CreateTime > 0 { eventAt = time.Unix(*l.CreateTime, 0).In(time.Local) has = true } timed = append(timed, timedLog{log: l, eventAt: eventAt, hasEvent: has}) } // 按时间节点排序(若没有 event time,则排在最后,仍保持稳定) sort.SliceStable(timed, func(i, j int) bool { a := timed[i] b := timed[j] if a.hasEvent && b.hasEvent { if !a.eventAt.Equal(b.eventAt) { return a.eventAt.Before(b.eventAt) } return a.log.ID < b.log.ID } if a.hasEvent != b.hasEvent { return a.hasEvent } return a.log.ID < b.log.ID }) nodes := make([]adviceSnapshotNode, 0, len(timed)) for _, t := range timed { timeLabel := "" if t.hasEvent { timeLabel = t.eventAt.Format("15:04") } nodes = append(nodes, adviceSnapshotNode{ Time: timeLabel, Num: t.log.Num, Level: t.log.Level, Remark: t.log.Remark, }) } snap := adviceSnapshot{ Date: dateOnly(adviceDate).Format("2006-01-02"), TotalNum: total, Nodes: nodes, } b, err := json.Marshal(snap) if err != nil { return adviceSnapshot{}, nil, fmt.Errorf("marshal snapshot: %w", err) } return snap, b, nil } func (s *SmokeAIAdviceService) callAI(ctx context.Context, snap adviceSnapshot) (string, string, *int, *int, error) { if s.cfg.APIKey == "" || s.cfg.Model == "" || s.cfg.BaseURL == "" { return "", "", nil, nil, ErrAIServiceDisabled } systemPrompt := strings.TrimSpace(` 你是一名专业的戒烟教练与行为改变顾问。你需要基于用户昨天的抽烟总量与时间节点,给出可执行、可量化的戒烟/控烟建议。 要求: 1) 用中文输出; 2) 先给出对昨天模式的简短分析(1-3条); 3) 给出今天的具体行动方案(至少5条,包含替代行为、触发场景应对、时间节点策略); 4) 给出一个“如果忍不住想抽”的 60 秒应对流程; 5) 语气友好、不指责;不提供医疗诊断。 `) userPrompt := fmt.Sprintf("用户昨日数据(JSON):\n%s", mustJSON(snap)) reqBody := chatCompletionRequest{ Model: s.cfg.Model, Messages: []chatMessage{ {Role: "system", Content: systemPrompt}, {Role: "user", Content: userPrompt}, }, Temperature: defaultTemperature, } payload, err := json.Marshal(reqBody) if err != nil { return "", "", nil, nil, fmt.Errorf("marshal ai request: %w", err) } endpoint := strings.TrimRight(s.cfg.BaseURL, "/") + "/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(payload)) if err != nil { return "", "", nil, nil, fmt.Errorf("build ai request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+s.cfg.APIKey) resp, err := s.client.Do(httpReq) if err != nil { return "", "", nil, nil, fmt.Errorf("call ai: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", "", nil, nil, fmt.Errorf("read ai response: %w", err) } if resp.StatusCode != http.StatusOK { return "", "", nil, nil, fmt.Errorf("ai http %d: %s", resp.StatusCode, truncateString(string(body), 512)) } var parsed chatCompletionResponse if err := json.Unmarshal(body, &parsed); err != nil { return "", "", nil, nil, fmt.Errorf("parse ai response: %w", err) } if len(parsed.Choices) == 0 { return "", "", nil, nil, errors.New("ai response has no choices") } content := strings.TrimSpace(parsed.Choices[0].Message.Content) if content == "" { return "", "", nil, nil, errors.New("ai response content is empty") } modelName := parsed.Model if modelName == "" { modelName = s.cfg.Model } var tokensIn, tokensOut *int if parsed.Usage != nil { tokensIn = &parsed.Usage.PromptTokens tokensOut = &parsed.Usage.CompletionTokens } return content, modelName, tokensIn, tokensOut, nil } type chatMessage struct { Role string `json:"role"` Content string `json:"content"` } type chatCompletionRequest struct { Model string `json:"model"` Messages []chatMessage `json:"messages"` Temperature float64 `json:"temperature,omitempty"` } type chatCompletionResponse struct { Model string `json:"model"` Choices []struct { Message chatMessage `json:"message"` } `json:"choices"` Usage *struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` } func dateOnly(t time.Time) time.Time { local := t.In(time.Local) return time.Date(local.Year(), local.Month(), local.Day(), 0, 0, 0, 0, time.Local) } func mustJSON(v any) string { b, err := json.Marshal(v) if err != nil { return "{}" } return string(b) } func truncateString(s string, max int) string { if max <= 0 || len(s) <= max { return s } return s[:max] }