mirror of https://github.com/milvus-io/milvus.git
Fix collection and channel not match (#25859)
Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/26058/head
parent
eade5f9b7f
commit
9614e61f14
|
@ -34,20 +34,22 @@ import (
|
|||
type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error
|
||||
|
||||
type ChannelWorkload struct {
|
||||
db string
|
||||
collection string
|
||||
channel string
|
||||
shardLeaders []int64
|
||||
nq int64
|
||||
exec executeFunc
|
||||
retryTimes uint
|
||||
db string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
channel string
|
||||
shardLeaders []int64
|
||||
nq int64
|
||||
exec executeFunc
|
||||
retryTimes uint
|
||||
}
|
||||
|
||||
type CollectionWorkLoad struct {
|
||||
db string
|
||||
collection string
|
||||
nq int64
|
||||
exec executeFunc
|
||||
db string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
nq int64
|
||||
exec executeFunc
|
||||
}
|
||||
|
||||
type LBPolicy interface {
|
||||
|
@ -89,7 +91,8 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) {
|
|||
// try to select the best node from the available nodes
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
|
||||
log := log.With(
|
||||
zap.String("collectionName", workload.collection),
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("collectionName", workload.collectionName),
|
||||
zap.String("channelName", workload.channel),
|
||||
)
|
||||
|
||||
|
@ -98,7 +101,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
|
|||
}
|
||||
|
||||
getShardLeaders := func() ([]int64, error) {
|
||||
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collection)
|
||||
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -109,7 +112,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
|
|||
availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
|
||||
targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
|
||||
if err != nil {
|
||||
globalMetaCache.DeprecateShardCache(workload.db, workload.collection)
|
||||
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
|
||||
nodes, err := getShardLeaders()
|
||||
if err != nil || len(nodes) == 0 {
|
||||
log.Warn("failed to get shard delegator",
|
||||
|
@ -141,7 +144,8 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
|
|||
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
|
||||
excludeNodes := typeutil.NewUniqueSet()
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collectionName", workload.collection),
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("collectionName", workload.collectionName),
|
||||
zap.String("channelName", workload.channel),
|
||||
)
|
||||
|
||||
|
@ -185,7 +189,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
|||
|
||||
// Execute will execute collection workload in parallel
|
||||
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collection)
|
||||
dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collectionName, workload.collectionID)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
|
||||
return err
|
||||
|
@ -197,13 +201,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
|
|||
nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID })
|
||||
wg.Go(func() error {
|
||||
err := lb.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: workload.db,
|
||||
collection: workload.collection,
|
||||
channel: channel,
|
||||
shardLeaders: nodes,
|
||||
nq: workload.nq,
|
||||
exec: workload.exec,
|
||||
retryTimes: uint(len(nodes)),
|
||||
db: workload.db,
|
||||
collectionName: workload.collectionName,
|
||||
collectionID: workload.collectionID,
|
||||
channel: channel,
|
||||
shardLeaders: nodes,
|
||||
nq: workload.nq,
|
||||
exec: workload.exec,
|
||||
retryTimes: uint(len(nodes)),
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
|
|
@ -53,7 +53,8 @@ type LBPolicySuite struct {
|
|||
channels []string
|
||||
qnList []*mocks.MockQueryNode
|
||||
|
||||
collection string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) SetupSuite() {
|
||||
|
@ -108,7 +109,7 @@ func (s *LBPolicySuite) SetupTest() {
|
|||
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
|
||||
s.NoError(err)
|
||||
|
||||
s.collection = "test_lb_policy"
|
||||
s.collectionName = "test_lb_policy"
|
||||
s.loadCollection()
|
||||
}
|
||||
|
||||
|
@ -125,7 +126,7 @@ func (s *LBPolicySuite) loadCollection() {
|
|||
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
|
||||
}
|
||||
|
||||
schema := constructCollectionSchemaByDataType(s.collection, fieldName2Types, testInt64Field, false)
|
||||
schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
|
@ -133,7 +134,7 @@ func (s *LBPolicySuite) loadCollection() {
|
|||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: s.collection,
|
||||
CollectionName: s.collectionName,
|
||||
DbName: dbName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
|
@ -147,7 +148,7 @@ func (s *LBPolicySuite) loadCollection() {
|
|||
s.NoError(createColT.Execute(ctx))
|
||||
s.NoError(createColT.PostExecute(ctx))
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collection)
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collectionName)
|
||||
s.NoError(err)
|
||||
|
||||
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
|
@ -159,17 +160,19 @@ func (s *LBPolicySuite) loadCollection() {
|
|||
})
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
s.collectionID = collectionID
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNode() {
|
||||
ctx := context.Background()
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, typeutil.NewUniqueSet())
|
||||
s.NoError(err)
|
||||
s.Equal(int64(5), targetNode)
|
||||
|
@ -179,11 +182,12 @@ func (s *LBPolicySuite) TestSelectNode() {
|
|||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: []int64{},
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: []int64{},
|
||||
nq: 1,
|
||||
}, typeutil.NewUniqueSet())
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), targetNode)
|
||||
|
@ -192,11 +196,12 @@ func (s *LBPolicySuite) TestSelectNode() {
|
|||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: []int64{},
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: []int64{},
|
||||
nq: 1,
|
||||
}, typeutil.NewUniqueSet())
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
s.Equal(int64(-1), targetNode)
|
||||
|
@ -205,11 +210,12 @@ func (s *LBPolicySuite) TestSelectNode() {
|
|||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, typeutil.NewUniqueSet(s.nodes...))
|
||||
s.ErrorIs(err, merr.ErrServiceUnavailable)
|
||||
s.Equal(int64(-1), targetNode)
|
||||
|
@ -220,11 +226,12 @@ func (s *LBPolicySuite) TestSelectNode() {
|
|||
s.qc.ExpectedCalls = nil
|
||||
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, typeutil.NewUniqueSet())
|
||||
s.ErrorIs(err, merr.ErrServiceUnavailable)
|
||||
s.Equal(int64(-1), targetNode)
|
||||
|
@ -239,11 +246,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
return nil
|
||||
},
|
||||
|
@ -255,11 +263,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
return nil
|
||||
},
|
||||
|
@ -274,11 +283,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
return nil
|
||||
},
|
||||
|
@ -291,11 +301,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
return nil
|
||||
},
|
||||
|
@ -311,11 +322,12 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
|||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
counter := 0
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
counter++
|
||||
if counter == 1 {
|
||||
|
@ -336,9 +348,10 @@ func (s *LBPolicySuite) TestExecute() {
|
|||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
return nil
|
||||
},
|
||||
|
@ -348,9 +361,10 @@ func (s *LBPolicySuite) TestExecute() {
|
|||
// test some channel failed
|
||||
counter := atomic.NewInt64(0)
|
||||
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
if counter.Add(1) == 1 {
|
||||
return nil
|
||||
|
@ -363,12 +377,13 @@ func (s *LBPolicySuite) TestExecute() {
|
|||
|
||||
// test get shard leader failed
|
||||
s.qc.ExpectedCalls = nil
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collection)
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, mockErr)
|
||||
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collection: s.collection,
|
||||
nq: 1,
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error {
|
||||
return nil
|
||||
},
|
||||
|
|
|
@ -57,8 +57,8 @@ type Cache interface {
|
|||
GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error)
|
||||
// GetDatabaseAndCollectionName get collection's name and database by id
|
||||
GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error)
|
||||
// GetCollectionInfo get collection's information by name, such as collection id, schema, and etc.
|
||||
GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error)
|
||||
// GetCollectionInfo get collection's information by name or collection id, such as schema, and etc.
|
||||
GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error)
|
||||
// GetPartitionID get partition's identifier of specific collection.
|
||||
GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
|
||||
// GetPartitions get all partitions' id of specific collection.
|
||||
|
@ -67,7 +67,7 @@ type Cache interface {
|
|||
GetPartitionInfo(ctx context.Context, database, collectionName string, partitionName string) (*partitionInfo, error)
|
||||
// GetCollectionSchema get collection's schema.
|
||||
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error)
|
||||
GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error)
|
||||
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
|
||||
DeprecateShardCache(database, collectionName string)
|
||||
expireShardLeaderCache(ctx context.Context)
|
||||
RemoveCollection(ctx context.Context, database, collectionName string)
|
||||
|
@ -229,7 +229,7 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam
|
|||
collInfo, ok = db[collectionName]
|
||||
}
|
||||
|
||||
method := "GeCollectionID"
|
||||
method := "GetCollectionID"
|
||||
if !ok || !collInfo.isCollectionCached() {
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
|
||||
tr := timerecord.NewTimeRecorder("UpdateCache")
|
||||
|
@ -289,7 +289,7 @@ func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collection
|
|||
|
||||
// GetCollectionInfo returns the collection information related to provided collection name
|
||||
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
|
||||
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string) (*collectionInfo, error) {
|
||||
func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) {
|
||||
m.mu.RLock()
|
||||
var collInfo *collectionInfo
|
||||
var ok bool
|
||||
|
@ -301,10 +301,17 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN
|
|||
m.mu.RUnlock()
|
||||
|
||||
method := "GetCollectionInfo"
|
||||
if !ok || !collInfo.isCollectionCached() {
|
||||
// if collInfo.collID != collectionID, means that the cache is not trustable
|
||||
// try to get collection according to collectionID
|
||||
if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID {
|
||||
tr := timerecord.NewTimeRecorder("UpdateCache")
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
|
||||
coll, err := m.describeCollection(ctx, database, collectionName, 0)
|
||||
var coll *milvuspb.DescribeCollectionResponse
|
||||
var err error
|
||||
|
||||
// collectionName maybe not trustable, get collection according to id
|
||||
coll, err = m.describeCollection(ctx, database, "", collectionID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -695,8 +702,12 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
|
|||
}
|
||||
|
||||
// GetShards update cache if withCache == false
|
||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string) (map[string][]nodeInfo, error) {
|
||||
info, err := m.GetCollectionInfo(ctx, database, collectionName)
|
||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collectionName", collectionName),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
|
||||
info, err := m.GetCollectionInfo(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -715,8 +726,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
|||
}
|
||||
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
|
||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
|
||||
zap.String("collectionName", collectionName))
|
||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
|
||||
}
|
||||
req := &querypb.GetShardLeadersRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
|
@ -754,9 +764,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
|||
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
|
||||
info, err = m.GetCollectionInfo(ctx, database, collectionName)
|
||||
info, err = m.GetCollectionInfo(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get shards, collection %s not found", collectionName)
|
||||
return nil, fmt.Errorf("failed to get shards, collectionName %s, colectionID %d not found", collectionName, collectionID)
|
||||
}
|
||||
// lock leader
|
||||
info.leaderMutex.Lock()
|
||||
|
|
|
@ -474,6 +474,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
var (
|
||||
ctx = context.Background()
|
||||
collectionName = "collection1"
|
||||
collectionID = int64(1)
|
||||
)
|
||||
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
|
@ -488,7 +489,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
defer qc.Stop()
|
||||
|
||||
t.Run("No collection in meta cache", func(t *testing.T) {
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists")
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
@ -503,7 +504,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
}, nil)
|
||||
shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName)
|
||||
shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName, collectionID)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
@ -524,7 +525,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
}, nil)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 1, len(shards))
|
||||
|
@ -537,7 +538,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||
Reason: "not implemented",
|
||||
},
|
||||
}, nil)
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName)
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
|
@ -550,6 +551,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
var (
|
||||
ctx = context.TODO()
|
||||
collectionName = "collection1"
|
||||
collectionID = int64(1)
|
||||
)
|
||||
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
|
@ -588,7 +590,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
}, nil)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName)
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, shards)
|
||||
require.Equal(t, 1, len(shards))
|
||||
|
@ -602,7 +604,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
Reason: "not implemented",
|
||||
},
|
||||
}, nil)
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName)
|
||||
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
@ -706,26 +708,26 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
|
|||
InMemoryPercentages: []int64{100, 50},
|
||||
}, nil)
|
||||
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
// no collectionInfo of collection1, should access RootCoord
|
||||
assert.Equal(t, rootCoord.GetAccessCount(), 1)
|
||||
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
// shouldn't access RootCoord again
|
||||
assert.Equal(t, rootCoord.GetAccessCount(), 1)
|
||||
|
||||
globalMetaCache.RemoveCollection(ctx, dbName, "collection1")
|
||||
// no collectionInfo of collection2, should access RootCoord
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
// shouldn't access RootCoord again
|
||||
assert.Equal(t, rootCoord.GetAccessCount(), 2)
|
||||
|
||||
globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1))
|
||||
// no collectionInfo of collection2, should access RootCoord
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1")
|
||||
_, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
// shouldn't access RootCoord again
|
||||
assert.Equal(t, rootCoord.GetAccessCount(), 3)
|
||||
|
@ -761,7 +763,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}, nil)
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodeInfos["channel-1"], 3)
|
||||
|
||||
|
@ -780,7 +782,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
|
|||
}, nil)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
return len(nodeInfos["channel-1"]) == 2
|
||||
}, 3*time.Second, 1*time.Second)
|
||||
|
@ -800,7 +802,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
|
|||
}, nil)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
return len(nodeInfos["channel-1"]) == 3
|
||||
}, 3*time.Second, 1*time.Second)
|
||||
|
@ -825,7 +827,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) {
|
|||
}, nil)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1")
|
||||
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
|
||||
assert.NoError(t, err)
|
||||
return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3
|
||||
}, 3*time.Second, 1*time.Second)
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
// Code generated by mockery v2.16.0. DO NOT EDIT.
|
||||
// Code generated by mockery v2.23.1. DO NOT EDIT.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
context "context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/stretchr/testify/mock"
|
||||
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// MockCache is an autogenerated mock type for the Cache type
|
||||
|
@ -55,18 +55,26 @@ func (_c *MockCache_DeprecateShardCache_Call) Return() *MockCache_DeprecateShard
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_DeprecateShardCache_Call) RunAndReturn(run func(string, string)) *MockCache_DeprecateShardCache_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionID provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) GetCollectionID(ctx context.Context, database string, collectionName string) (int64, error) {
|
||||
ret := _m.Called(ctx, database, collectionName)
|
||||
|
||||
var r0 int64
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (int64, error)); ok {
|
||||
return rf(ctx, database, collectionName)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) int64); ok {
|
||||
r0 = rf(ctx, database, collectionName)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int64)
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, database, collectionName)
|
||||
} else {
|
||||
|
@ -101,22 +109,30 @@ func (_c *MockCache_GetCollectionID_Call) Return(_a0 int64, _a1 error) *MockCach
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string) (*collectionInfo, error) {
|
||||
ret := _m.Called(ctx, database, collectionName)
|
||||
func (_c *MockCache_GetCollectionID_Call) RunAndReturn(run func(context.Context, string, string) (int64, error)) *MockCache_GetCollectionID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName, collectionID
|
||||
func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionInfo, error) {
|
||||
ret := _m.Called(ctx, database, collectionName, collectionID)
|
||||
|
||||
var r0 *collectionInfo
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) *collectionInfo); ok {
|
||||
r0 = rf(ctx, database, collectionName)
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*collectionInfo, error)); ok {
|
||||
return rf(ctx, database, collectionName, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *collectionInfo); ok {
|
||||
r0 = rf(ctx, database, collectionName, collectionID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*collectionInfo)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, database, collectionName)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok {
|
||||
r1 = rf(ctx, database, collectionName, collectionID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
@ -133,13 +149,14 @@ type MockCache_GetCollectionInfo_Call struct {
|
|||
// - ctx context.Context
|
||||
// - database string
|
||||
// - collectionName string
|
||||
func (_e *MockCache_Expecter) GetCollectionInfo(ctx interface{}, database interface{}, collectionName interface{}) *MockCache_GetCollectionInfo_Call {
|
||||
return &MockCache_GetCollectionInfo_Call{Call: _e.mock.On("GetCollectionInfo", ctx, database, collectionName)}
|
||||
// - collectionID int64
|
||||
func (_e *MockCache_Expecter) GetCollectionInfo(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}) *MockCache_GetCollectionInfo_Call {
|
||||
return &MockCache_GetCollectionInfo_Call{Call: _e.mock.On("GetCollectionInfo", ctx, database, collectionName, collectionID)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, database string, collectionName string)) *MockCache_GetCollectionInfo_Call {
|
||||
func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64)) *MockCache_GetCollectionInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string), args[2].(string))
|
||||
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -149,11 +166,20 @@ func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionInfo, _a1 erro
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Context, string, string, int64) (*collectionInfo, error)) *MockCache_GetCollectionInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
ret := _m.Called(ctx, database, collectionName)
|
||||
|
||||
var r0 *schemapb.CollectionSchema
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (*schemapb.CollectionSchema, error)); ok {
|
||||
return rf(ctx, database, collectionName)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemapb.CollectionSchema); ok {
|
||||
r0 = rf(ctx, database, collectionName)
|
||||
} else {
|
||||
|
@ -162,7 +188,6 @@ func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, c
|
|||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, database, collectionName)
|
||||
} else {
|
||||
|
@ -197,11 +222,20 @@ func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSch
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, string, string) (*schemapb.CollectionSchema, error)) *MockCache_GetCollectionSchema_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCredentialInfo provides a mock function with given fields: ctx, username
|
||||
func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
|
||||
ret := _m.Called(ctx, username)
|
||||
|
||||
var r0 *internalpb.CredentialInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
|
||||
return rf(ctx, username)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
|
||||
r0 = rf(ctx, username)
|
||||
} else {
|
||||
|
@ -210,7 +244,6 @@ func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*i
|
|||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, username)
|
||||
} else {
|
||||
|
@ -244,25 +277,33 @@ func (_c *MockCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInf
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockCache_GetCredentialInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetDatabaseAndCollectionName provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *MockCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
||||
var r0 string
|
||||
var r1 string
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) (string, string, error)); ok {
|
||||
return rf(ctx, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok {
|
||||
r0 = rf(ctx, collectionID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
var r1 string
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int64) string); ok {
|
||||
r1 = rf(ctx, collectionID)
|
||||
} else {
|
||||
r1 = ret.Get(1).(string)
|
||||
}
|
||||
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok {
|
||||
r2 = rf(ctx, collectionID)
|
||||
} else {
|
||||
|
@ -296,18 +337,26 @@ func (_c *MockCache_GetDatabaseAndCollectionName_Call) Return(_a0 string, _a1 st
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetDatabaseAndCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, string, error)) *MockCache_GetDatabaseAndCollectionName_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName
|
||||
func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) {
|
||||
ret := _m.Called(ctx, database, collectionName, partitionName)
|
||||
|
||||
var r0 int64
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (int64, error)); ok {
|
||||
return rf(ctx, database, collectionName, partitionName)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) int64); ok {
|
||||
r0 = rf(ctx, database, collectionName, partitionName)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int64)
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
|
||||
r1 = rf(ctx, database, collectionName, partitionName)
|
||||
} else {
|
||||
|
@ -343,11 +392,20 @@ func (_c *MockCache_GetPartitionID_Call) Return(_a0 int64, _a1 error) *MockCache
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetPartitionID_Call) RunAndReturn(run func(context.Context, string, string, string) (int64, error)) *MockCache_GetPartitionID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPartitionInfo provides a mock function with given fields: ctx, database, collectionName, partitionName
|
||||
func (_m *MockCache) GetPartitionInfo(ctx context.Context, database string, collectionName string, partitionName string) (*partitionInfo, error) {
|
||||
ret := _m.Called(ctx, database, collectionName, partitionName)
|
||||
|
||||
var r0 *partitionInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*partitionInfo, error)); ok {
|
||||
return rf(ctx, database, collectionName, partitionName)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *partitionInfo); ok {
|
||||
r0 = rf(ctx, database, collectionName, partitionName)
|
||||
} else {
|
||||
|
@ -356,7 +414,6 @@ func (_m *MockCache) GetPartitionInfo(ctx context.Context, database string, coll
|
|||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
|
||||
r1 = rf(ctx, database, collectionName, partitionName)
|
||||
} else {
|
||||
|
@ -392,11 +449,20 @@ func (_c *MockCache_GetPartitionInfo_Call) Return(_a0 *partitionInfo, _a1 error)
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetPartitionInfo_Call) RunAndReturn(run func(context.Context, string, string, string) (*partitionInfo, error)) *MockCache_GetPartitionInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPartitions provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) GetPartitions(ctx context.Context, database string, collectionName string) (map[string]int64, error) {
|
||||
ret := _m.Called(ctx, database, collectionName)
|
||||
|
||||
var r0 map[string]int64
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (map[string]int64, error)); ok {
|
||||
return rf(ctx, database, collectionName)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) map[string]int64); ok {
|
||||
r0 = rf(ctx, database, collectionName)
|
||||
} else {
|
||||
|
@ -405,7 +471,6 @@ func (_m *MockCache) GetPartitions(ctx context.Context, database string, collect
|
|||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, database, collectionName)
|
||||
} else {
|
||||
|
@ -440,6 +505,11 @@ func (_c *MockCache_GetPartitions_Call) Return(_a0 map[string]int64, _a1 error)
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetPartitions_Call) RunAndReturn(run func(context.Context, string, string) (map[string]int64, error)) *MockCache_GetPartitions_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPrivilegeInfo provides a mock function with given fields: ctx
|
||||
func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string {
|
||||
ret := _m.Called(ctx)
|
||||
|
@ -479,22 +549,30 @@ func (_c *MockCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockCache_GetPr
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName
|
||||
func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string) (map[string][]nodeInfo, error) {
|
||||
ret := _m.Called(ctx, withCache, database, collectionName)
|
||||
func (_c *MockCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockCache_GetPrivilegeInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID
|
||||
func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
|
||||
ret := _m.Called(ctx, withCache, database, collectionName, collectionID)
|
||||
|
||||
var r0 map[string][]nodeInfo
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string) map[string][]nodeInfo); ok {
|
||||
r0 = rf(ctx, withCache, database, collectionName)
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)); ok {
|
||||
return rf(ctx, withCache, database, collectionName, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) map[string][]nodeInfo); ok {
|
||||
r0 = rf(ctx, withCache, database, collectionName, collectionID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string][]nodeInfo)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string) error); ok {
|
||||
r1 = rf(ctx, withCache, database, collectionName)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64) error); ok {
|
||||
r1 = rf(ctx, withCache, database, collectionName, collectionID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
@ -512,13 +590,14 @@ type MockCache_GetShards_Call struct {
|
|||
// - withCache bool
|
||||
// - database string
|
||||
// - collectionName string
|
||||
func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}) *MockCache_GetShards_Call {
|
||||
return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName)}
|
||||
// - collectionID int64
|
||||
func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}) *MockCache_GetShards_Call {
|
||||
return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName, collectionID)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string)) *MockCache_GetShards_Call {
|
||||
func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64)) *MockCache_GetShards_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string))
|
||||
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -528,6 +607,11 @@ func (_c *MockCache_GetShards_Call) Return(_a0 map[string][]nodeInfo, _a1 error)
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShards_Call) RunAndReturn(run func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)) *MockCache_GetShards_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetUserRole provides a mock function with given fields: username
|
||||
func (_m *MockCache) GetUserRole(username string) []string {
|
||||
ret := _m.Called(username)
|
||||
|
@ -567,6 +651,11 @@ func (_c *MockCache_GetUserRole_Call) Return(_a0 []string) *MockCache_GetUserRol
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockCache_GetUserRole_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// InitPolicyInfo provides a mock function with given fields: info, userRoles
|
||||
func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) {
|
||||
_m.Called(info, userRoles)
|
||||
|
@ -596,6 +685,11 @@ func (_c *MockCache_InitPolicyInfo_Call) Return() *MockCache_InitPolicyInfo_Call
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockCache_InitPolicyInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RefreshPolicyInfo provides a mock function with given fields: op
|
||||
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
|
||||
ret := _m.Called(op)
|
||||
|
@ -633,6 +727,11 @@ func (_c *MockCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockCache_Refresh
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockCache_RefreshPolicyInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveCollection provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) RemoveCollection(ctx context.Context, database string, collectionName string) {
|
||||
_m.Called(ctx, database, collectionName)
|
||||
|
@ -663,6 +762,11 @@ func (_c *MockCache_RemoveCollection_Call) Return() *MockCache_RemoveCollection_
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveCollection_Call) RunAndReturn(run func(context.Context, string, string)) *MockCache_RemoveCollection_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveCollectionsByID provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *MockCache) RemoveCollectionsByID(ctx context.Context, collectionID int64) []string {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
@ -703,6 +807,11 @@ func (_c *MockCache_RemoveCollectionsByID_Call) Return(_a0 []string) *MockCache_
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveCollectionsByID_Call) RunAndReturn(run func(context.Context, int64) []string) *MockCache_RemoveCollectionsByID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveCredential provides a mock function with given fields: username
|
||||
func (_m *MockCache) RemoveCredential(username string) {
|
||||
_m.Called(username)
|
||||
|
@ -731,6 +840,11 @@ func (_c *MockCache_RemoveCredential_Call) Return() *MockCache_RemoveCredential_
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockCache_RemoveCredential_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveDatabase provides a mock function with given fields: ctx, database
|
||||
func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) {
|
||||
_m.Called(ctx, database)
|
||||
|
@ -760,6 +874,11 @@ func (_c *MockCache_RemoveDatabase_Call) Return() *MockCache_RemoveDatabase_Call
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveDatabase_Call) RunAndReturn(run func(context.Context, string)) *MockCache_RemoveDatabase_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemovePartition provides a mock function with given fields: ctx, database, collectionName, partitionName
|
||||
func (_m *MockCache) RemovePartition(ctx context.Context, database string, collectionName string, partitionName string) {
|
||||
_m.Called(ctx, database, collectionName, partitionName)
|
||||
|
@ -791,6 +910,11 @@ func (_c *MockCache_RemovePartition_Call) Return() *MockCache_RemovePartition_Ca
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemovePartition_Call) RunAndReturn(run func(context.Context, string, string, string)) *MockCache_RemovePartition_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateCredential provides a mock function with given fields: credInfo
|
||||
func (_m *MockCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
|
||||
_m.Called(credInfo)
|
||||
|
@ -819,6 +943,11 @@ func (_c *MockCache_UpdateCredential_Call) Return() *MockCache_UpdateCredential_
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// expireShardLeaderCache provides a mock function with given fields: ctx
|
||||
func (_m *MockCache) expireShardLeaderCache(ctx context.Context) {
|
||||
_m.Called(ctx)
|
||||
|
@ -847,6 +976,11 @@ func (_c *MockCache_expireShardLeaderCache_Call) Return() *MockCache_expireShard
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_expireShardLeaderCache_Call) RunAndReturn(run func(context.Context)) *MockCache_expireShardLeaderCache_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockCache interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Code generated by mockery v2.21.1. DO NOT EDIT.
|
||||
// Code generated by mockery v2.23.1. DO NOT EDIT.
|
||||
|
||||
package proxy
|
||||
|
||||
|
|
|
@ -709,7 +709,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error {
|
|||
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
|
||||
continue
|
||||
}
|
||||
collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName)
|
||||
collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName, id)
|
||||
if err != nil {
|
||||
log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName),
|
||||
zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections"))
|
||||
|
|
|
@ -370,10 +370,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName)
|
||||
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
|
||||
if err2 != nil {
|
||||
log.Warn("Proxy::queryTask::PreExecute failed to GetCollectionInfo from cache",
|
||||
zap.String("collectionName", collectionName), zap.Error(err2))
|
||||
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID),
|
||||
zap.Error(err2))
|
||||
return err2
|
||||
}
|
||||
|
||||
|
@ -417,10 +418,11 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
|||
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collection: t.collectionName,
|
||||
nq: 1,
|
||||
exec: t.queryShard,
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.CollectionID,
|
||||
collectionName: t.collectionName,
|
||||
nq: 1,
|
||||
exec: t.queryShard,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("fail to execute query", zap.Error(err))
|
||||
|
|
|
@ -361,10 +361,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName)
|
||||
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
|
||||
if err2 != nil {
|
||||
log.Warn("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache",
|
||||
zap.Any("collectionName", collectionName), zap.Error(err2))
|
||||
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
|
||||
return err2
|
||||
}
|
||||
guaranteeTs := t.request.GetGuaranteeTimestamp()
|
||||
|
@ -417,10 +417,11 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
|
||||
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collection: t.collectionName,
|
||||
nq: t.Nq,
|
||||
exec: t.searchShard,
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.SearchRequest.CollectionID,
|
||||
collectionName: t.collectionName,
|
||||
nq: t.Nq,
|
||||
exec: t.searchShard,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("search execute failed", zap.Error(err))
|
||||
|
|
|
@ -139,26 +139,28 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// check if collection/partitions are loaded into query node
|
||||
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, partIDs)
|
||||
log := log.Ctx(ctx)
|
||||
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collectionName", g.collectionName),
|
||||
zap.Int64("collectionID", g.CollectionID),
|
||||
)
|
||||
if err != nil {
|
||||
g.fromDataCoord = true
|
||||
g.unloadedPartitionIDs = partIDs
|
||||
log.Info("checkFullLoaded failed, try get statistics from DataCoord", zap.Error(err))
|
||||
log.Info("checkFullLoaded failed, try get statistics from DataCoord",
|
||||
zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
if len(unloaded) > 0 {
|
||||
g.fromDataCoord = true
|
||||
g.unloadedPartitionIDs = unloaded
|
||||
log.Info("some partitions has not been loaded, try get statistics from DataCoord",
|
||||
zap.String("collection", g.collectionName),
|
||||
zap.Int64s("unloaded partitions", unloaded))
|
||||
}
|
||||
if len(loaded) > 0 {
|
||||
g.fromQueryNode = true
|
||||
g.loadedPartitionIDs = loaded
|
||||
log.Info("some partitions has been loaded, try get statistics from QueryNode",
|
||||
zap.String("collection", g.collectionName),
|
||||
zap.Int64s("loaded partitions", loaded))
|
||||
}
|
||||
return nil
|
||||
|
@ -266,10 +268,11 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
|
|||
g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]()
|
||||
}
|
||||
err := g.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: g.request.GetDbName(),
|
||||
collection: g.collectionName,
|
||||
nq: 1,
|
||||
exec: g.getStatisticsShard,
|
||||
db: g.request.GetDbName(),
|
||||
collectionID: g.GetStatisticsRequest.CollectionID,
|
||||
collectionName: g.collectionName,
|
||||
nq: 1,
|
||||
exec: g.getStatisticsShard,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -317,14 +320,14 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
|
|||
|
||||
// checkFullLoaded check if collection / partition was fully loaded into QueryNode
|
||||
// return loaded partitions, unloaded partitions and error
|
||||
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
|
||||
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
|
||||
var loadedPartitionIDs []UniqueID
|
||||
var unloadPartitionIDs []UniqueID
|
||||
|
||||
// TODO: Consider to check if partition loaded from cache to save rpc.
|
||||
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err)
|
||||
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collectionName = %s,collectionID = %d, err = %s", collectionName, collectionID, err)
|
||||
}
|
||||
|
||||
// If request to search partitions
|
||||
|
@ -338,10 +341,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
|
|||
PartitionIDs: searchPartitionIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
for i, percentage := range resp.GetInMemoryPercentages() {
|
||||
|
@ -363,10 +366,10 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st
|
|||
CollectionID: info.collID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
|
||||
return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
loadedMap := make(map[UniqueID]bool)
|
||||
|
|
|
@ -44,7 +44,8 @@ type StatisticTaskSuite struct {
|
|||
|
||||
lb LBPolicy
|
||||
|
||||
collection string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
}
|
||||
|
||||
func (s *StatisticTaskSuite) SetupSuite() {
|
||||
|
@ -87,7 +88,7 @@ func (s *StatisticTaskSuite) SetupTest() {
|
|||
err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)
|
||||
s.NoError(err)
|
||||
|
||||
s.collection = "test_statistics_task"
|
||||
s.collectionName = "test_statistics_task"
|
||||
s.loadCollection()
|
||||
}
|
||||
|
||||
|
@ -104,7 +105,7 @@ func (s *StatisticTaskSuite) loadCollection() {
|
|||
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
|
||||
}
|
||||
|
||||
schema := constructCollectionSchemaByDataType(s.collection, fieldName2Types, testInt64Field, false)
|
||||
schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
|
@ -112,7 +113,7 @@ func (s *StatisticTaskSuite) loadCollection() {
|
|||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: s.collection,
|
||||
CollectionName: s.collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
},
|
||||
|
@ -125,7 +126,7 @@ func (s *StatisticTaskSuite) loadCollection() {
|
|||
s.NoError(createColT.Execute(ctx))
|
||||
s.NoError(createColT.PostExecute(ctx))
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collection)
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), s.collectionName)
|
||||
s.NoError(err)
|
||||
|
||||
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
|
@ -137,6 +138,7 @@ func (s *StatisticTaskSuite) loadCollection() {
|
|||
})
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
s.collectionID = collectionID
|
||||
}
|
||||
|
||||
func (s *StatisticTaskSuite) TearDownSuite() {
|
||||
|
@ -164,7 +166,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
|
|||
return &getStatisticsTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
collectionName: s.collection,
|
||||
collectionName: s.collectionName,
|
||||
result: &milvuspb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -175,7 +177,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
|
|||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: s.collection,
|
||||
CollectionName: s.collectionName,
|
||||
},
|
||||
qc: s.qc,
|
||||
lb: s.lb,
|
||||
|
@ -195,6 +197,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_NotShardLeader() {
|
|||
Reason: "error",
|
||||
},
|
||||
}, nil)
|
||||
s.NoError(task.PreExecute(ctx))
|
||||
s.Error(task.Execute(ctx))
|
||||
s.NoError(task.PostExecute(ctx))
|
||||
}
|
||||
|
@ -211,6 +214,7 @@ func (s *StatisticTaskSuite) TestStatisticTask_UnexpectedError() {
|
|||
Reason: "error",
|
||||
},
|
||||
}, nil)
|
||||
s.NoError(task.PreExecute(ctx))
|
||||
s.Error(task.Execute(ctx))
|
||||
s.NoError(task.PostExecute(ctx))
|
||||
}
|
||||
|
@ -220,8 +224,10 @@ func (s *StatisticTaskSuite) TestStatisticTask_Success() {
|
|||
task := s.getStatisticsTask(ctx)
|
||||
|
||||
s.NoError(task.OnEnqueue())
|
||||
task.fromQueryNode = true
|
||||
s.qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil)
|
||||
s.NoError(task.PreExecute(ctx))
|
||||
task.fromQueryNode = true
|
||||
task.fromDataCoord = false
|
||||
s.NoError(task.Execute(ctx))
|
||||
s.NoError(task.PostExecute(ctx))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue