diff --git a/internal/proxy/mock_cache_test.go b/internal/proxy/mock_cache_test.go index 069cacdc5d..e8adfce412 100644 --- a/internal/proxy/mock_cache_test.go +++ b/internal/proxy/mock_cache_test.go @@ -9,11 +9,13 @@ import ( type getCollectionIDFunc func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) type getCollectionSchemaFunc func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) +type getCollectionInfoFunc func(ctx context.Context, collectionName string) (*collectionInfo, error) type mockCache struct { Cache getIDFunc getCollectionIDFunc getSchemaFunc getCollectionSchemaFunc + getInfoFunc getCollectionInfoFunc } func (m *mockCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { @@ -30,6 +32,13 @@ func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName stri return nil, nil } +func (m *mockCache) GetCollectionInfo(ctx context.Context, collectionName string) (*collectionInfo, error) { + if m.getInfoFunc != nil { + return m.getInfoFunc(ctx, collectionName) + } + return nil, nil +} + func (m *mockCache) RemoveCollection(ctx context.Context, collectionName string) { } @@ -41,6 +50,10 @@ func (m *mockCache) setGetSchemaFunc(f getCollectionSchemaFunc) { m.getSchemaFunc = f } +func (m *mockCache) setGetInfoFunc(f getCollectionInfoFunc) { + m.getInfoFunc = f +} + func newMockCache() *mockCache { return &mockCache{} } diff --git a/internal/proxy/query_coord_mock_test.go b/internal/proxy/query_coord_mock_test.go index 7b1032ba4e..d109ab94e0 100644 --- a/internal/proxy/query_coord_mock_test.go +++ b/internal/proxy/query_coord_mock_test.go @@ -243,20 +243,10 @@ func (coord *QueryCoordMock) ResetShowPartitionsFunc() { } func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - if !coord.healthy() { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "unhealthy", - }, - }, nil - } - if coord.showPartitionsFunc != nil { return coord.showPartitionsFunc(ctx, req) } - - panic("implement me") + return nil, nil } func (coord *QueryCoordMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 79c47684d2..51677f02f4 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "regexp" "strconv" "strings" "sync" @@ -56,6 +55,41 @@ type queryTask struct { shardMgr *shardClientMgr } +// translateOutputFields translates output fields name to output fields id. +func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) { + outputFieldIDs := make([]UniqueID, 0, len(outputFields)) + if len(outputFields) == 0 { + for _, field := range schema.Fields { + if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector { + outputFieldIDs = append(outputFieldIDs, field.FieldID) + } + } + } else { + addPrimaryKey := false + for _, reqField := range outputFields { + findField := false + for _, field := range schema.Fields { + if reqField == field.Name { + if field.IsPrimaryKey { + addPrimaryKey = true + } + findField = true + outputFieldIDs = append(outputFieldIDs, field.FieldID) + } else { + if field.IsPrimaryKey && !addPrimaryKey { + outputFieldIDs = append(outputFieldIDs, field.FieldID) + addPrimaryKey = true + } + } + } + if !findField { + return nil, fmt.Errorf("field %s not exist", reqField) + } + } + } + return outputFieldIDs, nil +} + func (t *queryTask) PreExecute(ctx context.Context) error { if t.queryShardPolicy == nil { t.queryShardPolicy = roundRobinPolicy @@ -68,73 +102,46 @@ func (t *queryTask) PreExecute(ctx context.Context) error { t.collectionName = collectionName if err := validateCollectionName(collectionName); err != nil { log.Warn("Invalid collection name.", zap.String("collectionName", collectionName), - zap.Int64("requestID", t.Base.MsgID), zap.String("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) return err } log.Info("Validate collection name.", zap.Any("collectionName", collectionName), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) if err != nil { log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return err } t.CollectionID = collID log.Info("Get collection ID by name", zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) for _, tag := range t.request.PartitionNames { if err := validatePartitionTag(tag, false); err != nil { log.Warn("invalid partition name", zap.String("partition name", tag), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return err } } log.Debug("Validate partition names.", - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) - t.PartitionIDs = make([]UniqueID, 0) - partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName) + t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames()) if err != nil { log.Warn("failed to get partitions in collection.", zap.String("collection name", collectionName), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Error(err), + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return err } log.Debug("Get partitions in collection.", zap.Any("collectionName", collectionName), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) - // Check if partitions are valid partitions in collection - partitionsRecord := make(map[UniqueID]bool) - for _, partitionName := range t.request.PartitionNames { - pattern := fmt.Sprintf("^%s$", partitionName) - re, err := regexp.Compile(pattern) - if err != nil { - log.Debug("failed to compile partition name regex expression.", zap.Any("partition name", partitionName), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) - return errors.New("invalid partition names") - } - found := false - for name, pID := range partitionsMap { - if re.MatchString(name) { - if _, exist := partitionsRecord[pID]; !exist { - t.PartitionIDs = append(t.PartitionIDs, pID) - partitionsRecord[pID] = true - } - found = true - } - } - if !found { - // FIXME(wxyu): undefined behavior - errMsg := fmt.Sprintf("partition name: %s not found", partitionName) - return errors.New(errMsg) - } - } - - loaded, err := t.checkIfLoaded(collID, t.PartitionIDs) + loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.RetrieveRequest.GetPartitionIDs()) if err != nil { return fmt.Errorf("checkIfLoaded failed when query, collection:%v, partitions:%v, err = %s", collectionName, t.request.GetPartitionNames(), err) } @@ -167,41 +174,16 @@ func (t *queryTask) PreExecute(ctx context.Context) error { return err } log.Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) - if len(t.request.OutputFields) == 0 { - for _, field := range schema.Fields { - if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector { - t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID) - } - } - } else { - addPrimaryKey := false - for _, reqField := range t.request.OutputFields { - findField := false - for _, field := range schema.Fields { - if reqField == field.Name { - if field.IsPrimaryKey { - addPrimaryKey = true - } - findField = true - t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID) - plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID) - } else { - if field.IsPrimaryKey && !addPrimaryKey { - t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID) - plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID) - addPrimaryKey = true - } - } - } - if !findField { - return fmt.Errorf("field %s not exist", reqField) - } - } + outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema) + if err != nil { + return err } + t.RetrieveRequest.OutputFieldsId = outputFieldIDs + plan.OutputFieldIds = outputFieldIDs log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId), - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan) if err != nil { @@ -229,7 +211,9 @@ func (t *queryTask) PreExecute(ctx context.Context) error { t.DbID = 0 // TODO log.Info("Query PreExecute done.", - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"), + zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", t.GetTravelTimestamp()), + zap.Uint64("timeout_ts", t.GetTimeoutTimestamp())) return nil } @@ -250,6 +234,7 @@ func (t *queryTask) Execute(ctx context.Context) error { leaders := leaders t.runningGroup.Go(func() error { log.Debug("proxy starting to query one shard", + zap.Int64("msgID", t.ID()), zap.Int64("collectionID", t.CollectionID), zap.String("collection name", t.collectionName), zap.String("shard channel", channelID), @@ -269,7 +254,8 @@ func (t *queryTask) Execute(ctx context.Context) error { err := executeQuery(WithCache) if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { - log.Warn("invalid shard leaders cache, updating shardleader caches and retry search") + log.Warn("invalid shard leaders cache, updating shardleader caches and retry search", + zap.Int64("msgID", t.ID()), zap.Error(err)) return executeQuery(WithoutCache) } if err != nil { @@ -277,7 +263,7 @@ func (t *queryTask) Execute(ctx context.Context) error { } log.Info("Query Execute done.", - zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) return nil } @@ -294,14 +280,14 @@ func (t *queryTask) PostExecute(ctx context.Context) error { for { select { case <-t.TraceCtx().Done(): - log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, taskID:", t.ID())) + log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID())) return case <-t.runningGroupCtx.Done(): - log.Debug("all queries are finished or canceled", zap.Any("taskID", t.ID())) + log.Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID())) close(t.resultBuf) for res := range t.resultBuf { t.toReduceResults = append(t.toReduceResults, res) - log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("taskID", t.ID())) + log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID())) } wg.Done() return @@ -325,7 +311,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { ErrorCode: commonpb.ErrorCode_Success, } } else { - log.Info("Query result is nil", zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query")) + log.Info("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) t.result.Status = &commonpb.Status{ ErrorCode: commonpb.ErrorCode_EmptyCollection, Reason: "emptly collection", // TODO @@ -346,7 +332,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { } } } - log.Info("Query PostExecute done", zap.Any("requestID", t.Base.MsgID), zap.String("requestType", "query")) + log.Info("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) return nil } @@ -360,93 +346,37 @@ func (t *queryTask) queryShard(ctx context.Context, leaders []nodeInfo, channelI result, err := qn.Query(ctx, req) if err != nil { - log.Warn("QueryNode query return error", zap.Int64("nodeID", nodeID), zap.String("channel", channelID), - zap.Error(err)) + log.Warn("QueryNode query return error", zap.Int64("msgID", t.ID()), + zap.Int64("nodeID", nodeID), zap.String("channel", channelID), zap.Error(err)) return err } if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { - log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.String("channel", channelID)) + log.Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()), + zap.Int64("nodeID", nodeID), zap.String("channel", channelID)) return errInvalidShardLeaders } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("QueryNode query result error", zap.Int64("nodeID", nodeID), - zap.String("reason", result.GetStatus().GetReason())) + log.Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), + zap.Int64("nodeID", nodeID), zap.String("reason", result.GetStatus().GetReason())) return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason()) } - log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID)) + log.Debug("get query result", zap.Int64("msgID", t.ID()), + zap.Int64("nodeID", nodeID), zap.String("channelID", channelID)) t.resultBuf <- result return nil } err := t.queryShardPolicy(t.TraceCtx(), t.shardMgr, query, leaders) if err != nil { - log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders)) + log.Warn("fail to Query to all shard leaders", zap.Int64("msgID", t.ID()), + zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders)) return err } return nil } -func (t *queryTask) checkIfLoaded(collectionID UniqueID, queryPartitionIDs []UniqueID) (bool, error) { - // check if collection was loaded into QueryNode - info, err := globalMetaCache.GetCollectionInfo(t.ctx, t.collectionName) - if err != nil { - return false, fmt.Errorf("GetCollectionInfo failed, collectionID = %d, err = %s", collectionID, err) - } - if info.isLoaded { - return true, nil - } - - // If request to query partitions - if len(queryPartitionIDs) > 0 { - resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowCollections, - MsgID: t.Base.MsgID, - Timestamp: t.Base.Timestamp, - SourceID: Params.ProxyCfg.GetNodeID(), - }, - CollectionID: collectionID, - PartitionIDs: queryPartitionIDs, - }) - if err != nil { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, queryPartitionIDs, err) - } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, queryPartitionIDs, resp.GetStatus().GetReason()) - } - // Current logic: show partitions won't return error if the given partitions are all loaded - return true, nil - } - - // If request to query collection and collection is not fully loaded - resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowCollections, - MsgID: t.Base.MsgID, - Timestamp: t.Base.Timestamp, - SourceID: Params.ProxyCfg.GetNodeID(), - }, - CollectionID: collectionID, - }) - if err != nil { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, queryPartitionIDs, err) - } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, queryPartitionIDs, resp.GetStatus().GetReason()) - } - - if len(resp.GetPartitionIDs()) > 0 { - log.Warn("collection not fully loaded, query on these partitions", - zap.Int64("collectionID", collectionID), - zap.Int64s("partitionIDs", resp.GetPartitionIDs())) - return true, nil - } - - return false, nil -} - // IDs2Expr converts ids slices to bool expresion with specified field name func IDs2Expr(fieldName string, ids *schemapb.IDs) string { var idsStr string diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 18a8ee8be8..558096a890 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -6,11 +6,12 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/common" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" @@ -129,9 +130,6 @@ func TestQueryTask_all(t *testing.T) { queryShardPolicy: roundRobinPolicy, shardMgr: mgr, } - for i := 0; i < len(fieldName2Types); i++ { - task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i) - } assert.NoError(t, task.OnEnqueue()) @@ -157,6 +155,11 @@ func TestQueryTask_all(t *testing.T) { }, } + outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types)) + for i := 0; i < len(fieldName2Types); i++ { + outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i)) + } + task.RetrieveRequest.OutputFieldsId = outputFieldIDs for fieldName, dataType := range fieldName2Types { result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum)) } @@ -168,159 +171,3 @@ func TestQueryTask_all(t *testing.T) { assert.NoError(t, task.PostExecute(ctx)) } - -func TestCheckIfLoaded(t *testing.T) { - var err error - - Params.Init() - var ( - rc = NewRootCoordMock() - qc = NewQueryCoordMock() - ctx = context.TODO() - ) - - err = rc.Start() - defer rc.Stop() - require.NoError(t, err) - mgr := newShardClientMgr() - err = InitMetaCache(rc, qc, mgr) - require.NoError(t, err) - - err = qc.Start() - defer qc.Stop() - require.NoError(t, err) - - getQueryTask := func(t *testing.T, collName string) *queryTask { - task := &queryTask{ - ctx: ctx, - RetrieveRequest: &internalpb.RetrieveRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Retrieve, - }, - }, - request: &milvuspb.QueryRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Retrieve, - }, - CollectionName: collName, - }, - qc: qc, - } - require.NoError(t, task.OnEnqueue()) - return task - } - - t.Run("test checkIfLoaded error", func(t *testing.T) { - collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr() - createColl(t, collName, rc) - collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName) - require.NoError(t, err) - task := getQueryTask(t, collName) - task.collectionName = collName - - t.Run("show collection err", func(t *testing.T) { - qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - return nil, fmt.Errorf("mock") - }) - - loaded, err := task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - }) - - t.Run("show collection status unexpected error", func(t *testing.T) { - qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - return &querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock", - }, - }, nil - }) - - loaded, err := task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - assert.Error(t, task.PreExecute(ctx)) - qc.ResetShowCollectionsFunc() - }) - - t.Run("show partition error", func(t *testing.T) { - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock", - }, - }, nil - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{1}) - assert.Error(t, err) - assert.False(t, loaded) - }) - - t.Run("show partition status unexpected error", func(t *testing.T) { - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return nil, fmt.Errorf("mock error") - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{1}) - assert.Error(t, err) - assert.False(t, loaded) - }) - - t.Run("show partitions success", func(t *testing.T) { - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - }, nil - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{1}) - assert.NoError(t, err) - assert.True(t, loaded) - qc.ResetShowPartitionsFunc() - }) - - t.Run("show collection success but not loaded", func(t *testing.T) { - qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - return &querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - CollectionIDs: []UniqueID{collID}, - InMemoryPercentages: []int64{0}, - }, nil - }) - - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return nil, fmt.Errorf("mock error") - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return nil, fmt.Errorf("mock error") - }) - loaded, err = task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - PartitionIDs: []UniqueID{1}, - }, nil - }) - loaded, err = task.checkIfLoaded(collID, []UniqueID{}) - assert.NoError(t, err) - assert.True(t, loaded) - }) - - qc.ResetShowCollectionsFunc() - qc.ResetShowPartitionsFunc() - }) -} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 969ac1c5ab..819d56b057 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -202,7 +202,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } // check if collection/partitions are loaded into query node - loaded, err := t.checkIfLoaded(collID, t.SearchRequest.GetPartitionIDs()) + loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.SearchRequest.GetPartitionIDs()) if err != nil { return fmt.Errorf("checkIfLoaded failed when search, collection:%v, partitions:%v, err = %s", collectionName, t.request.GetPartitionNames(), err) } @@ -472,11 +472,11 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []nodeInfo, channe return nil } -func (t *searchTask) checkIfLoaded(collectionID UniqueID, searchPartitionIDs []UniqueID) (bool, error) { - // check if collection was loaded into QueryNode - info, err := globalMetaCache.GetCollectionInfo(t.ctx, t.collectionName) +// checkIfLoaded check if collection was loaded into QueryNode +func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) (bool, error) { + info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName) if err != nil { - return false, fmt.Errorf("GetCollectionInfo failed, collectionID = %d, err = %s", collectionID, err) + return false, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err) } if info.isLoaded { return true, nil @@ -484,48 +484,45 @@ func (t *searchTask) checkIfLoaded(collectionID UniqueID, searchPartitionIDs []U // If request to search partitions if len(searchPartitionIDs) > 0 { - resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{ + resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowCollections, - MsgID: t.Base.MsgID, - Timestamp: t.Base.Timestamp, - SourceID: Params.ProxyCfg.GetNodeID(), + MsgType: commonpb.MsgType_ShowCollections, + SourceID: Params.ProxyCfg.GetNodeID(), }, - CollectionID: collectionID, + CollectionID: info.collID, PartitionIDs: searchPartitionIDs, }) if err != nil { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err) + return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err) } if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason()) + return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason()) } // Current logic: show partitions won't return error if the given partitions are all loaded return true, nil } // If request to search collection and collection is not fully loaded - resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{ + resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowCollections, - MsgID: t.Base.MsgID, - Timestamp: t.Base.Timestamp, - SourceID: Params.ProxyCfg.GetNodeID(), + MsgType: commonpb.MsgType_ShowCollections, + SourceID: Params.ProxyCfg.GetNodeID(), }, - CollectionID: collectionID, + CollectionID: info.collID, }) if err != nil { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err) + return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err) } if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason()) + return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason()) } if len(resp.GetPartitionIDs()) > 0 { - log.Warn("collection not fully loaded, search on these partitions", zap.Int64("msgID", t.ID()), - zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", resp.GetPartitionIDs())) + log.Warn("collection not fully loaded, search on these partitions", + zap.String("collection", collectionName), + zap.Int64("collectionID", info.collID), zap.Int64s("partitionIDs", resp.GetPartitionIDs())) return true, nil } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index a1e267181c..b8d4c14905 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -196,21 +196,11 @@ func TestSearchTask_PreExecute(t *testing.T) { t.Run("test checkIfLoaded error", func(t *testing.T) { collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr() createColl(t, collName, rc) - collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName) + _, err := globalMetaCache.GetCollectionID(context.TODO(), collName) require.NoError(t, err) task := getSearchTask(t, collName) task.collectionName = collName - t.Run("show collection err", func(t *testing.T) { - qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - return nil, errors.New("mock") - }) - - loaded, err := task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - }) - t.Run("show collection status unexpected error", func(t *testing.T) { qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { return &querypb.ShowCollectionsResponse{ @@ -221,88 +211,10 @@ func TestSearchTask_PreExecute(t *testing.T) { }, nil }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) assert.Error(t, task.PreExecute(ctx)) qc.ResetShowCollectionsFunc() }) - t.Run("show partition error", func(t *testing.T) { - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock", - }, - }, nil - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{1}) - assert.Error(t, err) - assert.False(t, loaded) - }) - - t.Run("show partition status unexpected error", func(t *testing.T) { - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return nil, errors.New("mock error") - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{1}) - assert.Error(t, err) - assert.False(t, loaded) - }) - - t.Run("show partitions success", func(t *testing.T) { - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - }, nil - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{1}) - assert.NoError(t, err) - assert.True(t, loaded) - qc.ResetShowPartitionsFunc() - }) - - t.Run("show collection success but not loaded", func(t *testing.T) { - qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - return &querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - CollectionIDs: []UniqueID{collID}, - InMemoryPercentages: []int64{0}, - }, nil - }) - - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return nil, errors.New("mock error") - }) - loaded, err := task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return nil, errors.New("mock error") - }) - loaded, err = task.checkIfLoaded(collID, []UniqueID{}) - assert.Error(t, err) - assert.False(t, loaded) - - qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - PartitionIDs: []UniqueID{1}, - }, nil - }) - loaded, err = task.checkIfLoaded(collID, []UniqueID{}) - assert.NoError(t, err) - assert.True(t, loaded) - }) - qc.ResetShowCollectionsFunc() qc.ResetShowPartitionsFunc() }) @@ -1553,3 +1465,129 @@ func Test_reduceSearchResultData_str(t *testing.T) { // hard to compare floating point value. // TODO: compare scores. } + +func Test_checkIfLoaded(t *testing.T) { + t.Run("failed to get collection info", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return nil, errors.New("mock") + }) + globalMetaCache = cache + var qc types.QueryCoord + _, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{}) + assert.Error(t, err) + }) + + t.Run("collection loaded", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: true}, nil + }) + globalMetaCache = cache + var qc types.QueryCoord + loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{}) + assert.NoError(t, err) + assert.True(t, loaded) + }) + + t.Run("show partitions failed", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return nil, errors.New("mock") + }) + _, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2}) + assert.Error(t, err) + }) + + t.Run("show partitions but didn't success", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil + }) + _, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2}) + assert.Error(t, err) + }) + + t.Run("partitions loaded", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil + }) + loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2}) + assert.NoError(t, err) + assert.True(t, loaded) + }) + + t.Run("no specified partitions, show partitions failed", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return nil, errors.New("mock") + }) + _, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{}) + assert.Error(t, err) + }) + + t.Run("no specified partitions, show partitions but didn't succeed", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil + }) + _, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{}) + assert.Error(t, err) + }) + + t.Run("not fully loaded", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{1, 2}}, nil + }) + loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{}) + assert.NoError(t, err) + assert.True(t, loaded) + }) + + t.Run("not loaded", func(t *testing.T) { + cache := newMockCache() + cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) { + return &collectionInfo{isLoaded: false}, nil + }) + globalMetaCache = cache + qc := NewQueryCoordMock() + qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{}}, nil + }) + loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{}) + assert.NoError(t, err) + assert.False(t, loaded) + }) +} diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 65622b2d27..dfbd2f996b 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -1505,7 +1505,7 @@ func genSimpleRetrievePlan(collection *Collection) (*RetrievePlan, error) { } timestamp := retrieveMsg.RetrieveRequest.TravelTimestamp - plan, err2 := createRetrievePlanByExpr(collection, retrieveMsg.SerializedExprPlan, timestamp) + plan, err2 := createRetrievePlanByExpr(collection, retrieveMsg.SerializedExprPlan, timestamp, 100) return plan, err2 } diff --git a/internal/querynode/plan.go b/internal/querynode/plan.go index a89bd32a51..5598e0c859 100644 --- a/internal/querynode/plan.go +++ b/internal/querynode/plan.go @@ -136,7 +136,7 @@ func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh plan: plan, cPlaceholderGroup: cPlaceholderGroup, timestamp: req.Req.GetTravelTimestamp(), - msgID: req.GetReq().GetReqID(), + msgID: req.GetReq().GetBase().GetMsgID(), } return ret, nil @@ -175,9 +175,10 @@ func parseSearchRequest(plan *SearchPlan, searchRequestBlob []byte) (*searchRequ type RetrievePlan struct { cRetrievePlan C.CRetrievePlan Timestamp Timestamp + msgID UniqueID // only used to debug. } -func createRetrievePlanByExpr(col *Collection, expr []byte, timestamp Timestamp) (*RetrievePlan, error) { +func createRetrievePlanByExpr(col *Collection, expr []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) { var cPlan C.CRetrievePlan status := C.CreateRetrievePlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) @@ -189,6 +190,7 @@ func createRetrievePlanByExpr(col *Collection, expr []byte, timestamp Timestamp) var newPlan = &RetrievePlan{ cRetrievePlan: cPlan, Timestamp: timestamp, + msgID: msgID, } return newPlan, nil } diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index 7c87ac5a59..b83f31e779 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -303,7 +303,9 @@ func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, erro status := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult) metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) - log.Debug("do retrieve on segment", zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String())) + log.Debug("do retrieve on segment", + zap.Int64("msgID", plan.msgID), + zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String())) if err := HandleCStatus(&status, "Retrieve failed"); err != nil { return nil, err } diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index 55d520f5d6..a6afa26e78 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -200,7 +200,7 @@ func TestSegment_retrieve(t *testing.T) { // } planExpr, err := proto.Marshal(planNode) assert.NoError(t, err) - plan, err := createRetrievePlanByExpr(collection, planExpr, 100) + plan, err := createRetrievePlanByExpr(collection, planExpr, 100, 100) defer plan.delete() assert.NoError(t, err) diff --git a/internal/querynode/task_query.go b/internal/querynode/task_query.go index 7c0444c0a4..22f9c77577 100644 --- a/internal/querynode/task_query.go +++ b/internal/querynode/task_query.go @@ -66,12 +66,13 @@ func (q *queryTask) queryOnStreaming() error { q.QS.collection.RLock() // locks the collectionPtr defer q.QS.collection.RUnlock() if _, released := q.QS.collection.getReleaseTime(); released { - log.Debug("collection release before search", zap.Int64("collectionID", q.CollectionID)) + log.Debug("collection release before search", zap.Int64("msgID", q.ID()), + zap.Int64("collectionID", q.CollectionID)) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID) } // deserialize query plan - plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp) + plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp, q.ID()) if err != nil { return err } @@ -113,12 +114,13 @@ func (q *queryTask) queryOnHistorical() error { defer q.QS.collection.RUnlock() if _, released := q.QS.collection.getReleaseTime(); released { - log.Debug("collection release before search", zap.Int64("collectionID", q.CollectionID)) + log.Debug("collection release before search", zap.Int64("msgID", q.ID()), + zap.Int64("collectionID", q.CollectionID)) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID) } // deserialize query plan - plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp) + plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp, q.ID()) if err != nil { return err }