diff --git a/.env.example b/.env.example index a400514..f115268 100755 --- a/.env.example +++ b/.env.example @@ -8,6 +8,13 @@ DB_PORT=3306 DB_USER=root DB_PASSWORD=your_password DB_NAME=wx_service +# 多数据源(可选) +# DB_INSTANCES=lawyer,reporting +# DB_LAWYER_HOST=127.0.0.1 +# DB_LAWYER_PORT=3306 +# DB_LAWYER_USER=another_user +# DB_LAWYER_PASSWORD=another_password +# DB_LAWYER_NAME=lawyer # JWT配置 JWT_SECRET=your-secret-key-change-in-production diff --git a/config/config.go b/config/config.go index 5504fc4..525a1a9 100755 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "log" "os" "strconv" + "strings" "time" "github.com/joho/godotenv" @@ -27,6 +28,11 @@ type ServerConfig struct { } type DatabaseConfig struct { + Default DatabaseInstanceConfig + Additional map[string]DatabaseInstanceConfig +} + +type DatabaseInstanceConfig struct { Host string Port string User string @@ -94,17 +100,22 @@ func LoadConfig() { log.Println("未找到 .env 文件,使用环境变量") } + defaultDB := DatabaseInstanceConfig{ + Host: getEnv("DB_HOST", "localhost"), + Port: getEnv("DB_PORT", "3306"), + User: getEnv("DB_USER", "root"), + Password: getEnv("DB_PASSWORD", ""), + DBName: getEnv("DB_NAME", "wx_service"), + } + AppConfig = &Config{ Server: ServerConfig{ Port: getEnv("SERVER_PORT", "8080"), Mode: getEnv("GIN_MODE", "debug"), }, Database: DatabaseConfig{ - Host: getEnv("DB_HOST", "localhost"), - Port: getEnv("DB_PORT", "3306"), - User: getEnv("DB_USER", "root"), - Password: getEnv("DB_PASSWORD", ""), - DBName: getEnv("DB_NAME", "wx_service"), + Default: defaultDB, + Additional: loadAdditionalDBConfigs(defaultDB), }, JWT: JWTConfig{ Secret: getEnv("JWT_SECRET", "your-secret-key"), @@ -148,6 +159,30 @@ func LoadConfig() { } } +func loadAdditionalDBConfigs(defaultCfg DatabaseInstanceConfig) map[string]DatabaseInstanceConfig { + instances := strings.Split(getEnv("DB_INSTANCES", ""), ",") + result := make(map[string]DatabaseInstanceConfig) + for _, rawName := range instances { + name := strings.TrimSpace(rawName) + if name == "" { + continue + } + upperName := strings.ToUpper(name) + prefix := "DB_" + upperName + "_" + result[strings.ToLower(name)] = DatabaseInstanceConfig{ + Host: getEnv(prefix+"HOST", defaultCfg.Host), + Port: getEnv(prefix+"PORT", defaultCfg.Port), + User: getEnv(prefix+"USER", defaultCfg.User), + Password: getEnv(prefix+"PASSWORD", defaultCfg.Password), + DBName: getEnv(prefix+"NAME", defaultCfg.DBName), + } + } + if len(result) == 0 { + return nil + } + return result +} + func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value diff --git a/docs/README.md b/docs/README.md index dd12c83..acea473 100644 --- a/docs/README.md +++ b/docs/README.md @@ -27,6 +27,7 @@ 2. 按实际环境填写以下变量: - `SERVER_PORT`:HTTP 服务端口,例如 `8080`。 - `DB_HOST/DB_PORT/DB_USER/DB_PASSWORD/DB_NAME`:MySQL 连接信息。 + - 若需连接额外数据库,可设置 `DB_INSTANCES=lawyer,reporting` 并依次提供 `DB__HOST/PORT/USER/PASSWORD/NAME`;未指定的字段会回落到默认数据库的同名配置。 3. 如果需要,替换 `GIN_MODE`、`JWT_SECRET` 等其他变量。 4. 通过 `docs/sql/users.sql` 初始化 `mini_programs` 与 `users` 表,并插入每个小程序的 `name/app_id/app_secret`。 diff --git a/internal/database/database.go b/internal/database/database.go index 36097dc..52b6624 100755 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -11,11 +11,38 @@ import ( "gorm.io/gorm/logger" ) -var DB *gorm.DB +var ( + DB *gorm.DB + additionalDBs map[string]*gorm.DB +) func InitDB() error { cfg := config.AppConfig.Database + defaultDB, err := openConnection(cfg.Default) + if err != nil { + return fmt.Errorf("连接数据库失败: %v", err) + } + DB = defaultDB + log.Println("默认数据库连接成功") + + if len(cfg.Additional) > 0 { + additionalDBs = make(map[string]*gorm.DB) + for name, instanceCfg := range cfg.Additional { + conn, err := openConnection(instanceCfg) + if err != nil { + return fmt.Errorf("连接数据库[%s]失败: %v", name, err) + } + additionalDBs[strings.ToLower(name)] = conn + log.Printf("数据库连接成功: %s\n", name) + } + } else { + additionalDBs = nil + } + return nil +} + +func openConnection(cfg config.DatabaseInstanceConfig) (*gorm.DB, error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, @@ -24,17 +51,24 @@ func InitDB() error { cfg.DBName, ) - var err error - DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), }) - if err != nil { - return fmt.Errorf("连接数据库失败: %v", err) + return nil, err } + return db, nil +} - log.Println("数据库连接成功") - return nil +func GetAdditionalDB(name string) (*gorm.DB, bool) { + if name == "" || strings.EqualFold(name, "default") { + return DB, DB != nil + } + if additionalDBs == nil { + return nil, false + } + db, ok := additionalDBs[strings.ToLower(name)] + return db, ok } func AutoMigrate(models ...interface{}) error {