diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 056eb47a58..a6bd702242 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -80,29 +80,39 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string, return nil, err } - partitionsRecord := make(map[UniqueID]bool) - partitionIDs = make([]UniqueID, 0, len(partitionNames)) + useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool() + + partitionsSet := typeutil.NewSet[int64]() for _, partitionName := range partitionNames { - pattern := fmt.Sprintf("^%s$", partitionName) - re, err := regexp.Compile(pattern) - if err != nil { - return nil, fmt.Errorf("invalid partition: %s", partitionName) - } - found := false - for name, pID := range partitionsMap { - if re.MatchString(name) { - if _, exist := partitionsRecord[pID]; !exist { - partitionIDs = append(partitionIDs, pID) - partitionsRecord[pID] = true + if useRegexp { + // Legacy feature, use partition name as regexp + pattern := fmt.Sprintf("^%s$", partitionName) + re, err := regexp.Compile(pattern) + if err != nil { + return nil, fmt.Errorf("invalid partition: %s", partitionName) + } + var found bool + for name, pID := range partitionsMap { + if re.MatchString(name) { + partitionsSet.Insert(pID) + found = true } - found = true + } + if !found { + return nil, fmt.Errorf("partition name %s not found", partitionName) + } + } else { + partitionID, found := partitionsMap[partitionName] + if !found { + // TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName) + return nil, fmt.Errorf("partition name %s not found", partitionName) + } + if !partitionsSet.Contain(partitionID) { + partitionsSet.Insert(partitionID) } } - if !found { - return nil, fmt.Errorf("partition name %s not found", partitionName) - } } - return partitionIDs, nil + return partitionsSet.Collect(), nil } // parseSearchInfo returns QueryInfo and offset diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index f744a4a777..bb7216ebe8 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -2144,3 +2145,83 @@ func TestSearchTask_Requery(t *testing.T) { assert.Error(t, err) }) } + +type GetPartitionIDsSuite struct { + suite.Suite + + mockMetaCache *MockCache +} + +func (s *GetPartitionIDsSuite) SetupTest() { + s.mockMetaCache = NewMockCache(s.T()) + globalMetaCache = s.mockMetaCache +} + +func (s *GetPartitionIDsSuite) TearDownTest() { + globalMetaCache = nil + Params.Reset(Params.ProxyCfg.PartitionNameRegexp.Key) +} + +func (s *GetPartitionIDsSuite) TestPlainPartitionNames() { + Params.Save(Params.ProxyCfg.PartitionNameRegexp.Key, "false") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(map[string]int64{"partition_1": 100, "partition_2": 200}, nil).Once() + + result, err := getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_1", "partition_2"}) + + s.NoError(err) + s.ElementsMatch([]int64{100, 200}, result) + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(map[string]int64{"partition_1": 100}, nil).Once() + + _, err = getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_1", "partition_2"}) + s.Error(err) + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mocked")).Once() + _, err = getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_1", "partition_2"}) + s.Error(err) +} + +func (s *GetPartitionIDsSuite) TestRegexpPartitionNames() { + Params.Save(Params.ProxyCfg.PartitionNameRegexp.Key, "true") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(map[string]int64{"partition_1": 100, "partition_2": 200}, nil).Once() + + result, err := getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_1", "partition_2"}) + + s.NoError(err) + s.ElementsMatch([]int64{100, 200}, result) + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(map[string]int64{"partition_1": 100, "partition_2": 200}, nil).Once() + + result, err = getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_.*"}) + + s.NoError(err) + s.ElementsMatch([]int64{100, 200}, result) + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(map[string]int64{"partition_1": 100}, nil).Once() + + _, err = getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_1", "partition_2"}) + s.Error(err) + + s.mockMetaCache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mocked")).Once() + _, err = getPartitionIDs(ctx, "default_db", "test_collection", []string{"partition_1", "partition_2"}) + s.Error(err) +} + +func TestGetPartitionIDs(t *testing.T) { + suite.Run(t, new(GetPartitionIDsSuite)) +} diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 1025570d3a..64611e2289 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -905,6 +905,7 @@ type proxyConfig struct { CostMetricsExpireTime ParamItem `refreshable:"true"` RetryTimesOnReplica ParamItem `refreshable:"true"` RetryTimesOnHealthCheck ParamItem `refreshable:"true"` + PartitionNameRegexp ParamItem `refreshable:"true"` AccessLog AccessLogConfig } @@ -1190,6 +1191,14 @@ please adjust in embedded Milvus: false`, Doc: "set query node unavailable on proxy when heartbeat failures reach this limit", } p.RetryTimesOnHealthCheck.Init(base.mgr) + + p.PartitionNameRegexp = ParamItem{ + Key: "proxy.partitionNameRegexp", + Version: "2.3.4", + DefaultValue: "false", + Doc: "switch for whether proxy shall use partition name as regexp when searching", + } + p.PartitionNameRegexp.Init(base.mgr) } // /////////////////////////////////////////////////////////////////////////////