Fix collection and channel not match (#25859)

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/26058/head
smellthemoon 2023-08-01 17:33:06 +08:00 committed by GitHub
parent eade5f9b7f
commit 9614e61f14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 369 additions and 191 deletions

View File

@ -34,20 +34,22 @@ import (
type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error
type ChannelWorkload struct {
db string
collection string
channel string
shardLeaders []int64
nq int64
exec executeFunc
retryTimes uint
db string
collectionName string
collectionID int64
channel string
shardLeaders []int64
nq int64
exec executeFunc
retryTimes uint
}
type CollectionWorkLoad struct {
db string
collection string
nq int64
exec executeFunc
db string
collectionName string
collectionID int64
nq int64
exec executeFunc
}
type LBPolicy interface {
@ -89,7 +91,8 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) {
// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
log := log.With(
zap.String("collectionName", workload.collection),
zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel),
)
@ -98,7 +101,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
}
getShardLeaders := func() ([]int64, error) {
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collection)
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
if err != nil {
return nil, err
}
@ -109,7 +112,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil {
globalMetaCache.DeprecateShardCache(workload.db, workload.collection)
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
nodes, err := getShardLeaders()
if err != nil || len(nodes) == 0 {
log.Warn("failed to get shard delegator",
@ -141,7 +144,8 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
excludeNodes := typeutil.NewUniqueSet()
log := log.Ctx(ctx).With(
zap.String("collectionName", workload.collection),
zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel),
)
@ -185,7 +189,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
// Execute will execute collection workload in parallel
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collection)
dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collectionName, workload.collectionID)
if err != nil {
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
return err
@ -197,13 +201,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID })
wg.Go(func() error {
err := lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db,
collection: workload.collection,
channel: channel,
shardLeaders: nodes,
nq: workload.nq,
exec: workload.exec,
retryTimes: uint(len(nodes)),
db: workload.db,
collectionName: workload.collectionName,
collectionID: workload.collectionID,
channel: channel,
shardLeaders: nodes,
nq: workload.nq,
exec: workload.exec,
retryTimes: uint(len(nodes)),
})
return err
})

View File

@ -53,7 +53,8 @@ type LBPolicySuite struct {
channels []string
qnList []*mocks.MockQueryNode
collection string
collectionName string
collectionID int64
}
func (s *LBPolicySuite) SetupSuite() {
@ -108,7 +109,7 @@ func (s *LBPolicySuite) SetupTest() {
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
s.NoError(err)
s.collection = "test_lb_policy"
s.collectionName = "test_lb_policy"
s.loadCollection()
}
@ -125,7 +126,7 @@ func (s *LBPolicySuite) loadCollection() {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
schema := constructCollectionSchemaByDataType(s.collection, fieldName2Types, testInt64Field, false)
schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
@ -133,7 +134,7 @@ func (s *LBPolicySuite) loadCollection() {
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: s.collection,
CollectionName: s.collectionName,
DbName: dbName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
@ -147,7 +148,7 @@ func (s *LBPolicySuite) loadCollection() {
s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collection)
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collectionName)
s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
@ -159,17 +160,19 @@ func (s *LBPolicySuite) loadCollection() {
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
s.collectionID = collectionID
}
func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet())
s.NoError(err)
s.Equal(int64(5), targetNode)
@ -179,11 +182,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: []int64{},
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: []int64{},
nq: 1,
}, typeutil.NewUniqueSet())
s.NoError(err)
s.Equal(int64(3), targetNode)
@ -192,11 +196,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: []int64{},
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: []int64{},
nq: 1,
}, typeutil.NewUniqueSet())
s.ErrorIs(err, merr.ErrNodeNotAvailable)
s.Equal(int64(-1), targetNode)
@ -205,11 +210,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet(s.nodes...))
s.ErrorIs(err, merr.ErrServiceUnavailable)
s.Equal(int64(-1), targetNode)
@ -220,11 +226,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.qc.ExpectedCalls = nil
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet())
s.ErrorIs(err, merr.ErrServiceUnavailable)
s.Equal(int64(-1), targetNode)
@ -239,11 +246,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil
},
@ -255,11 +263,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil
},
@ -274,11 +283,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil
},
@ -291,11 +301,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil
},
@ -311,11 +322,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collection: s.collection,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
counter++
if counter == 1 {
@ -336,9 +348,10 @@ func (s *LBPolicySuite) TestExecute() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collection: s.collection,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil
},
@ -348,9 +361,10 @@ func (s *LBPolicySuite) TestExecute() {
// test some channel failed
counter := atomic.NewInt64(0)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collection: s.collection,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
if counter.Add(1) == 1 {
return nil
@ -363,12 +377,13 @@ func (s *LBPolicySuite) TestExecute() {
// test get shard leader failed
s.qc.ExpectedCalls = nil
globalMetaCache.DeprecateShardCache(dbName, s.collection)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, mockErr)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collection: s.collection,
nq: 1,
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
return nil
},

View File

@ -57,8 +57,8 @@ type Cache interface {
GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error)
// GetDatabaseAndCollectionName get collection's name and database by id
GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error)
// GetCollectionInfo get collection's information by name, such as collection id, schema, and etc.
GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error)
// GetCollectionInfo get collection's information by name or collection id, such as schema, and etc.
GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error)
// GetPartitionID get partition's identifier of specific collection.
GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
// GetPartitions get all partitions' id of specific collection.
@ -67,7 +67,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
DeprecateShardCache(database, collectionName string)
expireShardLeaderCache(ctx context.Context)
RemoveCollection(ctx context.Context, database, collectionName string)
@ -229,7 +229,7 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam
collInfo, ok = db[collectionName]
}
method := "GeCollectionID"
method := "GetCollectionID"
if !ok || !collInfo.isCollectionCached() {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
tr := timerecord.NewTimeRecorder("UpdateCache")
@ -289,7 +289,7 @@ func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collection
// GetCollectionInfo returns the collection information related to provided collection name
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error) {
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
@ -301,10 +301,17 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN
m.mu.RUnlock()
method := "GetCollectionInfo"
if !ok || !collInfo.isCollectionCached() {
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
coll, err := m.describeCollection(ctx, database, collectionName, 0)
var coll *milvuspb.DescribeCollectionResponse
var err error
// collectionName maybe not trustable, get collection according to id
coll, err = m.describeCollection(ctx, database, "", collectionID)
if err != nil {
return nil, err
}
@ -695,8 +702,12 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
}
// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error) {
info, err := m.GetCollectionInfo(ctx, database, collectionName)
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
log := log.Ctx(ctx).With(
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
info, err := m.GetCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
@ -715,8 +726,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
}
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName))
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
}
req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase(
@ -754,9 +764,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
shards := parseShardLeaderList2QueryNode(resp.GetShards())
info, err = m.GetCollectionInfo(ctx, database, collectionName)
info, err = m.GetCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, fmt.Errorf("failed to get shards, collection %s not found", collectionName)
return nil, fmt.Errorf("failed to get shards, collectionName %s, colectionID %d not found", collectionName, collectionID)
}
// lock leader
info.leaderMutex.Lock()

View File

@ -474,6 +474,7 @@ func TestMetaCache_GetShards(t *testing.T) {
var (
ctx = context.Background()
collectionName = "collection1"
collectionID = int64(1)
)
rootCoord := &MockRootCoordClientInterface{}
@ -488,7 +489,7 @@ func TestMetaCache_GetShards(t *testing.T) {
defer qc.Stop()
t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists")
shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists", 0)
assert.Error(t, err)
assert.Empty(t, shards)
})
@ -503,7 +504,7 @@ func TestMetaCache_GetShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName)
shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName, collectionID)
assert.Error(t, err)
assert.Empty(t, shards)
})
@ -524,7 +525,7 @@ func TestMetaCache_GetShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName)
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
@ -537,7 +538,7 @@ func TestMetaCache_GetShards(t *testing.T) {
Reason: "not implemented",
},
}, nil)
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName)
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
@ -550,6 +551,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
var (
ctx = context.TODO()
collectionName = "collection1"
collectionID = int64(1)
)
rootCoord := &MockRootCoordClientInterface{}
@ -588,7 +590,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil)
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName)
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
require.NoError(t, err)
require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards))
@ -602,7 +604,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
Reason: "not implemented",
},
}, nil)
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName)
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
assert.Error(t, err)
assert.Empty(t, shards)
})
@ -706,26 +708,26 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
InMemoryPercentages: []int64{100, 50},
}, nil)
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
// no collectionInfo of collection1, should access RootCoord
assert.Equal(t, rootCoord.GetAccessCount(), 1)
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
// shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 1)
globalMetaCache.RemoveCollection(ctx, dbName, "collection1")
// no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
// shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 2)
globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1))
// no collectionInfo of collection2, should access RootCoord
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
assert.NoError(t, err)
// shouldn't access RootCoord again
assert.Equal(t, rootCoord.GetAccessCount(), 3)
@ -761,7 +763,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
},
},
}, nil)
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3)
@ -780,7 +782,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil)
assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 2
}, 3*time.Second, 1*time.Second)
@ -800,7 +802,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil)
assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 3
}, 3*time.Second, 1*time.Second)
@ -825,7 +827,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
}, nil)
assert.Eventually(t, func() bool {
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
assert.NoError(t, err)
return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3
}, 3*time.Second, 1*time.Second)

View File

@ -1,16 +1,16 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
// Code generated by mockery v2.23.1. DO NOT EDIT.
package proxy
import (
"context"
context "context"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/stretchr/testify/mock"
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
mock "github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
)
// MockCache is an autogenerated mock type for the Cache type
@ -55,18 +55,26 @@ func (_c *MockCache_DeprecateShardCache_Call) Return() *MockCache_DeprecateShard
return _c
}
func (_c *MockCache_DeprecateShardCache_Call) RunAndReturn(run func(string, string)) *MockCache_DeprecateShardCache_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionID provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionID(ctx context.Context, database string, collectionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 int64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (int64, error)); ok {
return rf(ctx, database, collectionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) int64); ok {
r0 = rf(ctx, database, collectionName)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, database, collectionName)
} else {
@ -101,22 +109,30 @@ func (_c *MockCache_GetCollectionID_Call) Return(_a0 int64, _a1 error) *MockCach
return _c
}
// GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string) (*collectionInfo, error) {
ret := _m.Called(ctx, database, collectionName)
func (_c *MockCache_GetCollectionID_Call) RunAndReturn(run func(context.Context, string, string) (int64, error)) *MockCache_GetCollectionID_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName, collectionID
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionInfo, error) {
ret := _m.Called(ctx, database, collectionName, collectionID)
var r0 *collectionInfo
if rf, ok := ret.Get(0).(func(context.Context, string, string) *collectionInfo); ok {
r0 = rf(ctx, database, collectionName)
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*collectionInfo, error)); ok {
return rf(ctx, database, collectionName, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *collectionInfo); ok {
r0 = rf(ctx, database, collectionName, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*collectionInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, database, collectionName)
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok {
r1 = rf(ctx, database, collectionName, collectionID)
} else {
r1 = ret.Error(1)
}
@ -133,13 +149,14 @@ type MockCache_GetCollectionInfo_Call struct {
// - ctx context.Context
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetCollectionInfo(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetCollectionInfo_Call {
return &MockCache_GetCollectionInfo_Call{Call: _e.mock.On("GetCollectionInfo", ctx, database, collectionName)}
// - collectionID int64
func (_e *MockCache_Expecter) GetCollectionInfo(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}) *MockCache_GetCollectionInfo_Call {
return &MockCache_GetCollectionInfo_Call{Call: _e.mock.On("GetCollectionInfo", ctx, database, collectionName, collectionID)}
}
func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetCollectionInfo_Call {
func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64)) *MockCache_GetCollectionInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string))
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64))
})
return _c
}
@ -149,11 +166,20 @@ func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionInfo, _a1 erro
return _c
}
func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Context, string, string, int64) (*collectionInfo, error)) *MockCache_GetCollectionInfo_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemapb.CollectionSchema, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 *schemapb.CollectionSchema
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (*schemapb.CollectionSchema, error)); ok {
return rf(ctx, database, collectionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemapb.CollectionSchema); ok {
r0 = rf(ctx, database, collectionName)
} else {
@ -162,7 +188,6 @@ func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, c
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, database, collectionName)
} else {
@ -197,11 +222,20 @@ func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSch
return _c
}
func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, string, string) (*schemapb.CollectionSchema, error)) *MockCache_GetCollectionSchema_Call {
_c.Call.Return(run)
return _c
}
// GetCredentialInfo provides a mock function with given fields: ctx, username
func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
ret := _m.Called(ctx, username)
var r0 *internalpb.CredentialInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
return rf(ctx, username)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
r0 = rf(ctx, username)
} else {
@ -210,7 +244,6 @@ func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*i
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, username)
} else {
@ -244,25 +277,33 @@ func (_c *MockCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInf
return _c
}
func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockCache_GetCredentialInfo_Call {
_c.Call.Return(run)
return _c
}
// GetDatabaseAndCollectionName provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) {
ret := _m.Called(ctx, collectionID)
var r0 string
var r1 string
var r2 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (string, string, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok {
r0 = rf(ctx, collectionID)
} else {
r0 = ret.Get(0).(string)
}
var r1 string
if rf, ok := ret.Get(1).(func(context.Context, int64) string); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Get(1).(string)
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok {
r2 = rf(ctx, collectionID)
} else {
@ -296,18 +337,26 @@ func (_c *MockCache_GetDatabaseAndCollectionName_Call) Return(_a0 string, _a1 st
return _c
}
func (_c *MockCache_GetDatabaseAndCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, string, error)) *MockCache_GetDatabaseAndCollectionName_Call {
_c.Call.Return(run)
return _c
}
// GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName, partitionName)
var r0 int64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (int64, error)); ok {
return rf(ctx, database, collectionName, partitionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) int64); ok {
r0 = rf(ctx, database, collectionName, partitionName)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, database, collectionName, partitionName)
} else {
@ -343,11 +392,20 @@ func (_c *MockCache_GetPartitionID_Call) Return(_a0 int64, _a1 error) *MockCache
return _c
}
func (_c *MockCache_GetPartitionID_Call) RunAndReturn(run func(context.Context, string, string, string) (int64, error)) *MockCache_GetPartitionID_Call {
_c.Call.Return(run)
return _c
}
// GetPartitionInfo provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) GetPartitionInfo(ctx context.Context, database string, collectionName string, partitionName string) (*partitionInfo, error) {
ret := _m.Called(ctx, database, collectionName, partitionName)
var r0 *partitionInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*partitionInfo, error)); ok {
return rf(ctx, database, collectionName, partitionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *partitionInfo); ok {
r0 = rf(ctx, database, collectionName, partitionName)
} else {
@ -356,7 +414,6 @@ func (_m *MockCache) GetPartitionInfo(ctx context.Context, database string, coll
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, database, collectionName, partitionName)
} else {
@ -392,11 +449,20 @@ func (_c *MockCache_GetPartitionInfo_Call) Return(_a0 *partitionInfo, _a1 error)
return _c
}
func (_c *MockCache_GetPartitionInfo_Call) RunAndReturn(run func(context.Context, string, string, string) (*partitionInfo, error)) *MockCache_GetPartitionInfo_Call {
_c.Call.Return(run)
return _c
}
// GetPartitions provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetPartitions(ctx context.Context, database string, collectionName string) (map[string]int64, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 map[string]int64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (map[string]int64, error)); ok {
return rf(ctx, database, collectionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) map[string]int64); ok {
r0 = rf(ctx, database, collectionName)
} else {
@ -405,7 +471,6 @@ func (_m *MockCache) GetPartitions(ctx context.Context, database string, collect
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, database, collectionName)
} else {
@ -440,6 +505,11 @@ func (_c *MockCache_GetPartitions_Call) Return(_a0 map[string]int64, _a1 error)
return _c
}
func (_c *MockCache_GetPartitions_Call) RunAndReturn(run func(context.Context, string, string) (map[string]int64, error)) *MockCache_GetPartitions_Call {
_c.Call.Return(run)
return _c
}
// GetPrivilegeInfo provides a mock function with given fields: ctx
func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string {
ret := _m.Called(ctx)
@ -479,22 +549,30 @@ func (_c *MockCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockCache_GetPr
return _c
}
// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName
func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string) (map[string][]nodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName)
func (_c *MockCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockCache_GetPrivilegeInfo_Call {
_c.Call.Return(run)
return _c
}
// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID
func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName, collectionID)
var r0 map[string][]nodeInfo
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string) map[string][]nodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName)
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)); ok {
return rf(ctx, withCache, database, collectionName, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) map[string][]nodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]nodeInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string) error); ok {
r1 = rf(ctx, withCache, database, collectionName)
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64) error); ok {
r1 = rf(ctx, withCache, database, collectionName, collectionID)
} else {
r1 = ret.Error(1)
}
@ -512,13 +590,14 @@ type MockCache_GetShards_Call struct {
// - withCache bool
// - database string
// - collectionName string
func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}) *MockCache_GetShards_Call {
return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName)}
// - collectionID int64
func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}) *MockCache_GetShards_Call {
return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName, collectionID)}
}
func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string)) *MockCache_GetShards_Call {
func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64)) *MockCache_GetShards_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string))
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64))
})
return _c
}
@ -528,6 +607,11 @@ func (_c *MockCache_GetShards_Call) Return(_a0 map[string][]nodeInfo, _a1 error)
return _c
}
func (_c *MockCache_GetShards_Call) RunAndReturn(run func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)) *MockCache_GetShards_Call {
_c.Call.Return(run)
return _c
}
// GetUserRole provides a mock function with given fields: username
func (_m *MockCache) GetUserRole(username string) []string {
ret := _m.Called(username)
@ -567,6 +651,11 @@ func (_c *MockCache_GetUserRole_Call) Return(_a0 []string) *MockCache_GetUserRol
return _c
}
func (_c *MockCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockCache_GetUserRole_Call {
_c.Call.Return(run)
return _c
}
// InitPolicyInfo provides a mock function with given fields: info, userRoles
func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) {
_m.Called(info, userRoles)
@ -596,6 +685,11 @@ func (_c *MockCache_InitPolicyInfo_Call) Return() *MockCache_InitPolicyInfo_Call
return _c
}
func (_c *MockCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockCache_InitPolicyInfo_Call {
_c.Call.Return(run)
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op)
@ -633,6 +727,11 @@ func (_c *MockCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockCache_Refresh
return _c
}
func (_c *MockCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockCache_RefreshPolicyInfo_Call {
_c.Call.Return(run)
return _c
}
// RemoveCollection provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) RemoveCollection(ctx context.Context, database string, collectionName string) {
_m.Called(ctx, database, collectionName)
@ -663,6 +762,11 @@ func (_c *MockCache_RemoveCollection_Call) Return() *MockCache_RemoveCollection_
return _c
}
func (_c *MockCache_RemoveCollection_Call) RunAndReturn(run func(context.Context, string, string)) *MockCache_RemoveCollection_Call {
_c.Call.Return(run)
return _c
}
// RemoveCollectionsByID provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) RemoveCollectionsByID(ctx context.Context, collectionID int64) []string {
ret := _m.Called(ctx, collectionID)
@ -703,6 +807,11 @@ func (_c *MockCache_RemoveCollectionsByID_Call) Return(_a0 []string) *MockCache_
return _c
}
func (_c *MockCache_RemoveCollectionsByID_Call) RunAndReturn(run func(context.Context, int64) []string) *MockCache_RemoveCollectionsByID_Call {
_c.Call.Return(run)
return _c
}
// RemoveCredential provides a mock function with given fields: username
func (_m *MockCache) RemoveCredential(username string) {
_m.Called(username)
@ -731,6 +840,11 @@ func (_c *MockCache_RemoveCredential_Call) Return() *MockCache_RemoveCredential_
return _c
}
func (_c *MockCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockCache_RemoveCredential_Call {
_c.Call.Return(run)
return _c
}
// RemoveDatabase provides a mock function with given fields: ctx, database
func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) {
_m.Called(ctx, database)
@ -760,6 +874,11 @@ func (_c *MockCache_RemoveDatabase_Call) Return() *MockCache_RemoveDatabase_Call
return _c
}
func (_c *MockCache_RemoveDatabase_Call) RunAndReturn(run func(context.Context, string)) *MockCache_RemoveDatabase_Call {
_c.Call.Return(run)
return _c
}
// RemovePartition provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) RemovePartition(ctx context.Context, database string, collectionName string, partitionName string) {
_m.Called(ctx, database, collectionName, partitionName)
@ -791,6 +910,11 @@ func (_c *MockCache_RemovePartition_Call) Return() *MockCache_RemovePartition_Ca
return _c
}
func (_c *MockCache_RemovePartition_Call) RunAndReturn(run func(context.Context, string, string, string)) *MockCache_RemovePartition_Call {
_c.Call.Return(run)
return _c
}
// UpdateCredential provides a mock function with given fields: credInfo
func (_m *MockCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
_m.Called(credInfo)
@ -819,6 +943,11 @@ func (_c *MockCache_UpdateCredential_Call) Return() *MockCache_UpdateCredential_
return _c
}
func (_c *MockCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
_c.Call.Return(run)
return _c
}
// expireShardLeaderCache provides a mock function with given fields: ctx
func (_m *MockCache) expireShardLeaderCache(ctx context.Context) {
_m.Called(ctx)
@ -847,6 +976,11 @@ func (_c *MockCache_expireShardLeaderCache_Call) Return() *MockCache_expireShard
return _c
}
func (_c *MockCache_expireShardLeaderCache_Call) RunAndReturn(run func(context.Context)) *MockCache_expireShardLeaderCache_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewMockCache interface {
mock.TestingT
Cleanup(func())

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.21.1. DO NOT EDIT.
// Code generated by mockery v2.23.1. DO NOT EDIT.
package proxy

View File

@ -709,7 +709,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error {
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
continue
}
collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName)
collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName, id)
if err != nil {
log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName),
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))

View File

@ -370,10 +370,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName)
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil {
log.Warn("Proxy::queryTask::PreExecute failed to GetCollectionInfo from cache",
zap.String("collectionName", collectionName), zap.Error(err2))
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID),
zap.Error(err2))
return err2
}
@ -417,10 +418,11 @@ func (t *queryTask) Execute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collection: t.collectionName,
nq: 1,
exec: t.queryShard,
db: t.request.GetDbName(),
collectionID: t.CollectionID,
collectionName: t.collectionName,
nq: 1,
exec: t.queryShard,
})
if err != nil {
log.Warn("fail to execute query", zap.Error(err))

View File

@ -361,10 +361,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName)
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil {
log.Warn("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache",
zap.Any("collectionName", collectionName), zap.Error(err2))
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
return err2
}
guaranteeTs := t.request.GetGuaranteeTimestamp()
@ -417,10 +417,11 @@ func (t *searchTask) Execute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collection: t.collectionName,
nq: t.Nq,
exec: t.searchShard,
db: t.request.GetDbName(),
collectionID: t.SearchRequest.CollectionID,
collectionName: t.collectionName,
nq: t.Nq,
exec: t.searchShard,
})
if err != nil {
log.Warn("search execute failed", zap.Error(err))

View File

@ -139,26 +139,28 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
}
// check if collection/partitions are loaded into query node
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, partIDs)
log := log.Ctx(ctx)
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
log := log.Ctx(ctx).With(
zap.String("collectionName", g.collectionName),
zap.Int64("collectionID", g.CollectionID),
)
if err != nil {
g.fromDataCoord = true
g.unloadedPartitionIDs = partIDs
log.Info("checkFullLoaded failed, try get statistics from DataCoord", zap.Error(err))
log.Info("checkFullLoaded failed, try get statistics from DataCoord",
zap.Error(err))
return nil
}
if len(unloaded) > 0 {
g.fromDataCoord = true
g.unloadedPartitionIDs = unloaded
log.Info("some partitions has not been loaded, try get statistics from DataCoord",
zap.String("collection", g.collectionName),
zap.Int64s("unloaded partitions", unloaded))
}
if len(loaded) > 0 {
g.fromQueryNode = true
g.loadedPartitionIDs = loaded
log.Info("some partitions has been loaded, try get statistics from QueryNode",
zap.String("collection", g.collectionName),
zap.Int64s("loaded partitions", loaded))
}
return nil
@ -266,10 +268,11 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]()
}
err := g.lb.Execute(ctx, CollectionWorkLoad{
db: g.request.GetDbName(),
collection: g.collectionName,
nq: 1,
exec: g.getStatisticsShard,
db: g.request.GetDbName(),
collectionID: g.GetStatisticsRequest.CollectionID,
collectionName: g.collectionName,
nq: 1,
exec: g.getStatisticsShard,
})
if err != nil {
@ -317,14 +320,14 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
// checkFullLoaded check if collection / partition was fully loaded into QueryNode
// return loaded partitions, unloaded partitions and error
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
var loadedPartitionIDs []UniqueID
var unloadPartitionIDs []UniqueID
// TODO: Consider to check if partition loaded from cache to save rpc.
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, collectionID)
if err != nil {
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err)
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collectionName = %s,collectionID = %d, err = %s", collectionName, collectionID, err)
}
// If request to search partitions
@ -338,10 +341,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
PartitionIDs: searchPartitionIDs,
})
if err != nil {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
}
for i, percentage := range resp.GetInMemoryPercentages() {
@ -363,10 +366,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
CollectionID: info.collID,
})
if err != nil {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
}
loadedMap := make(map[UniqueID]bool)

View File

@ -44,7 +44,8 @@ type StatisticTaskSuite struct {
lb LBPolicy
collection string
collectionName string
collectionID int64
}
func (s *StatisticTaskSuite) SetupSuite() {
@ -87,7 +88,7 @@ func (s *StatisticTaskSuite) SetupTest() {
err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)
s.NoError(err)
s.collection = "test_statistics_task"
s.collectionName = "test_statistics_task"
s.loadCollection()
}
@ -104,7 +105,7 @@ func (s *StatisticTaskSuite) loadCollection() {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
schema := constructCollectionSchemaByDataType(s.collection, fieldName2Types, testInt64Field, false)
schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
@ -112,7 +113,7 @@ func (s *StatisticTaskSuite) loadCollection() {
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: s.collection,
CollectionName: s.collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
},
@ -125,7 +126,7 @@ func (s *StatisticTaskSuite) loadCollection() {
s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collection)
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collectionName)
s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
@ -137,6 +138,7 @@ func (s *StatisticTaskSuite) loadCollection() {
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
s.collectionID = collectionID
}
func (s *StatisticTaskSuite) TearDownSuite() {
@ -164,7 +166,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
return &getStatisticsTask{
Condition: NewTaskCondition(ctx),
ctx: ctx,
collectionName: s.collection,
collectionName: s.collectionName,
result: &milvuspb.GetStatisticsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -175,7 +177,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionName: s.collection,
CollectionName: s.collectionName,
},
qc: s.qc,
lb: s.lb,
@ -195,6 +197,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_NotShardLeader() {
Reason: "error",
},
}, nil)
s.NoError(task.PreExecute(ctx))
s.Error(task.Execute(ctx))
s.NoError(task.PostExecute(ctx))
}
@ -211,6 +214,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_UnexpectedError() {
Reason: "error",
},
}, nil)
s.NoError(task.PreExecute(ctx))
s.Error(task.Execute(ctx))
s.NoError(task.PostExecute(ctx))
}
@ -220,8 +224,10 @@ func (s *StatisticTaskSuite) TestStatisticTask_Success() {
task := s.getStatisticsTask(ctx)
s.NoError(task.OnEnqueue())
task.fromQueryNode = true
s.qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil)
s.NoError(task.PreExecute(ctx))
task.fromQueryNode = true
task.fromDataCoord = false
s.NoError(task.Execute(ctx))
s.NoError(task.PostExecute(ctx))
}