diff --git a/internal/proxy/task.go b/internal/proxy/task.go index f27597129b..bb8a2f2179 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1753,124 +1753,6 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []in return sel } -func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchResultData, idx int64) { - for i, fieldData := range src.FieldsData { - switch fieldType := fieldData.Field.(type) { - case *schemapb.FieldData_Scalars: - if dst.FieldsData[i] == nil || dst.FieldsData[i].GetScalars() == nil { - dst.FieldsData[i] = &schemapb.FieldData{ - FieldName: fieldData.FieldName, - FieldId: fieldData.FieldId, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{}, - }, - } - } - switch scalarType := fieldType.Scalars.Data.(type) { - case *schemapb.ScalarField_BoolData: - if dst.FieldsData[i].GetScalars().GetBoolData() == nil { - dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: []bool{scalarType.BoolData.Data[idx]}, - }, - }, - } - } else { - dst.FieldsData[i].GetScalars().GetBoolData().Data = append(dst.FieldsData[i].GetScalars().GetBoolData().Data, scalarType.BoolData.Data[idx]) - } - case *schemapb.ScalarField_IntData: - if dst.FieldsData[i].GetScalars().GetIntData() == nil { - dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: []int32{scalarType.IntData.Data[idx]}, - }, - }, - } - } else { - dst.FieldsData[i].GetScalars().GetIntData().Data = append(dst.FieldsData[i].GetScalars().GetIntData().Data, scalarType.IntData.Data[idx]) - } - case *schemapb.ScalarField_LongData: - if dst.FieldsData[i].GetScalars().GetLongData() == nil { - dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: []int64{scalarType.LongData.Data[idx]}, - }, - }, - } - } else { - dst.FieldsData[i].GetScalars().GetLongData().Data = append(dst.FieldsData[i].GetScalars().GetLongData().Data, scalarType.LongData.Data[idx]) - } - case *schemapb.ScalarField_FloatData: - if dst.FieldsData[i].GetScalars().GetFloatData() == nil { - dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{ - Data: &schemapb.ScalarField_FloatData{ - FloatData: &schemapb.FloatArray{ - Data: []float32{scalarType.FloatData.Data[idx]}, - }, - }, - } - } else { - dst.FieldsData[i].GetScalars().GetFloatData().Data = append(dst.FieldsData[i].GetScalars().GetFloatData().Data, scalarType.FloatData.Data[idx]) - } - case *schemapb.ScalarField_DoubleData: - if dst.FieldsData[i].GetScalars().GetDoubleData() == nil { - dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{ - Data: &schemapb.ScalarField_DoubleData{ - DoubleData: &schemapb.DoubleArray{ - Data: []float64{scalarType.DoubleData.Data[idx]}, - }, - }, - } - } else { - dst.FieldsData[i].GetScalars().GetDoubleData().Data = append(dst.FieldsData[i].GetScalars().GetDoubleData().Data, scalarType.DoubleData.Data[idx]) - } - default: - log.Debug("Not supported field type", zap.String("field type", fieldData.Type.String())) - } - case *schemapb.FieldData_Vectors: - dim := fieldType.Vectors.Dim - if dst.FieldsData[i] == nil || dst.FieldsData[i].GetVectors() == nil { - dst.FieldsData[i] = &schemapb.FieldData{ - FieldName: fieldData.FieldName, - FieldId: fieldData.FieldId, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: dim, - }, - }, - } - } - switch vectorType := fieldType.Vectors.Data.(type) { - case *schemapb.VectorField_BinaryVector: - if dst.FieldsData[i].GetVectors().GetBinaryVector() == nil { - bvec := &schemapb.VectorField_BinaryVector{ - BinaryVector: vectorType.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)], - } - dst.FieldsData[i].GetVectors().Data = bvec - } else { - dst.FieldsData[i].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(dst.FieldsData[i].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[idx*(dim/8):(idx+1)*(dim/8)]...) - } - case *schemapb.VectorField_FloatVector: - if dst.FieldsData[i].GetVectors().GetFloatVector() == nil { - fvec := &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: vectorType.FloatVector.Data[idx*dim : (idx+1)*dim], - }, - } - dst.FieldsData[i].GetVectors().Data = fvec - } else { - dst.FieldsData[i].GetVectors().GetFloatVector().Data = append(dst.FieldsData[i].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[idx*dim:(idx+1)*dim]...) - } - default: - log.Debug("Not supported field type", zap.String("field type", fieldData.Type.String())) - } - } - } -} - //func printSearchResultData(data *schemapb.SearchResultData, header string) { // size := len(data.Ids.GetIntId().Data) // if size != len(data.Scores) { @@ -1949,7 +1831,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in // remove duplicates if math.Abs(float64(score)-float64(prevScore)) > 0.00001 { - copySearchResultData(ret.Results, searchResultData[sel], idx) + typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx) ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id) ret.Results.Scores = append(ret.Results.Scores, score) prevScore = score @@ -1961,7 +1843,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in // e2: [101, 0.99] ==> not duplicated, should keep // e3: [100, 0.99] ==> duplicated, should remove if _, ok := prevIDSet[id]; !ok { - copySearchResultData(ret.Results, searchResultData[sel], idx) + typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx) ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id) ret.Results.Scores = append(ret.Results.Scores, score) prevIDSet[id] = struct{}{} diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 0f40bbe810..426e669292 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -16,6 +16,9 @@ import ( "fmt" "strconv" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/schemapb" ) @@ -181,3 +184,113 @@ func IsBoolType(dataType schemapb.DataType) bool { return false } } + +// AppendFieldData appends fields data of specified index from src to dst +func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx int64) { + for i, fieldData := range src { + switch fieldType := fieldData.Field.(type) { + case *schemapb.FieldData_Scalars: + if dst[i] == nil || dst[i].GetScalars() == nil { + dst[i] = &schemapb.FieldData{ + FieldName: fieldData.FieldName, + FieldId: fieldData.FieldId, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{}, + }, + } + } + dstScalar := dst[i].GetScalars() + switch srcScalar := fieldType.Scalars.Data.(type) { + case *schemapb.ScalarField_BoolData: + if dstScalar.GetBoolData() == nil { + dstScalar.Data = &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{srcScalar.BoolData.Data[idx]}, + }, + } + } else { + dstScalar.GetBoolData().Data = append(dstScalar.GetBoolData().Data, srcScalar.BoolData.Data[idx]) + } + case *schemapb.ScalarField_IntData: + if dstScalar.GetIntData() == nil { + dstScalar.Data = &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{srcScalar.IntData.Data[idx]}, + }, + } + } else { + dstScalar.GetIntData().Data = append(dstScalar.GetIntData().Data, srcScalar.IntData.Data[idx]) + } + case *schemapb.ScalarField_LongData: + if dstScalar.GetLongData() == nil { + dstScalar.Data = &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{srcScalar.LongData.Data[idx]}, + }, + } + } else { + dstScalar.GetLongData().Data = append(dstScalar.GetLongData().Data, srcScalar.LongData.Data[idx]) + } + case *schemapb.ScalarField_FloatData: + if dstScalar.GetFloatData() == nil { + dstScalar.Data = &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{srcScalar.FloatData.Data[idx]}, + }, + } + } else { + dstScalar.GetFloatData().Data = append(dstScalar.GetFloatData().Data, srcScalar.FloatData.Data[idx]) + } + case *schemapb.ScalarField_DoubleData: + if dstScalar.GetDoubleData() == nil { + dstScalar.Data = &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{srcScalar.DoubleData.Data[idx]}, + }, + } + } else { + dstScalar.GetDoubleData().Data = append(dstScalar.GetDoubleData().Data, srcScalar.DoubleData.Data[idx]) + } + default: + log.Error("Not supported field type", zap.String("field type", fieldData.Type.String())) + } + case *schemapb.FieldData_Vectors: + dim := fieldType.Vectors.Dim + if dst[i] == nil || dst[i].GetVectors() == nil { + dst[i] = &schemapb.FieldData{ + FieldName: fieldData.FieldName, + FieldId: fieldData.FieldId, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + }, + }, + } + } + dstVector := dst[i].GetVectors() + switch srcVector := fieldType.Vectors.Data.(type) { + case *schemapb.VectorField_BinaryVector: + if dstVector.GetBinaryVector() == nil { + dstVector.Data = &schemapb.VectorField_BinaryVector{ + BinaryVector: srcVector.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)], + } + } else { + dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector) + dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector[idx*(dim/8):(idx+1)*(dim/8)]...) + } + case *schemapb.VectorField_FloatVector: + if dstVector.GetFloatVector() == nil { + dstVector.Data = &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: srcVector.FloatVector.Data[idx*dim : (idx+1)*dim], + }, + } + } else { + dstVector.GetFloatVector().Data = append(dstVector.GetFloatVector().Data, srcVector.FloatVector.Data[idx*dim:(idx+1)*dim]...) + } + default: + log.Error("Not supported field type", zap.String("field type", fieldData.Type.String())) + } + } + } +} diff --git a/internal/util/typeutil/schema_test.go b/internal/util/typeutil/schema_test.go index deec971d24..0ae707576d 100644 --- a/internal/util/typeutil/schema_test.go +++ b/internal/util/typeutil/schema_test.go @@ -14,6 +14,10 @@ package typeutil import ( "testing" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/stretchr/testify/assert" @@ -324,3 +328,174 @@ func TestSchema_invalid(t *testing.T) { assert.NotNil(t, err) }) } + +func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData { + var fieldData *schemapb.FieldData + switch fieldType { + case schemapb.DataType_Bool: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Bool, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: fieldValue.([]bool), + }, + }, + }, + }, + FieldId: fieldID, + } + case schemapb.DataType_Int32: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Int32, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: fieldValue.([]int32), + }, + }, + }, + }, + FieldId: fieldID, + } + case schemapb.DataType_Int64: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: fieldValue.([]int64), + }, + }, + }, + }, + FieldId: fieldID, + } + case schemapb.DataType_Float: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Float, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: fieldValue.([]float32), + }, + }, + }, + }, + FieldId: fieldID, + } + case schemapb.DataType_Double: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Double, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: fieldValue.([]float64), + }, + }, + }, + }, + FieldId: fieldID, + } + case schemapb.DataType_BinaryVector: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: fieldValue.([]byte), + }, + }, + }, + FieldId: fieldID, + } + case schemapb.DataType_FloatVector: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: fieldValue.([]float32), + }, + }, + }, + }, + FieldId: fieldID, + } + default: + log.Error("not supported field type", zap.String("field type", fieldType.String())) + } + + return fieldData +} + +func TestAppendFieldData(t *testing.T) { + const ( + BoolFieldName = "BoolField" + Int32FieldName = "Int32Field" + Int64FieldName = "Int64Field" + FloatFieldName = "FloatField" + DoubleFieldName = "DoubleField" + BinaryVectorFieldName = "BinaryVectorField" + FloatVectorFieldName = "FloatVectorField" + BoolFieldID = common.StartOfUserFieldID + 1 + Int32FieldID = common.StartOfUserFieldID + 2 + Int64FieldID = common.StartOfUserFieldID + 3 + FloatFieldID = common.StartOfUserFieldID + 4 + DoubleFieldID = common.StartOfUserFieldID + 5 + BinaryVectorFieldID = common.StartOfUserFieldID + 6 + FloatVectorFieldID = common.StartOfUserFieldID + 7 + ) + BoolArray := []bool{true, false} + Int32Array := []int32{1, 2} + Int64Array := []int64{11, 22} + FloatArray := []float32{1.0, 2.0} + DoubleArray := []float64{11.0, 22.0} + BinaryVector := []byte{0x12, 0x34} + FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + + result := make([]*schemapb.FieldData, 7) + var fieldDataArray1 []*schemapb.FieldData + fieldDataArray1 = append(fieldDataArray1, genFieldData(BoolFieldName, BoolFieldID, schemapb.DataType_Bool, BoolArray[0:1], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(Int32FieldName, Int32FieldID, schemapb.DataType_Int32, Int32Array[0:1], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:1], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatFieldName, FloatFieldID, schemapb.DataType_Float, FloatArray[0:1], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(DoubleFieldName, DoubleFieldID, schemapb.DataType_Double, DoubleArray[0:1], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(BinaryVectorFieldName, BinaryVectorFieldID, schemapb.DataType_BinaryVector, BinaryVector[0:1], 8)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:8], 8)) + + var fieldDataArray2 []*schemapb.FieldData + fieldDataArray2 = append(fieldDataArray2, genFieldData(BoolFieldName, BoolFieldID, schemapb.DataType_Bool, BoolArray[1:2], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(Int32FieldName, Int32FieldID, schemapb.DataType_Int32, Int32Array[1:2], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[1:2], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatFieldName, FloatFieldID, schemapb.DataType_Float, FloatArray[1:2], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(DoubleFieldName, DoubleFieldID, schemapb.DataType_Double, DoubleArray[1:2], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(BinaryVectorFieldName, BinaryVectorFieldID, schemapb.DataType_BinaryVector, BinaryVector[1:2], 8)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[8:16], 8)) + + AppendFieldData(result, fieldDataArray1, 0) + AppendFieldData(result, fieldDataArray2, 0) + + assert.Equal(t, BoolArray, result[0].GetScalars().GetBoolData().Data) + assert.Equal(t, Int32Array, result[1].GetScalars().GetIntData().Data) + assert.Equal(t, Int64Array, result[2].GetScalars().GetLongData().Data) + assert.Equal(t, FloatArray, result[3].GetScalars().GetFloatData().Data) + assert.Equal(t, DoubleArray, result[4].GetScalars().GetDoubleData().Data) + assert.Equal(t, BinaryVector, result[5].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector) + assert.Equal(t, FloatVector, result[6].GetVectors().GetFloatVector().Data) +}