diff --git a/srvlib/serversessionmgr.go b/srvlib/serversessionmgr.go index baad454..b58883d 100644 --- a/srvlib/serversessionmgr.go +++ b/srvlib/serversessionmgr.go @@ -9,8 +9,7 @@ import ( ) /* - 服务器信息注册,单个服务器可能包含多个服务端口 - 服务器信息->session + 服务器信息注册,单个服务器可能包含多个子服务端口 */ var ( @@ -35,64 +34,79 @@ func (ssm *ServerSessionMgr) AddListener(l ServerSessionRegisteListener) ServerS func (ssm *ServerSessionMgr) RegisteSession(s *netlib.Session) bool { attr := s.GetAttribute(SessionAttributeServerInfo) - if attr != nil { - if srvInfo, ok := attr.(*protocol.SSSrvRegiste); ok && srvInfo != nil { - areaId := int(srvInfo.GetAreaId()) - srvType := int(srvInfo.GetType()) - srvId := int(srvInfo.GetId()) - if a, exist := ssm.sessions[areaId]; !exist { - ssm.sessions[areaId] = make(map[int]map[int]*netlib.Session) - a = ssm.sessions[areaId] - a[srvType] = make(map[int]*netlib.Session) - } else { - if _, exist := a[srvType]; !exist { - a[srvType] = make(map[int]*netlib.Session) - } - } - - if _, exist := ssm.sessions[areaId][srvType][srvId]; !exist { - logger.Logger.Infof("(ssm *ServerSessionMgr) RegisteSession %v", srvInfo) - ssm.sessions[areaId][srvType][srvId] = s - if len(ssm.listeners) != 0 { - for _, l := range ssm.listeners { - l.OnRegiste(s) - } - } - } else { - logger.Logger.Warnf("###(ssm *ServerSessionMgr) RegisteSession repeated areaid:%v srvType:%v srvId:%v", areaId, srvType, srvId) - } - } - } else { - logger.Logger.Warnf("ServerSessionMgr.RegisteSession SessionAttributeServerInfo=nil") + if attr == nil { + logger.Logger.Warnf("服务器注册信息为空") + return false } + + srvInfo, ok := attr.(*protocol.SSSrvRegiste) + if !ok || srvInfo == nil { + logger.Logger.Warnf("服务器注册信息错误") + return false + } + + areaId := int(srvInfo.GetAreaId()) + srvType := int(srvInfo.GetType()) + srvId := int(srvInfo.GetId()) + + if _, ok := ssm.sessions[areaId]; !ok { + ssm.sessions[areaId] = make(map[int]map[int]*netlib.Session) + } + if _, ok := ssm.sessions[areaId][srvType]; !ok { + ssm.sessions[areaId][srvType] = make(map[int]*netlib.Session) + } + + session, has := ssm.sessions[areaId][srvType][srvId] + if has && session != nil && session != s { + logger.Logger.Warnf("删除旧服务器注册: %v", srvInfo) + ssm.UnregisteSession(session) + } + + logger.Logger.Infof("服务器注册成功:%v", srvInfo) + ssm.sessions[areaId][srvType][srvId] = s + if len(ssm.listeners) != 0 { + for _, l := range ssm.listeners { + l.OnRegiste(s) + } + } + return true } +// UnregisteSession 服务器关闭 +// 需要支持幂等性 func (ssm *ServerSessionMgr) UnregisteSession(s *netlib.Session) bool { attr := s.GetAttribute(SessionAttributeServerInfo) - if attr != nil { - if srvInfo, ok := attr.(*protocol.SSSrvRegiste); ok && srvInfo != nil { - logger.Logger.Infof("ServerSessionMgr.UnregisteSession try %v", srvInfo) - areaId := int(srvInfo.GetAreaId()) - srvType := int(srvInfo.GetType()) - srvId := int(srvInfo.GetId()) - if a, exist := ssm.sessions[areaId]; exist { - if b, exist := a[srvType]; exist { - if _, exist := b[srvId]; exist { - logger.Logger.Infof("ServerSessionMgr.UnregisteSession %v success", srvInfo) - delete(b, srvId) - if len(ssm.listeners) != 0 { - for _, l := range ssm.listeners { - l.OnUnregiste(s) - } - } - } else { - logger.Logger.Warnf("(ssm *ServerSessionMgr) UnregisteSession found not fit session, area:%v type:%v id:%v", areaId, srvType, srvId) + if attr == nil { + return false + } + + srvInfo, ok := attr.(*protocol.SSSrvRegiste) + if !ok || srvInfo == nil { + return false + } + + logger.Logger.Tracef("尝试删除服务器注册:%v", srvInfo) + areaId := int(srvInfo.GetAreaId()) + srvType := int(srvInfo.GetType()) + srvId := int(srvInfo.GetId()) + + if a, exist := ssm.sessions[areaId]; exist { + if b, exist := a[srvType]; exist { + if conn, exist := b[srvId]; exist && s == conn { + logger.Logger.Infof("删除服务器注册成功 %v", srvInfo) + delete(b, srvId) + if len(ssm.listeners) != 0 { + for _, l := range ssm.listeners { + l.OnUnregiste(s) } } + } else { + logger.Logger.Tracef("服务器注册信息已经删除") } } } + return true }