enhance: Change proxy connection manager to concurrent safe (#31008)

See also #31007

This PR:
- Add param item for connection manager behavior: TTL & check interval
- Change clientInfo map to concurrent map

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/31023/head
congqixia 2024-03-05 10:39:00 +08:00 committed by GitHub
parent 8f7019468f
commit 3b5ce73ded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 78 additions and 114 deletions

View File

@ -208,6 +208,8 @@ proxy:
ginLogging: true ginLogging: true
ginLogSkipPaths: "/" # skipped url path for gin log split by comma ginLogSkipPaths: "/" # skipped url path for gin log split by comma
maxTaskNum: 1024 # max task number of proxy task queue maxTaskNum: 1024 # max task number of proxy task queue
connectionMgrCheckInterval: 120 # the interval time(in seconds) for connection manager to scan inactive client info
connectionClientInfoTTL: 86400 # inactive client info TTL duration, in seconds
accessLog: accessLog:
enable: false enable: false
# Log filename, set as "" to use stdout. # Log filename, set as "" to use stdout.

View File

@ -33,6 +33,11 @@ import (
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
func TestMain(m *testing.M) {
paramtable.Init()
os.Exit(m.Run())
}
func TestAccessLogger_NotEnable(t *testing.T) { func TestAccessLogger_NotEnable(t *testing.T) {
var Params paramtable.ComponentParam var Params paramtable.ComponentParam

View File

@ -46,6 +46,10 @@ type GrpcAccessInfoSuite struct {
info *GrpcAccessInfo info *GrpcAccessInfo
} }
func (s *GrpcAccessInfoSuite) SetupSuite() {
paramtable.Init()
}
func (s *GrpcAccessInfoSuite) SetupTest() { func (s *GrpcAccessInfoSuite) SetupTest() {
s.username = "test-user" s.username = "test-user"
s.traceID = "test-trace" s.traceID = "test-trace"

View File

@ -8,9 +8,7 @@ var getConnectionManagerInstanceOnce sync.Once
func GetManager() *connectionManager { func GetManager() *connectionManager {
getConnectionManagerInstanceOnce.Do(func() { getConnectionManagerInstanceOnce.Do(func() {
connectionManagerInstance = newConnectionManager( connectionManagerInstance = newConnectionManager()
withDuration(defaultConnCheckDuration),
withTTL(defaultTTLForInactiveConn))
}) })
return connectionManagerInstance return connectionManagerInstance
} }

View File

@ -6,52 +6,23 @@ import (
"sync" "sync"
"time" "time"
"github.com/golang/protobuf/proto" "go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
) "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
const (
// we shouldn't check this too frequently.
defaultConnCheckDuration = 2 * time.Minute
defaultTTLForInactiveConn = 24 * time.Hour
) )
type connectionManager struct { type connectionManager struct {
mu sync.RWMutex
initOnce sync.Once initOnce sync.Once
stopOnce sync.Once stopOnce sync.Once
closeSignal chan struct{} closeSignal chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
buffer chan int64 clientInfos *typeutil.ConcurrentMap[int64, clientInfo]
duration time.Duration count atomic.Int64
ttl time.Duration
clientInfos map[int64]clientInfo
}
type connectionManagerOption func(s *connectionManager)
func withDuration(duration time.Duration) connectionManagerOption {
return func(s *connectionManager) {
s.duration = duration
}
}
func withTTL(ttl time.Duration) connectionManagerOption {
return func(s *connectionManager) {
s.ttl = ttl
}
}
func (s *connectionManager) apply(opts ...connectionManagerOption) {
for _, opt := range opts {
opt(s)
}
} }
func (s *connectionManager) init() { func (s *connectionManager) init() {
@ -71,7 +42,7 @@ func (s *connectionManager) Stop() {
func (s *connectionManager) checkLoop() { func (s *connectionManager) checkLoop() {
defer s.wg.Done() defer s.wg.Done()
t := time.NewTicker(s.duration) t := time.NewTicker(paramtable.Get().ProxyCfg.ConnectionCheckIntervalSeconds.GetAsDuration(time.Second))
defer t.Stop() defer t.Stop()
for { for {
@ -79,10 +50,9 @@ func (s *connectionManager) checkLoop() {
case <-s.closeSignal: case <-s.closeSignal:
log.Info("connection manager closed") log.Info("connection manager closed")
return return
case identifier := <-s.buffer:
s.Update(identifier)
case <-t.C: case <-t.C:
s.removeLongInactiveClients() s.removeLongInactiveClients()
t.Reset(paramtable.Get().ProxyCfg.ConnectionCheckIntervalSeconds.GetAsDuration(time.Second))
} }
} }
} }
@ -94,49 +64,42 @@ func (s *connectionManager) Register(ctx context.Context, identifier int64, info
lastActiveTime: time.Now(), lastActiveTime: time.Now(),
} }
s.mu.Lock() s.count.Inc()
defer s.mu.Unlock() s.clientInfos.Insert(identifier, cli)
s.clientInfos[identifier] = cli
log.Ctx(ctx).Info("client register", cli.GetLogger()...) log.Ctx(ctx).Info("client register", cli.GetLogger()...)
} }
func (s *connectionManager) KeepActive(identifier int64) { func (s *connectionManager) KeepActive(identifier int64) {
// make this asynchronous and then the rpc won't be blocked too long. s.Update(identifier)
s.buffer <- identifier
} }
func (s *connectionManager) List() []*commonpb.ClientInfo { func (s *connectionManager) List() []*commonpb.ClientInfo {
s.mu.RLock() clients := make([]*commonpb.ClientInfo, 0, s.count.Load())
defer s.mu.RUnlock()
clients := make([]*commonpb.ClientInfo, 0, len(s.clientInfos)) s.clientInfos.Range(func(identifier int64, info clientInfo) bool {
if info.ClientInfo != nil {
for identifier, cli := range s.clientInfos { client := typeutil.Clone(info.ClientInfo)
if cli.ClientInfo != nil {
client := proto.Clone(cli.ClientInfo).(*commonpb.ClientInfo)
if client.Reserved == nil { if client.Reserved == nil {
client.Reserved = make(map[string]string) client.Reserved = make(map[string]string)
} }
client.Reserved["identifier"] = string(strconv.AppendInt(nil, identifier, 10)) client.Reserved["identifier"] = string(strconv.AppendInt(nil, identifier, 10))
client.Reserved["last_active_time"] = cli.lastActiveTime.String() client.Reserved["last_active_time"] = info.lastActiveTime.String()
clients = append(clients, client) clients = append(clients, client)
} }
} return true
})
return clients return clients
} }
func (s *connectionManager) Get(ctx context.Context) *commonpb.ClientInfo { func (s *connectionManager) Get(ctx context.Context) *commonpb.ClientInfo {
s.mu.RLock()
defer s.mu.RUnlock()
identifier, err := GetIdentifierFromContext(ctx) identifier, err := GetIdentifierFromContext(ctx)
if err != nil { if err != nil {
return nil return nil
} }
cli, ok := s.clientInfos[identifier] cli, ok := s.clientInfos.Get(identifier)
if !ok { if !ok {
return nil return nil
} }
@ -144,37 +107,30 @@ func (s *connectionManager) Get(ctx context.Context) *commonpb.ClientInfo {
} }
func (s *connectionManager) Update(identifier int64) { func (s *connectionManager) Update(identifier int64) {
s.mu.Lock() info, ok := s.clientInfos.Get(identifier)
defer s.mu.Unlock()
cli, ok := s.clientInfos[identifier]
if ok { if ok {
cli.lastActiveTime = time.Now() info.lastActiveTime = time.Now()
s.clientInfos[identifier] = cli s.clientInfos.Insert(identifier, info)
} }
} }
func (s *connectionManager) removeLongInactiveClients() { func (s *connectionManager) removeLongInactiveClients() {
s.mu.Lock() ttl := paramtable.Get().ProxyCfg.ConnectionClientInfoTTLSeconds.GetAsDuration(time.Second)
defer s.mu.Unlock() s.clientInfos.Range(func(candidate int64, info clientInfo) bool {
if time.Since(info.lastActiveTime) > ttl {
for candidate, cli := range s.clientInfos { log.Info("client deregister", info.GetLogger()...)
if time.Since(cli.lastActiveTime) > s.ttl { s.clientInfos.Remove(candidate)
log.Info("client deregister", cli.GetLogger()...) s.count.Dec()
delete(s.clientInfos, candidate)
} }
} return true
})
} }
func newConnectionManager(opts ...connectionManagerOption) *connectionManager { func newConnectionManager() *connectionManager {
s := &connectionManager{ s := &connectionManager{
closeSignal: make(chan struct{}, 1), closeSignal: make(chan struct{}, 1),
buffer: make(chan int64, 64), clientInfos: typeutil.NewConcurrentMap[int64, clientInfo](),
duration: defaultConnCheckDuration,
ttl: defaultTTLForInactiveConn,
clientInfos: make(map[int64]clientInfo),
} }
s.apply(opts...)
s.init() s.init()
return s return s

View File

@ -8,39 +8,19 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/util/paramtable"
) )
func Test_withDuration(t *testing.T) {
s := &connectionManager{}
s.apply(withDuration(defaultConnCheckDuration))
assert.Equal(t, defaultConnCheckDuration, s.duration)
}
func Test_withTTL(t *testing.T) {
s := &connectionManager{}
s.apply(withTTL(defaultTTLForInactiveConn))
assert.Equal(t, defaultTTLForInactiveConn, s.ttl)
}
func Test_connectionManager_apply(t *testing.T) {
s := &connectionManager{}
s.apply(
withDuration(defaultConnCheckDuration),
withTTL(defaultTTLForInactiveConn))
assert.Equal(t, defaultConnCheckDuration, s.duration)
assert.Equal(t, defaultTTLForInactiveConn, s.ttl)
}
func TestGetConnectionManager(t *testing.T) {
s := GetManager()
assert.Equal(t, defaultConnCheckDuration, s.duration)
assert.Equal(t, defaultTTLForInactiveConn, s.ttl)
}
func TestConnectionManager(t *testing.T) { func TestConnectionManager(t *testing.T) {
s := newConnectionManager( paramtable.Init()
withDuration(time.Millisecond*5),
withTTL(time.Millisecond*100)) pt := paramtable.Get()
pt.Save(pt.ProxyCfg.ConnectionCheckIntervalSeconds.Key, "2")
pt.Save(pt.ProxyCfg.ConnectionClientInfoTTLSeconds.Key, "1")
defer pt.Reset(pt.ProxyCfg.ConnectionCheckIntervalSeconds.Key)
defer pt.Reset(pt.ProxyCfg.ConnectionClientInfoTTLSeconds.Key)
s := newConnectionManager()
defer s.Stop()
s.Register(context.TODO(), 1, &commonpb.ClientInfo{ s.Register(context.TODO(), 1, &commonpb.ClientInfo{
Reserved: map[string]string{"for_test": "for_test"}, Reserved: map[string]string{"for_test": "for_test"},
@ -60,10 +40,7 @@ func TestConnectionManager(t *testing.T) {
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 5)
assert.Equal(t, 2, len(s.List())) assert.Equal(t, 2, len(s.List()))
time.Sleep(time.Millisecond * 100) assert.Eventually(t, func() bool {
assert.Equal(t, 0, len(s.List())) return len(s.List()) == 0
}, time.Second*5, time.Second)
s.Stop()
time.Sleep(time.Millisecond * 5)
} }

View File

@ -1048,6 +1048,10 @@ type proxyConfig struct {
AccessLog AccessLogConfig AccessLog AccessLogConfig
// connection manager
ConnectionCheckIntervalSeconds ParamItem `refreshable:"true"`
ConnectionClientInfoTTLSeconds ParamItem `refreshable:"true"`
GracefulStopTimeout ParamItem `refreshable:"true"` GracefulStopTimeout ParamItem `refreshable:"true"`
} }
@ -1364,6 +1368,24 @@ please adjust in embedded Milvus: false`,
Export: true, Export: true,
} }
p.GracefulStopTimeout.Init(base.mgr) p.GracefulStopTimeout.Init(base.mgr)
p.ConnectionCheckIntervalSeconds = ParamItem{
Key: "proxy.connectionMgrCheckInterval",
Version: "2.3.11",
Doc: "the interval time(in seconds) for connection manager to scan inactive client info",
DefaultValue: "120",
Export: true,
}
p.ConnectionCheckIntervalSeconds.Init(base.mgr)
p.ConnectionClientInfoTTLSeconds = ParamItem{
Key: "proxy.connectionClientInfoTTL",
Version: "2.3.11",
Doc: "inactive client info TTL duration, in seconds",
DefaultValue: "86400",
Export: true,
}
p.ConnectionClientInfoTTLSeconds.Init(base.mgr)
} }
// ///////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////