diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 26090d5e59..287106df96 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -2283,43 +2283,43 @@ func (qt *queryTask) Execute(ctx context.Context) error { return err } -func copyQueryResultData(dst *milvuspb.QueryResults, src *internalpb.RetrieveResults) { - // handles initialization, cannot use idx==0 since first result may be empty - if len(dst.FieldsData) == 0 { - dst.FieldsData = append(dst.FieldsData, src.FieldsData...) - } else { - for i, fieldData := range src.FieldsData { - switch fieldType := fieldData.Field.(type) { - case *schemapb.FieldData_Scalars: - dstScalar := dst.FieldsData[i].GetScalars() - switch srcScalar := fieldType.Scalars.Data.(type) { - case *schemapb.ScalarField_BoolData: - dstScalar.GetBoolData().Data = append(dstScalar.GetBoolData().Data, srcScalar.BoolData.Data...) - case *schemapb.ScalarField_IntData: - dstScalar.GetIntData().Data = append(dstScalar.GetIntData().Data, srcScalar.IntData.Data...) - case *schemapb.ScalarField_LongData: - dstScalar.GetLongData().Data = append(dstScalar.GetLongData().Data, srcScalar.LongData.Data...) - case *schemapb.ScalarField_FloatData: - dstScalar.GetFloatData().Data = append(dstScalar.GetFloatData().Data, srcScalar.FloatData.Data...) - case *schemapb.ScalarField_DoubleData: - dstScalar.GetDoubleData().Data = append(dstScalar.GetDoubleData().Data, srcScalar.DoubleData.Data...) - default: - log.Debug("Query received not supported data type", zap.String("field type", fieldData.Type.String())) - } - case *schemapb.FieldData_Vectors: - dstVector := dst.FieldsData[i].GetVectors() - switch srcVector := fieldType.Vectors.Data.(type) { - case *schemapb.VectorField_BinaryVector: - dstVector.Data.(*schemapb.VectorField_BinaryVector).BinaryVector = - append(dstVector.Data.(*schemapb.VectorField_BinaryVector).BinaryVector, srcVector.BinaryVector...) - case *schemapb.VectorField_FloatVector: - dstVector.GetFloatVector().Data = append(dstVector.GetFloatVector().Data, srcVector.FloatVector.Data...) - default: - log.Debug("Query received not supported data type", zap.String("field type", fieldData.Type.String())) - } +func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) { + var ret *milvuspb.QueryResults + var skipDupCnt int64 = 0 + var idSet = make(map[int64]struct{}) + + // merge results and remove duplicates + for _, rr := range retrieveResults { + // skip empty result, it will break merge result + if rr == nil || rr.Ids == nil { + continue + } + + if ret == nil { + ret = &milvuspb.QueryResults{ + FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)), + } + } + + if len(ret.FieldsData) != len(rr.FieldsData) { + return nil, fmt.Errorf("mismatch FieldData in RetrieveResults") + } + + for i, id := range rr.Ids.GetIntId().GetData() { + if _, ok := idSet[id]; !ok { + typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i)) + idSet[id] = struct{}{} + } else { + // primary keys duplicate + skipDupCnt++ } } } + if skipDupCnt > 0 { + log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) + } + + return ret, nil } func (qt *queryTask) PostExecute(ctx context.Context) error { @@ -2354,31 +2354,21 @@ func (qt *queryTask) PostExecute(ctx context.Context) error { return errors.New(reason) } - qt.result = &milvuspb.QueryResults{ - Status: &commonpb.Status{ + var err error + qt.result, err = mergeRetrieveResults(filterRetrieveResults) + if err != nil { + return err + } + + if len(qt.result.FieldsData) > 0 { + qt.result.Status = &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, - }, - FieldsData: make([]*schemapb.FieldData, 0), - } - - validRetrieveResults := make([]*internalpb.RetrieveResults, 0) - for _, partialRetrieveResult := range filterRetrieveResults { - if partialRetrieveResult.Ids != nil { - validRetrieveResults = append(validRetrieveResults, partialRetrieveResult) } - } - - for _, partialRetrieveResult := range validRetrieveResults { - copyQueryResultData(qt.result, partialRetrieveResult) - } - - if len(qt.result.FieldsData) == 0 { + } else { log.Info("Query result is nil", zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query")) - qt.result = &milvuspb.QueryResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_EmptyCollection, - Reason: reason, - }, + qt.result.Status = &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_EmptyCollection, + Reason: reason, } return nil }