mirror of https://github.com/milvus-io/milvus.git
deprecate shard cache immediately if query failed (#22779)
Signed-off-by: Wei Liu <wei.liu@zilliz.com>pull/22828/head
parent
3b2bd089e7
commit
6b5dfa6db2
1
Makefile
1
Makefile
|
@ -346,4 +346,5 @@ generate-mockery: getdeps
|
|||
$(PWD)/bin/mockery --name=GarbageCollector --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=garbage_collector.go --with-expecter --outpkg=mockrootcoord
|
||||
#internal/types
|
||||
$(PWD)/bin/mockery --name=QueryCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querycoord.go --with-expecter --structname=MockQueryCoord --outpkg=types --inpackage
|
||||
$(PWD)/bin/mockery --name=QueryNodeComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querynode.go --with-expecter --structname=MockQueryNode --outpkg=types --inpackage
|
||||
ci-ut: build-cpp-with-coverage generated-proto-go-without-cpp codecov-cpp codecov-go
|
||||
|
|
|
@ -68,7 +68,7 @@ type Cache interface {
|
|||
// GetCollectionSchema get collection's schema.
|
||||
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
|
||||
GetShards(ctx context.Context, withCache bool, collectionName string) (map[string][]nodeInfo, error)
|
||||
ClearShards(collectionName string)
|
||||
DeprecateShardCache(collectionName string)
|
||||
expireShardLeaderCache(ctx context.Context)
|
||||
RemoveCollection(ctx context.Context, collectionName string)
|
||||
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string
|
||||
|
@ -754,8 +754,8 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
|
|||
return shard2QueryNodes
|
||||
}
|
||||
|
||||
// ClearShards clear the shard leader cache of a collection
|
||||
func (m *MetaCache) ClearShards(collectionName string) {
|
||||
// DeprecateShardCache clear the shard leader cache of a collection
|
||||
func (m *MetaCache) DeprecateShardCache(collectionName string) {
|
||||
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
|
||||
m.mu.RLock()
|
||||
info, ok := m.collInfo[collectionName]
|
||||
|
|
|
@ -559,11 +559,11 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
defer qc.Stop()
|
||||
|
||||
t.Run("Clear with no collection info", func(t *testing.T) {
|
||||
globalMetaCache.ClearShards("collection_not_exist")
|
||||
globalMetaCache.DeprecateShardCache("collection_not_exist")
|
||||
})
|
||||
|
||||
t.Run("Clear valid collection empty cache", func(t *testing.T) {
|
||||
globalMetaCache.ClearShards(collectionName)
|
||||
globalMetaCache.DeprecateShardCache(collectionName)
|
||||
})
|
||||
|
||||
t.Run("Clear valid collection valid cache", func(t *testing.T) {
|
||||
|
@ -589,7 +589,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||
require.Equal(t, 1, len(shards))
|
||||
require.Equal(t, 3, len(shards["channel-1"]))
|
||||
|
||||
globalMetaCache.ClearShards(collectionName)
|
||||
globalMetaCache.DeprecateShardCache(collectionName)
|
||||
|
||||
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
|
|
@ -1477,7 +1477,7 @@ func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
func (rct *releaseCollectionTask) PostExecute(ctx context.Context) error {
|
||||
globalMetaCache.ClearShards(rct.CollectionName)
|
||||
globalMetaCache.DeprecateShardCache(rct.CollectionName)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1697,7 +1697,7 @@ func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error {
|
||||
globalMetaCache.ClearShards(rpt.CollectionName)
|
||||
globalMetaCache.DeprecateShardCache(rpt.CollectionName)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -405,7 +405,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
|||
log.Warn("invalid shard leaders cache, updating shardleader caches and retry query",
|
||||
zap.Error(err))
|
||||
// invalidate cache first, since ctx may be canceled or timeout here
|
||||
globalMetaCache.ClearShards(t.collectionName)
|
||||
globalMetaCache.DeprecateShardCache(t.collectionName)
|
||||
err = executeQuery(WithoutCache)
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -468,10 +468,12 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
|
|||
log.Ctx(ctx).Warn("QueryNode query return error",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.Strings("channels", channelIDs), zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(t.collectionName)
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
|
||||
globalMetaCache.DeprecateShardCache(t.collectionName)
|
||||
return errInvalidShardLeaders
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
|
|
|
@ -427,7 +427,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
log.Warn("first search failed, updating shardleader caches and retry search",
|
||||
zap.Error(err))
|
||||
// invalidate cache first, since ctx may be canceled or timeout here
|
||||
globalMetaCache.ClearShards(t.collectionName)
|
||||
globalMetaCache.DeprecateShardCache(t.collectionName)
|
||||
err = executeSearch(WithoutCache)
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -519,12 +519,14 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
|||
zap.Int64("nodeID", nodeID),
|
||||
zap.Strings("channels", channelIDs),
|
||||
zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(t.collectionName)
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Ctx(ctx).Warn("QueryNode is not shardLeader",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.Strings("channels", channelIDs))
|
||||
globalMetaCache.DeprecateShardCache(t.collectionName)
|
||||
return errInvalidShardLeaders
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
|
|
|
@ -288,7 +288,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
|
|||
log.Warn("first get statistics failed, updating shard leader caches and retry",
|
||||
zap.Error(err))
|
||||
// invalidate cache first, since ctx may be canceled or timeout here
|
||||
globalMetaCache.ClearShards(g.collectionName)
|
||||
globalMetaCache.DeprecateShardCache(g.collectionName)
|
||||
err = executeGetStatistics(WithoutCache)
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -304,24 +304,28 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
|
|||
DmlChannels: channelIDs,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
log.Info("xxxx")
|
||||
result, err := qn.GetStatistics(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode statistic return error",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.Strings("channels", channelIDs),
|
||||
zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(g.collectionName)
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Warn("QueryNode is not shardLeader",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.Strings("channels", channelIDs))
|
||||
globalMetaCache.DeprecateShardCache(g.collectionName)
|
||||
return errInvalidShardLeaders
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("QueryNode statistic result error",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
globalMetaCache.DeprecateShardCache(g.collectionName)
|
||||
return fmt.Errorf("fail to get statistic, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
|
||||
}
|
||||
g.resultBuf <- result
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStatisticTask_all(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
ctx = context.TODO()
|
||||
|
||||
rc = NewRootCoordMock()
|
||||
qc = types.NewMockQueryCoord(t)
|
||||
qn = types.NewMockQueryNode(t)
|
||||
|
||||
shardsNum = int32(2)
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
qc.EXPECT().Start().Return(nil)
|
||||
qc.EXPECT().Stop().Return(nil)
|
||||
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil)
|
||||
|
||||
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
||||
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
testBoolField: schemapb.DataType_Bool,
|
||||
testInt32Field: schemapb.DataType_Int32,
|
||||
testInt64Field: schemapb.DataType_Int64,
|
||||
testFloatField: schemapb.DataType_Float,
|
||||
testDoubleField: schemapb.DataType_Double,
|
||||
testFloatVecField: schemapb.DataType_FloatVector,
|
||||
}
|
||||
if enableMultipleVectorFields {
|
||||
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
|
||||
}
|
||||
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
}
|
||||
|
||||
require.NoError(t, createColT.OnEnqueue())
|
||||
require.NoError(t, createColT.PreExecute(ctx))
|
||||
require.NoError(t, createColT.Execute(ctx))
|
||||
require.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: &successStatus,
|
||||
CollectionIDs: []int64{collectionID},
|
||||
InMemoryPercentages: []int64{100},
|
||||
}, nil)
|
||||
|
||||
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
// test begins
|
||||
task := &getStatisticsTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
result: &milvuspb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
request: &milvuspb.GetStatisticsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
qc: qc,
|
||||
shardMgr: mgr,
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
|
||||
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
PartitionIDs: []int64{1, 2, 3},
|
||||
}, nil)
|
||||
|
||||
// test query task with timeout
|
||||
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel1()
|
||||
// before preExecute
|
||||
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
|
||||
task.ctx = ctx1
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
// after preExecute
|
||||
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
||||
|
||||
task.ctx = ctx
|
||||
task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string) error, map[string][]nodeInfo) error {
|
||||
return fmt.Errorf("fake error")
|
||||
}
|
||||
task.fromQueryNode = true
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string) error, map[string][]nodeInfo) error {
|
||||
return errInvalidShardLeaders
|
||||
}
|
||||
task.fromQueryNode = true
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
||||
task.fromQueryNode = true
|
||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("GetStatistics failed")).Times(3)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
||||
task.fromQueryNode = true
|
||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NotShardLeader,
|
||||
Reason: "error",
|
||||
},
|
||||
}, nil).Times(6)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
||||
task.fromQueryNode = true
|
||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "error",
|
||||
},
|
||||
}, nil).Times(3)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
||||
task.fromQueryNode = true
|
||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil).Once()
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue