update shard leader cache (#22632)

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
pull/22660/head
wei liu 2023-03-09 16:35:53 +08:00 committed by GitHub
parent 1c65d56825
commit 336378c198
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1847 additions and 29 deletions

View File

@ -15,6 +15,7 @@ import (
// validAuth validates the authentication
func TestValidAuth(t *testing.T) {
ctx := context.Background()
Params.InitOnce()
// no metadata
res := validAuth(ctx, nil)
assert.False(t, res)

View File

@ -151,7 +151,7 @@ func TestProxy_CheckHealth(t *testing.T) {
}
func TestProxy_ResourceGroup(t *testing.T) {
Params.Init()
Params.InitOnce()
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -64,6 +65,7 @@ type Cache interface {
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error)
ClearShards(collectionName string)
expireShardLeaderCache(ctx context.Context)
RemoveCollection(ctx context.Context, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
RemovePartition(ctx context.Context, collectionName string, partitionName string)
@ -172,6 +174,7 @@ func InitMetaCache(ctx context.Context, rootCoord types.RootCoord, queryCoord ty
}
globalMetaCache.InitPolicyInfo(resp.PolicyInfos, resp.UserRoles)
log.Info("success to init meta cache", zap.Strings("policy_infos", resp.PolicyInfos))
globalMetaCache.expireShardLeaderCache(ctx)
return nil
}
@ -743,14 +746,56 @@ func (m *MetaCache) ClearShards(collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
m.mu.Lock()
info, ok := m.collInfo[collectionName]
if ok {
m.collInfo[collectionName].shardLeaders = nil
}
m.mu.Unlock()
// delete refcnt in shardClientMgr
if ok && info.shardLeaders != nil {
_ = m.shardMgr.UpdateShardLeaders(info.shardLeaders.shardLeaders, nil)
var shardLeaders *shardLeaders
if ok {
info.leaderMutex.Lock()
m.collInfo[collectionName].shardLeaders = nil
shardLeaders = info.shardLeaders
info.leaderMutex.Unlock()
}
// delete refcnt in shardClientMgr
if ok && shardLeaders != nil {
_ = m.shardMgr.UpdateShardLeaders(shardLeaders.shardLeaders, nil)
}
}
func (m *MetaCache) expireShardLeaderCache(ctx context.Context) {
go func() {
updateInterval := Params.ProxyCfg.ShardLeaderCacheInterval.Load()
if updateInterval == nil {
updateInterval = time.Duration(30) * time.Second
}
log.Info("updating shard leader cache every", zap.Duration("interval", updateInterval.(time.Duration)))
ticker := time.NewTicker(updateInterval.(time.Duration))
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("stop periodically update meta cache")
return
case <-ticker.C:
m.mu.Lock()
log.Info("expire all shard leader cache",
zap.Strings("collections", lo.Keys(m.collInfo)))
for _, info := range m.collInfo {
info.leaderMutex.Lock()
shardLeaders := info.shardLeaders
info.shardLeaders = nil
info.leaderMutex.Unlock()
if shardLeaders != nil {
err := m.shardMgr.UpdateShardLeaders(shardLeaders.shardLeaders, nil)
if err != nil {
// unreachable logic path
log.Warn("failed to update shard leaders reference", zap.Error(err))
}
}
}
m.mu.Unlock()
}
}
}()
}
func (m *MetaCache) InitPolicyInfo(info []string, userRoles []string) {

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
@ -743,3 +744,101 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
// shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 3)
}
func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
queryCoord := &types.MockQueryCoord{}
shardMgr := newShardClientMgr()
Params.InitOnce()
Params.ProxyCfg.ShardLeaderCacheInterval.Store(time.Duration(1) * time.Second)
err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr)
assert.Nil(t, err)
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{1},
InMemoryPercentages: []int64{100},
}, nil)
queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil).Times(1)
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1")
assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3)
queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2},
NodeAddrs: []string{"localhost:9000", "localhost:9001"},
},
},
}, nil).Times(1)
assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1")
assert.NoError(t, err)
return assert.Len(t, nodeInfos["channel-1"], 2)
}, 3*time.Second, 1*time.Second)
queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil).Times(1)
assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1")
assert.NoError(t, err)
return assert.Len(t, nodeInfos["channel-1"], 3)
}, 3*time.Second, 1*time.Second)
queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
{
ChannelName: "channel-2",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil).Times(1)
assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, "collection1")
assert.NoError(t, err)
return assert.Len(t, nodeInfos["channel-1"], 3) && assert.Len(t, nodeInfos["channel-2"], 3)
}, 3*time.Second, 1*time.Second)
}

View File

@ -27,7 +27,7 @@ import (
)
func TestMultiRateLimiter(t *testing.T) {
Params.Init()
Params.InitOnce()
t.Run("test multiRateLimiter", func(t *testing.T) {
bak := Params.QuotaConfig.QuotaAndLimitsEnabled
Params.QuotaConfig.QuotaAndLimitsEnabled = true
@ -78,7 +78,7 @@ func TestMultiRateLimiter(t *testing.T) {
}
func TestRateLimiter(t *testing.T) {
Params.Init()
Params.InitOnce()
t.Run("test limit", func(t *testing.T) {
limiter := newRateLimiter()
limiter.registerLimiters()

View File

@ -98,7 +98,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
fieldName := "field1"
indexName := "_default_idx_101"
Params.Init()
Params.InitOnce()
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
@ -210,7 +210,7 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
collectionID := UniqueID(1)
fieldName := newTestSchema().Fields[0].Name
Params.Init()
Params.InitOnce()
ic := newMockIndexCoord()
ctx := context.Background()

View File

@ -57,7 +57,7 @@ func TestUpdateShardsWithRoundRobin(t *testing.T) {
func TestGroupShardLeadersWithSameQueryNode(t *testing.T) {
var err error
Params.Init()
Params.InitOnce()
var (
ctx = context.TODO()
)
@ -121,7 +121,7 @@ func TestGroupShardLeadersWithSameQueryNode(t *testing.T) {
func TestMergeRoundRobinPolicy(t *testing.T) {
var err error
Params.Init()
Params.InitOnce()
var (
ctx = context.TODO()
)

View File

@ -27,7 +27,7 @@ import (
)
func TestQueryTask_all(t *testing.T) {
Params.Init()
Params.InitOnce()
var (
err error
@ -376,7 +376,7 @@ func Test_translateToOutputFieldIDs(t *testing.T) {
}
func TestTaskQuery_functions(t *testing.T) {
Params.Init()
Params.InitOnce()
t.Run("test parseQueryParams", func(t *testing.T) {
tests := []struct {
description string

View File

@ -26,7 +26,7 @@ import (
)
func TestBaseTaskQueue(t *testing.T) {
Params.Init()
Params.InitOnce()
var err error
var unissuedTask task
@ -104,7 +104,7 @@ func TestBaseTaskQueue(t *testing.T) {
}
func TestDdTaskQueue(t *testing.T) {
Params.Init()
Params.InitOnce()
var err error
var unissuedTask task
@ -183,7 +183,7 @@ func TestDdTaskQueue(t *testing.T) {
// test the logic of queue
func TestDmTaskQueue_Basic(t *testing.T) {
Params.Init()
Params.InitOnce()
var err error
var unissuedTask task
@ -262,7 +262,7 @@ func TestDmTaskQueue_Basic(t *testing.T) {
// test the timestamp statistics
func TestDmTaskQueue_TimestampStatistics(t *testing.T) {
Params.Init()
Params.InitOnce()
var err error
var unissuedTask task
@ -301,7 +301,7 @@ func TestDmTaskQueue_TimestampStatistics(t *testing.T) {
}
func TestDqTaskQueue(t *testing.T) {
Params.Init()
Params.InitOnce()
var err error
var unissuedTask task
@ -379,7 +379,7 @@ func TestDqTaskQueue(t *testing.T) {
}
func TestTaskScheduler(t *testing.T) {
Params.Init()
Params.InitOnce()
var err error

View File

@ -439,7 +439,7 @@ func TestSearchTask_Reduce(t *testing.T) {
func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) {
// var err error
//
// Params.Init()
// Params.InitOnce()
// Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
//
// rc := NewRootCoordMock()
@ -682,7 +682,7 @@ func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) {
func TestSearchTaskV2_all(t *testing.T) {
// var err error
//
// Params.Init()
// Params.InitOnce()
// Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
//
// rc := NewRootCoordMock()
@ -927,7 +927,7 @@ func TestSearchTaskV2_all(t *testing.T) {
func TestSearchTaskV2_7803_reduce(t *testing.T) {
// var err error
//
// Params.Init()
// Params.InitOnce()
// Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
//
// rc := NewRootCoordMock()
@ -1742,7 +1742,7 @@ func Test_checkIfLoaded(t *testing.T) {
}
func TestSearchTask_ErrExecute(t *testing.T) {
Params.Init()
Params.InitOnce()
var (
err error

View File

@ -2126,7 +2126,7 @@ func Test_checkTrain(t *testing.T) {
func Test_createIndexTask_PreExecute(t *testing.T) {
collectionName := "test"
fieldName := "test"
Params.Init()
Params.InitOnce()
cit := &createIndexTask{
req: &milvuspb.CreateIndexRequest{

View File

@ -804,7 +804,7 @@ func TestPasswordVerify(t *testing.T) {
}
func TestValidateTravelTimestamp(t *testing.T) {
Params.Init()
Params.InitOnce()
originalRetentionDuration := Params.CommonCfg.RetentionDuration
defer func() {
Params.CommonCfg.RetentionDuration = originalRetentionDuration

File diff suppressed because it is too large Load Diff

View File

@ -582,8 +582,9 @@ type proxyConfig struct {
MaxTaskNum int64
CreatedTime time.Time
UpdatedTime time.Time
CreatedTime time.Time
UpdatedTime time.Time
ShardLeaderCacheInterval atomic.Value
}
func (p *proxyConfig) init(base *BaseTable) {
@ -606,6 +607,7 @@ func (p *proxyConfig) init(base *BaseTable) {
p.initMaxRoleNum()
p.initSoPath()
p.initShardLeaderCacheInterval()
}
// InitAlias initialize Alias member.
@ -728,6 +730,11 @@ func (p *proxyConfig) initMaxRoleNum() {
p.MaxRoleNum = int(maxRoleNum)
}
func (p *proxyConfig) initShardLeaderCacheInterval() {
interval := p.Base.ParseIntWithDefault("proxy.shardLeaderCacheInterval", 30)
p.ShardLeaderCacheInterval.Store(time.Duration(interval) * time.Second)
}
// /////////////////////////////////////////////////////////////////////////////
// --- querycoord ---
type queryCoordConfig struct {

View File

@ -185,6 +185,8 @@ func TestComponentParam(t *testing.T) {
t.Logf("MaxDimension: %d", Params.MaxDimension)
t.Logf("MaxTaskNum: %d", Params.MaxTaskNum)
t.Logf("ShardLeaderCacheInterval: %d", Params.ShardLeaderCacheInterval.Load())
})
t.Run("test proxyConfig panic", func(t *testing.T) {