mirror of https://github.com/milvus-io/milvus.git
Optimize proxy reduce search result data (#10327)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10344/head
parent
86655f2221
commit
b099179ac0
|
@ -1736,6 +1736,163 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) (res []*sche
|
|||
return
|
||||
}
|
||||
|
||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
||||
if data.NumQueries != nq {
|
||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||
}
|
||||
if data.TopK != topk {
|
||||
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
||||
}
|
||||
if len(data.Ids.GetIntId().Data) != (int)(nq*topk) {
|
||||
return fmt.Errorf("search result's id length %d invalid", len(data.Ids.GetIntId().Data))
|
||||
}
|
||||
if len(data.Scores) != (int)(nq*topk) {
|
||||
return fmt.Errorf("search result's score length %d invalid", len(data.Scores))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, idx int64) int {
|
||||
sel := -1
|
||||
maxDistance := minFloat32
|
||||
for q, loc := range offsets { // query num, the number of ways to merge
|
||||
if loc >= topk {
|
||||
continue
|
||||
}
|
||||
offset := idx*topk + loc
|
||||
id := dataArray[q].Ids.GetIntId().Data[offset]
|
||||
if id != -1 {
|
||||
distance := dataArray[q].Scores[offset]
|
||||
if distance > maxDistance {
|
||||
sel = q
|
||||
maxDistance = distance
|
||||
}
|
||||
}
|
||||
}
|
||||
return sel
|
||||
}
|
||||
|
||||
func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchResultData, idx int64) error {
|
||||
for i, fieldData := range src.FieldsData {
|
||||
switch fieldType := fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
if dst.FieldsData[i] == nil || dst.FieldsData[i].GetScalars() == nil {
|
||||
dst.FieldsData[i] = &schemapb.FieldData{
|
||||
FieldName: fieldData.FieldName,
|
||||
FieldId: fieldData.FieldId,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{},
|
||||
},
|
||||
}
|
||||
}
|
||||
switch scalarType := fieldType.Scalars.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
if dst.FieldsData[i].GetScalars().GetBoolData() == nil {
|
||||
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{scalarType.BoolData.Data[idx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
dst.FieldsData[i].GetScalars().GetBoolData().Data = append(dst.FieldsData[i].GetScalars().GetBoolData().Data, scalarType.BoolData.Data[idx])
|
||||
}
|
||||
case *schemapb.ScalarField_IntData:
|
||||
if dst.FieldsData[i].GetScalars().GetIntData() == nil {
|
||||
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{scalarType.IntData.Data[idx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
dst.FieldsData[i].GetScalars().GetIntData().Data = append(dst.FieldsData[i].GetScalars().GetIntData().Data, scalarType.IntData.Data[idx])
|
||||
}
|
||||
case *schemapb.ScalarField_LongData:
|
||||
if dst.FieldsData[i].GetScalars().GetLongData() == nil {
|
||||
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{scalarType.LongData.Data[idx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
dst.FieldsData[i].GetScalars().GetLongData().Data = append(dst.FieldsData[i].GetScalars().GetLongData().Data, scalarType.LongData.Data[idx])
|
||||
}
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
if dst.FieldsData[i].GetScalars().GetFloatData() == nil {
|
||||
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{scalarType.FloatData.Data[idx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
dst.FieldsData[i].GetScalars().GetFloatData().Data = append(dst.FieldsData[i].GetScalars().GetFloatData().Data, scalarType.FloatData.Data[idx])
|
||||
}
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
if dst.FieldsData[i].GetScalars().GetDoubleData() == nil {
|
||||
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: []float64{scalarType.DoubleData.Data[idx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
dst.FieldsData[i].GetScalars().GetDoubleData().Data = append(dst.FieldsData[i].GetScalars().GetDoubleData().Data, scalarType.DoubleData.Data[idx])
|
||||
}
|
||||
default:
|
||||
log.Debug("Not supported field type", zap.String("field type", fieldData.Type.String()))
|
||||
return fmt.Errorf("not supported field type: %s", fieldData.Type.String())
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
dim := fieldType.Vectors.Dim
|
||||
if dst.FieldsData[i] == nil || dst.FieldsData[i].GetVectors() == nil {
|
||||
dst.FieldsData[i] = &schemapb.FieldData{
|
||||
FieldName: fieldData.FieldName,
|
||||
FieldId: fieldData.FieldId,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
switch vectorType := fieldType.Vectors.Data.(type) {
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
if dst.FieldsData[i].GetVectors().GetBinaryVector() == nil {
|
||||
bvec := &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: vectorType.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)],
|
||||
}
|
||||
dst.FieldsData[i].GetVectors().Data = bvec
|
||||
} else {
|
||||
dst.FieldsData[i].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(dst.FieldsData[i].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[idx*(dim/8):(idx+1)*(dim/8)]...)
|
||||
}
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
if dst.FieldsData[i].GetVectors().GetFloatVector() == nil {
|
||||
fvec := &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: vectorType.FloatVector.Data[idx*dim : (idx+1)*dim],
|
||||
},
|
||||
}
|
||||
dst.FieldsData[i].GetVectors().Data = fvec
|
||||
} else {
|
||||
dst.FieldsData[i].GetVectors().GetFloatVector().Data = append(dst.FieldsData[i].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[idx*dim:(idx+1)*dim]...)
|
||||
}
|
||||
default:
|
||||
log.Debug("Not supported field type", zap.String("field type", fieldData.Type.String()))
|
||||
return fmt.Errorf("not supported field type: %s", fieldData.Type.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
|
||||
nq int64, topk int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
|
||||
|
||||
|
@ -1771,173 +1928,34 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Any("len(FieldsData)", len(sData.FieldsData)))
|
||||
if sData.NumQueries != nq {
|
||||
return ret, fmt.Errorf("search result's nq(%d) mis-match with %d", sData.NumQueries, nq)
|
||||
}
|
||||
if sData.TopK != topk {
|
||||
return ret, fmt.Errorf("search result's topk(%d) mis-match with %d", sData.TopK, topk)
|
||||
}
|
||||
if len(sData.Ids.GetIntId().Data) != (int)(nq*topk) {
|
||||
return ret, fmt.Errorf("search result's id length %d invalid", len(sData.Ids.GetIntId().Data))
|
||||
}
|
||||
if len(sData.Scores) != (int)(nq*topk) {
|
||||
return ret, fmt.Errorf("search result's score length %d invalid", len(sData.Scores))
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yukun): Use parallel function
|
||||
var realTopK int64 = -1
|
||||
var idx int64
|
||||
var j int64
|
||||
for idx = 0; idx < nq; idx++ {
|
||||
locs := make([]int64, availableQueryNodeNum)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
offsets := make([]int64, availableQueryNodeNum)
|
||||
|
||||
j = 0
|
||||
for ; j < topk; j++ {
|
||||
choice, maxDistance := -1, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= topk {
|
||||
continue
|
||||
}
|
||||
curIdx := idx*topk + loc
|
||||
id := searchResultData[q].Ids.GetIntId().Data[curIdx]
|
||||
if id != -1 {
|
||||
distance := searchResultData[q].Scores[curIdx]
|
||||
if distance > maxDistance {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
}
|
||||
}
|
||||
}
|
||||
if choice == -1 {
|
||||
var j int64
|
||||
for j = 0; j < topk; j++ {
|
||||
sel := selectSearchResultData(searchResultData, offsets, topk, i)
|
||||
if sel == -1 {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
curIdx := idx*topk + choiceOffset
|
||||
offset := offsets[sel]
|
||||
idx := i*topk + offset
|
||||
|
||||
// ignore invalid search result
|
||||
id := searchResultData[choice].Ids.GetIntId().Data[curIdx]
|
||||
id := searchResultData[sel].Ids.GetIntId().Data[idx]
|
||||
if id == -1 {
|
||||
continue
|
||||
}
|
||||
copySearchResultData(ret.Results, searchResultData[sel], idx)
|
||||
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
|
||||
// TODO(yukun): Process searchResultData.FieldsData
|
||||
for k, fieldData := range searchResultData[choice].FieldsData {
|
||||
switch fieldType := fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
if ret.Results.FieldsData[k] == nil || ret.Results.FieldsData[k].GetScalars() == nil {
|
||||
ret.Results.FieldsData[k] = &schemapb.FieldData{
|
||||
FieldName: fieldData.FieldName,
|
||||
FieldId: fieldData.FieldId,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{},
|
||||
},
|
||||
}
|
||||
}
|
||||
switch scalarType := fieldType.Scalars.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
if ret.Results.FieldsData[k].GetScalars().GetBoolData() == nil {
|
||||
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{scalarType.BoolData.Data[curIdx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetScalars().GetBoolData().Data = append(ret.Results.FieldsData[k].GetScalars().GetBoolData().Data, scalarType.BoolData.Data[curIdx])
|
||||
}
|
||||
case *schemapb.ScalarField_IntData:
|
||||
if ret.Results.FieldsData[k].GetScalars().GetIntData() == nil {
|
||||
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{scalarType.IntData.Data[curIdx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetScalars().GetIntData().Data = append(ret.Results.FieldsData[k].GetScalars().GetIntData().Data, scalarType.IntData.Data[curIdx])
|
||||
}
|
||||
case *schemapb.ScalarField_LongData:
|
||||
if ret.Results.FieldsData[k].GetScalars().GetLongData() == nil {
|
||||
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{scalarType.LongData.Data[curIdx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetScalars().GetLongData().Data = append(ret.Results.FieldsData[k].GetScalars().GetLongData().Data, scalarType.LongData.Data[curIdx])
|
||||
}
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
if ret.Results.FieldsData[k].GetScalars().GetFloatData() == nil {
|
||||
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{scalarType.FloatData.Data[curIdx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetScalars().GetFloatData().Data = append(ret.Results.FieldsData[k].GetScalars().GetFloatData().Data, scalarType.FloatData.Data[curIdx])
|
||||
}
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
if ret.Results.FieldsData[k].GetScalars().GetDoubleData() == nil {
|
||||
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: []float64{scalarType.DoubleData.Data[curIdx]},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetScalars().GetDoubleData().Data = append(ret.Results.FieldsData[k].GetScalars().GetDoubleData().Data, scalarType.DoubleData.Data[curIdx])
|
||||
}
|
||||
default:
|
||||
log.Debug("Not supported field type")
|
||||
return nil, fmt.Errorf("not supported field type: %s", fieldData.Type.String())
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
dim := fieldType.Vectors.Dim
|
||||
if ret.Results.FieldsData[k] == nil || ret.Results.FieldsData[k].GetVectors() == nil {
|
||||
ret.Results.FieldsData[k] = &schemapb.FieldData{
|
||||
FieldName: fieldData.FieldName,
|
||||
FieldId: fieldData.FieldId,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
switch vectorType := fieldType.Vectors.Data.(type) {
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
if ret.Results.FieldsData[k].GetVectors().GetBinaryVector() == nil {
|
||||
bvec := &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: vectorType.BinaryVector[curIdx*(dim/8) : (curIdx+1)*(dim/8)],
|
||||
}
|
||||
ret.Results.FieldsData[k].GetVectors().Data = bvec
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*(dim/8):(curIdx+1)*(dim/8)]...)
|
||||
}
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
if ret.Results.FieldsData[k].GetVectors().GetFloatVector() == nil {
|
||||
fvec := &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: vectorType.FloatVector.Data[curIdx*dim : (curIdx+1)*dim],
|
||||
},
|
||||
}
|
||||
ret.Results.FieldsData[k].GetVectors().Data = fvec
|
||||
} else {
|
||||
ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*dim:(curIdx+1)*dim]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset])
|
||||
locs[choice]++
|
||||
ret.Results.Scores = append(ret.Results.Scores, searchResultData[sel].Scores[i*topk+offset])
|
||||
offsets[sel]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
|
|
Loading…
Reference in New Issue