112 lines
2.5 KiB
Go
Executable File
112 lines
2.5 KiB
Go
Executable File
package database
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
"wx_service/config"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
var (
|
|
DB *gorm.DB
|
|
additionalDBs map[string]*gorm.DB
|
|
)
|
|
|
|
func InitDB() error {
|
|
cfg := config.AppConfig.Database
|
|
|
|
defaultDB, err := openConnection(cfg.Default)
|
|
if err != nil {
|
|
return fmt.Errorf("连接数据库失败: %v", err)
|
|
}
|
|
DB = defaultDB
|
|
log.Println("默认数据库连接成功")
|
|
|
|
if len(cfg.Additional) > 0 {
|
|
additionalDBs = make(map[string]*gorm.DB)
|
|
for name, instanceCfg := range cfg.Additional {
|
|
conn, err := openConnection(instanceCfg)
|
|
if err != nil {
|
|
return fmt.Errorf("连接数据库[%s]失败: %v", name, err)
|
|
}
|
|
additionalDBs[strings.ToLower(name)] = conn
|
|
log.Printf("数据库连接成功: %s\n", name)
|
|
}
|
|
} else {
|
|
additionalDBs = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func openConnection(cfg config.DatabaseInstanceConfig) (*gorm.DB, error) {
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
cfg.User,
|
|
cfg.Password,
|
|
cfg.Host,
|
|
cfg.Port,
|
|
cfg.DBName,
|
|
)
|
|
|
|
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func GetAdditionalDB(name string) (*gorm.DB, bool) {
|
|
if name == "" || strings.EqualFold(name, "default") {
|
|
return DB, DB != nil
|
|
}
|
|
if additionalDBs == nil {
|
|
return nil, false
|
|
}
|
|
db, ok := additionalDBs[strings.ToLower(name)]
|
|
return db, ok
|
|
}
|
|
|
|
func AutoMigrate(models ...interface{}) error {
|
|
type tableCommenter interface {
|
|
TableComment() string
|
|
}
|
|
|
|
for _, m := range models {
|
|
tx := DB
|
|
comment := ""
|
|
if tc, ok := m.(tableCommenter); ok {
|
|
comment = strings.TrimSpace(tc.TableComment())
|
|
if comment != "" {
|
|
tx = tx.Set("gorm:table_options", fmt.Sprintf("COMMENT='%s'", escapeSQLComment(comment)))
|
|
}
|
|
}
|
|
|
|
if err := tx.AutoMigrate(m); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 尝试为已存在的表补齐 table comment(即使表已创建,也能更新注释)。
|
|
if comment != "" {
|
|
stmt := &gorm.Statement{DB: DB}
|
|
if err := stmt.Parse(m); err == nil && stmt.Schema != nil && stmt.Schema.Table != "" {
|
|
_ = DB.Exec(fmt.Sprintf("ALTER TABLE `%s` COMMENT = '%s'", stmt.Schema.Table, escapeSQLComment(comment))).Error
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := repairSmokeAINextSmokeIndexes(DB); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func escapeSQLComment(s string) string {
|
|
// MySQL: 单引号用两个单引号转义
|
|
return strings.ReplaceAll(s, "'", "''")
|
|
}
|