Optimize proxy reduce search result data (#10327)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/10344/head
Cai Yudong 2021-10-21 10:46:36 +08:00 committed by GitHub
parent 86655f2221
commit b099179ac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 171 additions and 153 deletions

View File

@ -1736,6 +1736,163 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) (res []*sche
return
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
}
if data.TopK != topk {
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
}
if len(data.Ids.GetIntId().Data) != (int)(nq*topk) {
return fmt.Errorf("search result's id length %d invalid", len(data.Ids.GetIntId().Data))
}
if len(data.Scores) != (int)(nq*topk) {
return fmt.Errorf("search result's score length %d invalid", len(data.Scores))
}
return nil
}
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, idx int64) int {
sel := -1
maxDistance := minFloat32
for q, loc := range offsets { // query num, the number of ways to merge
if loc >= topk {
continue
}
offset := idx*topk + loc
id := dataArray[q].Ids.GetIntId().Data[offset]
if id != -1 {
distance := dataArray[q].Scores[offset]
if distance > maxDistance {
sel = q
maxDistance = distance
}
}
}
return sel
}
func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchResultData, idx int64) error {
for i, fieldData := range src.FieldsData {
switch fieldType := fieldData.Field.(type) {
case *schemapb.FieldData_Scalars:
if dst.FieldsData[i] == nil || dst.FieldsData[i].GetScalars() == nil {
dst.FieldsData[i] = &schemapb.FieldData{
FieldName: fieldData.FieldName,
FieldId: fieldData.FieldId,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{},
},
}
}
switch scalarType := fieldType.Scalars.Data.(type) {
case *schemapb.ScalarField_BoolData:
if dst.FieldsData[i].GetScalars().GetBoolData() == nil {
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{scalarType.BoolData.Data[idx]},
},
},
}
} else {
dst.FieldsData[i].GetScalars().GetBoolData().Data = append(dst.FieldsData[i].GetScalars().GetBoolData().Data, scalarType.BoolData.Data[idx])
}
case *schemapb.ScalarField_IntData:
if dst.FieldsData[i].GetScalars().GetIntData() == nil {
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{scalarType.IntData.Data[idx]},
},
},
}
} else {
dst.FieldsData[i].GetScalars().GetIntData().Data = append(dst.FieldsData[i].GetScalars().GetIntData().Data, scalarType.IntData.Data[idx])
}
case *schemapb.ScalarField_LongData:
if dst.FieldsData[i].GetScalars().GetLongData() == nil {
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{scalarType.LongData.Data[idx]},
},
},
}
} else {
dst.FieldsData[i].GetScalars().GetLongData().Data = append(dst.FieldsData[i].GetScalars().GetLongData().Data, scalarType.LongData.Data[idx])
}
case *schemapb.ScalarField_FloatData:
if dst.FieldsData[i].GetScalars().GetFloatData() == nil {
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: []float32{scalarType.FloatData.Data[idx]},
},
},
}
} else {
dst.FieldsData[i].GetScalars().GetFloatData().Data = append(dst.FieldsData[i].GetScalars().GetFloatData().Data, scalarType.FloatData.Data[idx])
}
case *schemapb.ScalarField_DoubleData:
if dst.FieldsData[i].GetScalars().GetDoubleData() == nil {
dst.FieldsData[i].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: []float64{scalarType.DoubleData.Data[idx]},
},
},
}
} else {
dst.FieldsData[i].GetScalars().GetDoubleData().Data = append(dst.FieldsData[i].GetScalars().GetDoubleData().Data, scalarType.DoubleData.Data[idx])
}
default:
log.Debug("Not supported field type", zap.String("field type", fieldData.Type.String()))
return fmt.Errorf("not supported field type: %s", fieldData.Type.String())
}
case *schemapb.FieldData_Vectors:
dim := fieldType.Vectors.Dim
if dst.FieldsData[i] == nil || dst.FieldsData[i].GetVectors() == nil {
dst.FieldsData[i] = &schemapb.FieldData{
FieldName: fieldData.FieldName,
FieldId: fieldData.FieldId,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
},
},
}
}
switch vectorType := fieldType.Vectors.Data.(type) {
case *schemapb.VectorField_BinaryVector:
if dst.FieldsData[i].GetVectors().GetBinaryVector() == nil {
bvec := &schemapb.VectorField_BinaryVector{
BinaryVector: vectorType.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)],
}
dst.FieldsData[i].GetVectors().Data = bvec
} else {
dst.FieldsData[i].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(dst.FieldsData[i].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[idx*(dim/8):(idx+1)*(dim/8)]...)
}
case *schemapb.VectorField_FloatVector:
if dst.FieldsData[i].GetVectors().GetFloatVector() == nil {
fvec := &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: vectorType.FloatVector.Data[idx*dim : (idx+1)*dim],
},
}
dst.FieldsData[i].GetVectors().Data = fvec
} else {
dst.FieldsData[i].GetVectors().GetFloatVector().Data = append(dst.FieldsData[i].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[idx*dim:(idx+1)*dim]...)
}
default:
log.Debug("Not supported field type", zap.String("field type", fieldData.Type.String()))
return fmt.Errorf("not supported field type: %s", fieldData.Type.String())
}
}
}
return nil
}
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
@ -1771,173 +1928,34 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
zap.Any("len(FieldsData)", len(sData.FieldsData)))
if sData.NumQueries != nq {
return ret, fmt.Errorf("search result's nq(%d) mis-match with %d", sData.NumQueries, nq)
}
if sData.TopK != topk {
return ret, fmt.Errorf("search result's topk(%d) mis-match with %d", sData.TopK, topk)
}
if len(sData.Ids.GetIntId().Data) != (int)(nq*topk) {
return ret, fmt.Errorf("search result's id length %d invalid", len(sData.Ids.GetIntId().Data))
}
if len(sData.Scores) != (int)(nq*topk) {
return ret, fmt.Errorf("search result's score length %d invalid", len(sData.Scores))
if err := checkSearchResultData(sData, nq, topk); err != nil {
return ret, err
}
}
// TODO(yukun): Use parallel function
var realTopK int64 = -1
var idx int64
var j int64
for idx = 0; idx < nq; idx++ {
locs := make([]int64, availableQueryNodeNum)
for i := int64(0); i < nq; i++ {
offsets := make([]int64, availableQueryNodeNum)
j = 0
for ; j < topk; j++ {
choice, maxDistance := -1, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
if loc >= topk {
continue
}
curIdx := idx*topk + loc
id := searchResultData[q].Ids.GetIntId().Data[curIdx]
if id != -1 {
distance := searchResultData[q].Scores[curIdx]
if distance > maxDistance {
choice = q
maxDistance = distance
}
}
}
if choice == -1 {
var j int64
for j = 0; j < topk; j++ {
sel := selectSearchResultData(searchResultData, offsets, topk, i)
if sel == -1 {
break
}
choiceOffset := locs[choice]
curIdx := idx*topk + choiceOffset
offset := offsets[sel]
idx := i*topk + offset
// ignore invalid search result
id := searchResultData[choice].Ids.GetIntId().Data[curIdx]
id := searchResultData[sel].Ids.GetIntId().Data[idx]
if id == -1 {
continue
}
copySearchResultData(ret.Results, searchResultData[sel], idx)
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
// TODO(yukun): Process searchResultData.FieldsData
for k, fieldData := range searchResultData[choice].FieldsData {
switch fieldType := fieldData.Field.(type) {
case *schemapb.FieldData_Scalars:
if ret.Results.FieldsData[k] == nil || ret.Results.FieldsData[k].GetScalars() == nil {
ret.Results.FieldsData[k] = &schemapb.FieldData{
FieldName: fieldData.FieldName,
FieldId: fieldData.FieldId,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{},
},
}
}
switch scalarType := fieldType.Scalars.Data.(type) {
case *schemapb.ScalarField_BoolData:
if ret.Results.FieldsData[k].GetScalars().GetBoolData() == nil {
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{scalarType.BoolData.Data[curIdx]},
},
},
}
} else {
ret.Results.FieldsData[k].GetScalars().GetBoolData().Data = append(ret.Results.FieldsData[k].GetScalars().GetBoolData().Data, scalarType.BoolData.Data[curIdx])
}
case *schemapb.ScalarField_IntData:
if ret.Results.FieldsData[k].GetScalars().GetIntData() == nil {
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{scalarType.IntData.Data[curIdx]},
},
},
}
} else {
ret.Results.FieldsData[k].GetScalars().GetIntData().Data = append(ret.Results.FieldsData[k].GetScalars().GetIntData().Data, scalarType.IntData.Data[curIdx])
}
case *schemapb.ScalarField_LongData:
if ret.Results.FieldsData[k].GetScalars().GetLongData() == nil {
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{scalarType.LongData.Data[curIdx]},
},
},
}
} else {
ret.Results.FieldsData[k].GetScalars().GetLongData().Data = append(ret.Results.FieldsData[k].GetScalars().GetLongData().Data, scalarType.LongData.Data[curIdx])
}
case *schemapb.ScalarField_FloatData:
if ret.Results.FieldsData[k].GetScalars().GetFloatData() == nil {
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: []float32{scalarType.FloatData.Data[curIdx]},
},
},
}
} else {
ret.Results.FieldsData[k].GetScalars().GetFloatData().Data = append(ret.Results.FieldsData[k].GetScalars().GetFloatData().Data, scalarType.FloatData.Data[curIdx])
}
case *schemapb.ScalarField_DoubleData:
if ret.Results.FieldsData[k].GetScalars().GetDoubleData() == nil {
ret.Results.FieldsData[k].Field.(*schemapb.FieldData_Scalars).Scalars = &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: []float64{scalarType.DoubleData.Data[curIdx]},
},
},
}
} else {
ret.Results.FieldsData[k].GetScalars().GetDoubleData().Data = append(ret.Results.FieldsData[k].GetScalars().GetDoubleData().Data, scalarType.DoubleData.Data[curIdx])
}
default:
log.Debug("Not supported field type")
return nil, fmt.Errorf("not supported field type: %s", fieldData.Type.String())
}
case *schemapb.FieldData_Vectors:
dim := fieldType.Vectors.Dim
if ret.Results.FieldsData[k] == nil || ret.Results.FieldsData[k].GetVectors() == nil {
ret.Results.FieldsData[k] = &schemapb.FieldData{
FieldName: fieldData.FieldName,
FieldId: fieldData.FieldId,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
},
},
}
}
switch vectorType := fieldType.Vectors.Data.(type) {
case *schemapb.VectorField_BinaryVector:
if ret.Results.FieldsData[k].GetVectors().GetBinaryVector() == nil {
bvec := &schemapb.VectorField_BinaryVector{
BinaryVector: vectorType.BinaryVector[curIdx*(dim/8) : (curIdx+1)*(dim/8)],
}
ret.Results.FieldsData[k].GetVectors().Data = bvec
} else {
ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*(dim/8):(curIdx+1)*(dim/8)]...)
}
case *schemapb.VectorField_FloatVector:
if ret.Results.FieldsData[k].GetVectors().GetFloatVector() == nil {
fvec := &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: vectorType.FloatVector.Data[curIdx*dim : (curIdx+1)*dim],
},
}
ret.Results.FieldsData[k].GetVectors().Data = fvec
} else {
ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*dim:(curIdx+1)*dim]...)
}
}
}
}
ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset])
locs[choice]++
ret.Results.Scores = append(ret.Results.Scores, searchResultData[sel].Scores[i*topk+offset])
offsets[sel]++
}
if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))