diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index e6c14d2ac9..cb6d1d28f8 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -253,6 +253,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { log.Warn("failed to get partitions of collection", zap.Error(err)) return err } + loadMeta := packLoadMeta( ex.meta.GetLoadType(task.CollectionID()), "", diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index c826834857..a8495d24b5 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -358,6 +358,10 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() { } } +func (suite *TaskSuite) expectationsForLoadSegments() { + +} + func (suite *TaskSuite) TestLoadSegmentTask() { ctx := context.Background() timeout := 10 * time.Second diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index bb7feddac0..188b5d3810 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/msgpb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" @@ -410,6 +411,21 @@ func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPart proportion := paramtable.Get().DataCoordCfg.SegmentSealProportion.GetAsFloat() maxIndexRecordPerSegment = int64(threshold * proportion / float64(sizePerRecord)) } + vecField, err := typeutil.GetVectorFieldSchema(req.GetSchema()) + if err != nil { + return merr.Status(err), nil + } + indexInfo, ok := lo.Find(req.GetIndexInfoList(), func(info *indexpb.IndexInfo) bool { + return info.GetFieldID() == vecField.GetFieldID() + }) + if !ok || indexInfo == nil { + err = fmt.Errorf("cannot find index info for %s field", vecField.GetName()) + return merr.Status(err), nil + } + metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, indexInfo.GetIndexParams()) + if err != nil { + return merr.Status(err), nil + } node.manager.Collection.Put(req.GetCollectionID(), req.GetSchema(), &segcorepb.CollectionIndexMeta{ IndexMetas: fieldIndexMetas, MaxIndexRowCount: maxIndexRecordPerSegment, @@ -417,6 +433,7 @@ func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPart CollectionID: req.GetCollectionID(), PartitionIDs: req.GetPartitionIDs(), LoadType: querypb.LoadType_LoadCollection, // TODO: dyh, remove loadType in querynode + MetricType: metricType, }) log.Info("load partitions done") diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 8c239f9257..9534419983 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -32,12 +32,14 @@ import ( "github.com/milvus-io/milvus-proto/go-api/msgpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -45,6 +47,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type ServiceSuite struct { @@ -1477,10 +1480,47 @@ func (suite *ServiceSuite) TestLoadPartition() { suite.NoError(err) suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) - // collection not exist and schema is not nil + // no vec field in schema + req.Schema = &schemapb.CollectionSchema{} + status, err = suite.node.LoadPartitions(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) + + // no indexInfo req.Schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) status, err = suite.node.LoadPartitions(ctx, req) suite.NoError(err) + suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) + + // no metric type + vecField, err := typeutil.GetVectorFieldSchema(req.GetSchema()) + suite.NoError(err) + req.IndexInfoList = []*indexpb.IndexInfo{ + { + CollectionID: suite.collectionID, + FieldID: vecField.GetFieldID(), + IndexParams: []*commonpb.KeyValuePair{}, + }, + } + status, err = suite.node.LoadPartitions(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) + + // collection not exist and schema is not nil + req.IndexInfoList = []*indexpb.IndexInfo{ + { + CollectionID: suite.collectionID, + FieldID: vecField.GetFieldID(), + IndexParams: []*commonpb.KeyValuePair{ + { + Key: common.MetricTypeKey, + Value: "L2", + }, + }, + }, + } + status, err = suite.node.LoadPartitions(ctx, req) + suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) // collection existed