diff --git a/mongo/export.go b/mongo/export.go index 312e905..b464133 100644 --- a/mongo/export.go +++ b/mongo/export.go @@ -42,6 +42,27 @@ func Close() { internal.Close(_manager) } +// GetDatabase 获取数据库 +// platform: 平台id +// database: 数据库名称 +func GetDatabase(platform, database string) (*Database, error) { + if _manager == nil { + return nil, errors.New("mongo manager is nil, please call Init() first") + } + + return _manager.GetDatabase(platform, database) +} + +// GetGlobalDatabase 获取全局库 +// database: 数据库名称 +func GetGlobalDatabase(database string) (*Database, error) { + if _manager == nil { + return nil, errors.New("mongo manager is nil, please call Init() first") + } + + return _manager.GetDatabase("global", database) +} + // GetGlobalCollection 获取全局库 // database: 数据库名称 // collection: 集合名称 diff --git a/mongo/internal/mongo.go b/mongo/internal/mongo.go index 1c2282f..9d54b9c 100644 --- a/mongo/internal/mongo.go +++ b/mongo/internal/mongo.go @@ -98,6 +98,14 @@ type Manager struct { } func (m *Manager) GetCollection(key, database, collection string) (*Collection, error) { + d, err := m.GetDatabase(key, database) + if err != nil { + return nil, err + } + return d.GetCollection(collection) +} + +func (m *Manager) GetDatabase(key, database string) (*Database, error) { switch key { case "global": v, ok := m.global.Load(database) @@ -113,7 +121,7 @@ func (m *Manager) GetCollection(key, database, collection string) (*Collection, m.global.Store(database, v) } d, _ := v.(*Database) - return d.GetCollection(collection) + return d, nil default: var mp *sync.Map @@ -137,7 +145,7 @@ func (m *Manager) GetCollection(key, database, collection string) (*Collection, mp.Store(database, v) } d, _ := v.(*Database) - return d.GetCollection(collection) + return d, nil } }