goserver_sync/core/mongox/function.go

157 lines
3.9 KiB
Go

package mongox
import (
"reflect"
"strings"
"go.mongodb.org/mongo-driver/mongo"
"mongo.games.com/goserver/core/logger"
)
type DatabaseType string
const (
KeyGlobal = "global"
DatabaseUser DatabaseType = "user"
DatabaseLog DatabaseType = "log"
DatabaseMonitor DatabaseType = "monitor"
)
// GetClient 获取数据库连接
// 默认获取的是 Global, log 的数据库连接
func GetClient() (*mongo.Client, error) {
if _manager == nil {
return nil, NotInitError
}
c, err := _manager.GetCollection(KeyGlobal, string(DatabaseLog), "empty")
if err != nil {
return nil, err
}
return c.Database.Client, nil
}
// GetDatabase 获取数据库
// platform: 平台id
// database: 数据库名称
func GetDatabase(platform string, database DatabaseType) (*Database, error) {
if _manager == nil {
return nil, NotInitError
}
return _manager.GetDatabase(platform, string(database))
}
func GetUserDatabase(platform string) (*Database, error) {
return GetDatabase(platform, DatabaseUser)
}
func GetLogDatabase(platform string) (*Database, error) {
return GetDatabase(platform, DatabaseLog)
}
// GetGlobalDatabase 获取全局库
// database: 数据库名称
func GetGlobalDatabase(database DatabaseType) (*Database, error) {
if _manager == nil {
return nil, NotInitError
}
return _manager.GetDatabase(KeyGlobal, string(database))
}
func GetGlobalUserDatabase() (*Database, error) {
return GetGlobalDatabase(DatabaseUser)
}
func GetGlobalLogDatabase() (*Database, error) {
return GetGlobalDatabase(DatabaseLog)
}
func GetGlobalMonitorDatabase() (*Database, error) {
return GetGlobalDatabase(DatabaseMonitor)
}
// GetGlobalCollection 获取全局库
// database: 数据库名称
// collection: 集合名称
func GetGlobalCollection(database DatabaseType, collection string) (*Collection, error) {
if _manager == nil {
return nil, NotInitError
}
return _manager.GetCollection(KeyGlobal, string(database), collection)
}
func GetGlobalUserCollection(collection string) (*Collection, error) {
return GetGlobalCollection(DatabaseUser, collection)
}
func GetGlobalLogCollection(collection string) (*Collection, error) {
return GetGlobalCollection(DatabaseLog, collection)
}
func GetGlobalMonitorCollection(collection string) (*Collection, error) {
return GetGlobalCollection(DatabaseMonitor, collection)
}
// GetCollection 获取平台库
// platform: 平台id
// database: 数据库名称
// collection: 集合名称
func GetCollection(platform string, database DatabaseType, collection string) (*Collection, error) {
if _manager == nil {
return nil, NotInitError
}
return _manager.GetCollection(platform, string(database), collection)
}
func GetUserCollection(platform string, collection string) (*Collection, error) {
return GetCollection(platform, DatabaseUser, collection)
}
func GetLogCollection(platform string, collection string) (*Collection, error) {
return GetCollection(platform, DatabaseLog, collection)
}
// ICollectionName 文档名称接口
type ICollectionName interface {
CollectionName() string
}
// GetTableName 获取文档名
func GetTableName(model any) string {
if m, ok := model.(ICollectionName); ok {
return m.CollectionName()
}
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
panic("model must be a struct or a pointer to a struct")
}
return strings.ToLower(t.Name())
}
// GetCollectionDao 获取文档操作接口
// key: 平台id 或 KeyGlobal
// database: 数据库类型 DatabaseType
// f: 文档接口创建函数
func GetCollectionDao[T any](key string, database DatabaseType, model any, f func(database *mongo.Database, c *mongo.Collection) T) (T, error) {
collectionName := GetTableName(model)
c, err := GetCollection(key, database, collectionName)
if err != nil {
var z T
logger.Logger.Errorf("GetCollectionModel key:%v database:%v model:%v error: %v", key, database, collectionName, err)
return z, err
}
return f(c.Database.Database, c.Collection), nil
}