diff --git a/core/loader.go b/core/loader.go index 9c16c29..750011f 100644 --- a/core/loader.go +++ b/core/loader.go @@ -2,7 +2,6 @@ package core import ( "io" - "strings" "mongo.games.com/goserver/core/logger" "mongo.games.com/goserver/core/viperx" @@ -47,12 +46,7 @@ func RegisterConfigEncryptor(h viperx.ConfigFileEncryptorHook) { // LoadPackages 加载功能包 func LoadPackages(configFile string) { - val := strings.Split(configFile, ".") - if len(val) != 2 { - panic("config file name error") - } - - vp := viperx.GetViper(val[0], val[1]) + vp := viperx.GetViper(configFile) var err error var notFoundConfig []string diff --git a/core/mongox/config.go b/core/mongox/config.go new file mode 100644 index 0000000..4845622 --- /dev/null +++ b/core/mongox/config.go @@ -0,0 +1,44 @@ +package mongox + +import ( + "fmt" + + "mongo.games.com/goserver/core" + "mongo.games.com/goserver/core/viperx" +) + +var config = Configuration{} + +type Configuration struct { + Path string +} + +func (c *Configuration) Name() string { + return "mongox" +} + +func (c *Configuration) Init() error { + if c.Path == "" { + c.Path = "mongo.yaml" + } + + vp := viperx.GetViper(c.Path) + + cfg := &Config{} + if err := vp.Unmarshal(cfg); err != nil { + panic(fmt.Sprintf("mongox init error: %v", err)) + } + + Init(cfg) + + return nil +} + +func (c *Configuration) Close() error { + Close() + return nil +} + +func init() { + core.RegistePackage(&config) +} diff --git a/core/mongox/function.go b/core/mongox/function.go index fe2b76a..e92387b 100644 --- a/core/mongox/function.go +++ b/core/mongox/function.go @@ -1,6 +1,7 @@ package mongox import ( + "errors" "reflect" "strings" @@ -12,7 +13,7 @@ import ( type DatabaseType string const ( - KeyGlobal = "global" + Global = "global" DatabaseUser DatabaseType = "user" DatabaseLog DatabaseType = "log" @@ -26,7 +27,7 @@ func GetClient() (*mongo.Client, error) { return nil, NotInitError } - c, err := _manager.GetCollection(KeyGlobal, string(DatabaseLog), "empty") + c, err := _manager.GetCollection(Global, string(DatabaseLog), "empty") if err != nil { return nil, err } @@ -60,7 +61,7 @@ func GetGlobalDatabase(database DatabaseType) (*Database, error) { return nil, NotInitError } - return _manager.GetDatabase(KeyGlobal, string(database)) + return _manager.GetDatabase(Global, string(database)) } func GetGlobalUserDatabase() (*Database, error) { @@ -83,7 +84,7 @@ func GetGlobalCollection(database DatabaseType, collection string) (*Collection, return nil, NotInitError } - return _manager.GetCollection(KeyGlobal, string(database), collection) + return _manager.GetCollection(Global, string(database), collection) } func GetGlobalUserCollection(collection string) (*Collection, error) { @@ -123,8 +124,8 @@ type ICollectionName interface { CollectionName() string } -// GetTableName 获取文档名 -func GetTableName(model any) string { +// GetCollectionName 获取文档名 +func GetCollectionName(model any) string { if m, ok := model.(ICollectionName); ok { return m.CollectionName() } @@ -140,17 +141,37 @@ func GetTableName(model any) string { return strings.ToLower(t.Name()) } -// GetCollectionDao 获取文档操作接口 +// IDatabaseName 数据库名称接口 +type IDatabaseName interface { + DatabaseName() string +} + +func GetDatabaseName(model any) (string, error) { + if m, ok := model.(IDatabaseName); ok { + return m.DatabaseName(), nil + } + + return "", errors.New("not set database name") +} + +// GetDao 获取文档操作接口 // key: 平台id 或 KeyGlobal // database: 数据库类型 DatabaseType -// f: 文档接口创建函数; 结合 tools/mongoctl 生成 -func GetCollectionDao[T any](key string, database DatabaseType, model any, f func(database *mongo.Database, c *mongo.Collection) T) (T, error) { - collectionName := GetTableName(model) - c, err := GetCollection(key, database, collectionName) +// f: 文档接口创建函数; 结合 tools/mongoctl 使用 +func GetDao[T, M any](key string, f func(database *mongo.Database, c *mongo.Collection) (T, M)) (T, error) { + var z T + t, m := f(nil, nil) + databaseName, err := GetDatabaseName(m) if err != nil { - var z T - logger.Logger.Errorf("GetCollectionDao key:%v database:%v model:%v error: %v", key, database, collectionName, err) + logger.Logger.Errorf("GetDao error: %v", err) return z, err } - return f(c.Database.Database, c.Collection), nil + collectionName := GetCollectionName(m) + c, err := GetCollection(key, DatabaseType(databaseName), collectionName) + if err != nil { + logger.Logger.Errorf("GetDao key:%v database:%v model:%v error: %v", key, databaseName, collectionName, err) + return z, err + } + t, _ = f(c.Database.Database, c.Collection) + return t, nil } diff --git a/core/tools/mongoctl/generator.go b/core/tools/mongoctl/generator.go index 25af672..da3b368 100644 --- a/core/tools/mongoctl/generator.go +++ b/core/tools/mongoctl/generator.go @@ -151,6 +151,9 @@ func (g *generator) makeModelExternalDao(m *model) { replaces[varDaoPrefixNameKey] = m.daoPrefixName replaces[varDaoPackageNameKey] = m.daoPkgName replaces[varDaoPackagePathKey] = m.daoPkgPath + replaces[varModelPackagePathKey] = m.modelPkgPath + replaces[varModelPackageNameKey] = m.modelPkgName + replaces[varModelClassNameKey] = m.modelClassName err = doWrite(file, template.ExternalTemplate, replaces) if err != nil { diff --git a/core/tools/mongoctl/template/template.go b/core/tools/mongoctl/template/template.go index 358aeb8..a444cee 100644 --- a/core/tools/mongoctl/template/template.go +++ b/core/tools/mongoctl/template/template.go @@ -4,8 +4,22 @@ const ExternalTemplate = ` package ${VarDaoPackageName} import ( + "context" + + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" + "mongo.games.com/goserver/core/logger" + "mongo.games.com/goserver/core/mongox" + "${VarDaoPackagePath}/internal" + ${VarModelPackageName} "${VarModelPackagePath}" +) + +var ( + _ = context.Background() + _ = logger.Logger + _ = bson.M{} + _ = mongo.Database{} ) type ${VarDaoPrefixName}Columns = internal.${VarDaoPrefixName}Columns @@ -14,14 +28,28 @@ type ${VarDaoClassName} struct { *internal.${VarDaoClassName} } -func New${VarDaoClassName}(db *mongo.Database, c *mongo.Collection) *${VarDaoClassName} { +func Get${VarDaoClassName}(key string) (*${VarDaoClassName}, error) { + return mongox.GetDao(key, New${VarDaoClassName}) +} + +func New${VarDaoClassName}(db *mongo.Database, c *mongo.Collection) (*${VarDaoClassName}, any) { + if db == nil || c == nil { + return &${VarDaoClassName}{}, &${VarModelPackageName}.${VarModelClassName}{} + } + v := internal.New${VarDaoClassName}(nil) v.Database = db v.Collection = c - panic("创建索引") + + //todo: 创建索引,删除代码 + m :=&${VarModelPackageName}.${VarModelClassName}{} + name := mongox.GetCollectionName(m) + panic(fmt.Sprintf("创建索引 %s", name)) //c.Indexes().CreateOne() //c.Indexes().CreateMany() - return &${VarDaoClassName}{${VarDaoClassName}: v} + //todo: 创建索引,删除代码 + + return &${VarDaoClassName}{${VarDaoClassName}: v}, &${VarModelPackageName}.${VarModelClassName}{} } ` diff --git a/core/viperx/viper.go b/core/viperx/viper.go index 59ff759..6e7de4d 100644 --- a/core/viperx/viper.go +++ b/core/viperx/viper.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "os" + "path/filepath" "github.com/spf13/viper" ) @@ -29,12 +30,11 @@ func RegisterConfigEncryptor(h ConfigFileEncryptorHook) { } // GetViper 获取viper配置 -// name: 配置文件名,不带后缀 -// filetype: 配置文件类型,如json、yaml、ini等 -func GetViper(name, filetype string) *viper.Viper { - buf, err := ReadFile(name, filetype) +// name: 配置文件路径和名称,json、yaml、ini等 +func GetViper(name string) *viper.Viper { + buf, err := ReadFile(name) if err != nil { - panic(fmt.Sprintf("Error while reading config file %s: %v", name+filetype, err)) + panic(fmt.Sprintf("Error while reading config file %s: %v", name, err)) } if configFileEH != nil { @@ -44,17 +44,20 @@ func GetViper(name, filetype string) *viper.Viper { } vp := viper.New() - vp.SetConfigName(name) - vp.SetConfigType(filetype) + ext := filepath.Ext(name) + if len(ext) > 0 { + ext = ext[1:] // 去掉点 + } + vp.SetConfigType(ext) if err = vp.ReadConfig(bytes.NewReader(buf)); err != nil { - panic(fmt.Sprintf("Error while reading config file %s: %v", name+filetype, err)) + panic(fmt.Sprintf("Error while reading config file %s: %v", name, err)) } return vp } -func ReadFile(name, filetype string) ([]byte, error) { +func ReadFile(name string) ([]byte, error) { for _, v := range paths { - file := fmt.Sprintf("%s/%s.%s", v, name, filetype) + file := filepath.Join(v, name) if _, err := os.Stat(file); err == nil { return os.ReadFile(file) }