mirror of https://github.com/milvus-io/milvus.git
Remove primary key duplicated query result on proxy (#10967)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10624/head
parent
385eed1d17
commit
70188990dc
|
@ -2283,43 +2283,43 @@ func (qt *queryTask) Execute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func copyQueryResultData(dst *milvuspb.QueryResults, src *internalpb.RetrieveResults) {
|
||||
// handles initialization, cannot use idx==0 since first result may be empty
|
||||
if len(dst.FieldsData) == 0 {
|
||||
dst.FieldsData = append(dst.FieldsData, src.FieldsData...)
|
||||
} else {
|
||||
for i, fieldData := range src.FieldsData {
|
||||
switch fieldType := fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
dstScalar := dst.FieldsData[i].GetScalars()
|
||||
switch srcScalar := fieldType.Scalars.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
dstScalar.GetBoolData().Data = append(dstScalar.GetBoolData().Data, srcScalar.BoolData.Data...)
|
||||
case *schemapb.ScalarField_IntData:
|
||||
dstScalar.GetIntData().Data = append(dstScalar.GetIntData().Data, srcScalar.IntData.Data...)
|
||||
case *schemapb.ScalarField_LongData:
|
||||
dstScalar.GetLongData().Data = append(dstScalar.GetLongData().Data, srcScalar.LongData.Data...)
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
dstScalar.GetFloatData().Data = append(dstScalar.GetFloatData().Data, srcScalar.FloatData.Data...)
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
dstScalar.GetDoubleData().Data = append(dstScalar.GetDoubleData().Data, srcScalar.DoubleData.Data...)
|
||||
default:
|
||||
log.Debug("Query received not supported data type", zap.String("field type", fieldData.Type.String()))
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
dstVector := dst.FieldsData[i].GetVectors()
|
||||
switch srcVector := fieldType.Vectors.Data.(type) {
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
dstVector.Data.(*schemapb.VectorField_BinaryVector).BinaryVector =
|
||||
append(dstVector.Data.(*schemapb.VectorField_BinaryVector).BinaryVector, srcVector.BinaryVector...)
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
dstVector.GetFloatVector().Data = append(dstVector.GetFloatVector().Data, srcVector.FloatVector.Data...)
|
||||
default:
|
||||
log.Debug("Query received not supported data type", zap.String("field type", fieldData.Type.String()))
|
||||
}
|
||||
func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
|
||||
var ret *milvuspb.QueryResults
|
||||
var skipDupCnt int64 = 0
|
||||
var idSet = make(map[int64]struct{})
|
||||
|
||||
// merge results and remove duplicates
|
||||
for _, rr := range retrieveResults {
|
||||
// skip empty result, it will break merge result
|
||||
if rr == nil || rr.Ids == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ret == nil {
|
||||
ret = &milvuspb.QueryResults{
|
||||
FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
|
||||
}
|
||||
}
|
||||
|
||||
if len(ret.FieldsData) != len(rr.FieldsData) {
|
||||
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
|
||||
}
|
||||
|
||||
for i, id := range rr.Ids.GetIntId().GetData() {
|
||||
if _, ok := idSet[id]; !ok {
|
||||
typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
|
||||
idSet[id] = struct{}{}
|
||||
} else {
|
||||
// primary keys duplicate
|
||||
skipDupCnt++
|
||||
}
|
||||
}
|
||||
}
|
||||
if skipDupCnt > 0 {
|
||||
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qt *queryTask) PostExecute(ctx context.Context) error {
|
||||
|
@ -2354,31 +2354,21 @@ func (qt *queryTask) PostExecute(ctx context.Context) error {
|
|||
return errors.New(reason)
|
||||
}
|
||||
|
||||
qt.result = &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
var err error
|
||||
qt.result, err = mergeRetrieveResults(filterRetrieveResults)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(qt.result.FieldsData) > 0 {
|
||||
qt.result.Status = &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
}
|
||||
|
||||
validRetrieveResults := make([]*internalpb.RetrieveResults, 0)
|
||||
for _, partialRetrieveResult := range filterRetrieveResults {
|
||||
if partialRetrieveResult.Ids != nil {
|
||||
validRetrieveResults = append(validRetrieveResults, partialRetrieveResult)
|
||||
}
|
||||
}
|
||||
|
||||
for _, partialRetrieveResult := range validRetrieveResults {
|
||||
copyQueryResultData(qt.result, partialRetrieveResult)
|
||||
}
|
||||
|
||||
if len(qt.result.FieldsData) == 0 {
|
||||
} else {
|
||||
log.Info("Query result is nil", zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
||||
qt.result = &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_EmptyCollection,
|
||||
Reason: reason,
|
||||
},
|
||||
qt.result.Status = &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_EmptyCollection,
|
||||
Reason: reason,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue