From b6f69fe7f2f41ebee8aa999105b386c1e08284b9 Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Mon, 3 Apr 2023 14:26:23 +0800 Subject: [PATCH] Check if loaded before delegator search/query (#23162) Signed-off-by: bigsheeper --- internal/querynodev2/delegator/delegator.go | 11 +++ .../querynodev2/delegator/delegator_test.go | 67 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 94eab862dc..357e5786a0 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/samber/lo" @@ -205,6 +206,11 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) } + partitions := req.GetReq().GetPartitionIDs() + if !sd.collection.ExistPartition(partitions...) { + return nil, merr.WrapErrPartitionNotLoaded(partitions) + } + // wait tsafe err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) if err != nil { @@ -251,6 +257,11 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) } + partitions := req.GetReq().GetPartitionIDs() + if !sd.collection.ExistPartition(partitions...) { + return nil, merr.WrapErrPartitionNotLoaded(partitions) + } + // wait tsafe err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) if err != nil { diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index 08b3c3eb61..5f61dc5c00 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/util/commonpbutil" + "github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/samber/lo" "github.com/stretchr/testify/assert" @@ -259,6 +260,39 @@ func (s *DelegatorSuite) TestSearch() { s.Equal(3, len(results)) }) + s.Run("partition_not_loaded", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + }() + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + + workers[1] = worker1 + workers[2] = worker2 + + worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). + Return(&internalpb.SearchResults{}, nil) + worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). + Return(&internalpb.SearchResults{}, nil) + + s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := s.delegator.Search(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: commonpbutil.NewMsgBase(), + PartitionIDs: []int64{500}, + }, + DmlChannels: []string{s.vchannelName}, + }) + + errors.Is(err, merr.ErrPartitionNotLoaded) + }) + s.Run("worker_return_error", func() { defer func() { s.workerManager.ExpectedCalls = nil @@ -478,6 +512,39 @@ func (s *DelegatorSuite) TestQuery() { s.Equal(3, len(results)) }) + s.Run("partition_not_loaded", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + }() + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + + workers[1] = worker1 + workers[2] = worker2 + + worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")). + Return(&internalpb.RetrieveResults{}, nil) + worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")). + Return(&internalpb.RetrieveResults{}, nil) + + s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := s.delegator.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase(), + PartitionIDs: []int64{500}, + }, + DmlChannels: []string{s.vchannelName}, + }) + + errors.Is(err, merr.ErrPartitionNotLoaded) + }) + s.Run("worker_return_error", func() { defer func() { s.workerManager.ExpectedCalls = nil