mirror of https://github.com/milvus-io/milvus.git
Optimize decodeSearchResults (#10728)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10733/head
parent
555912f40e
commit
c51155a542
|
@ -1690,39 +1690,33 @@ func (st *searchTask) Execute(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
||||||
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfSearchResults", len(searchResults)))
|
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
||||||
|
log.Debug("decodeSearchResults", zap.Any("lenOfSearchResults", len(searchResults)))
|
||||||
|
|
||||||
results := make([]*schemapb.SearchResultData, 0)
|
results := make([]*schemapb.SearchResultData, 0)
|
||||||
// necessary to parallel this?
|
// necessary to parallel this?
|
||||||
for i, partialSearchResult := range searchResults {
|
for i, partialSearchResult := range searchResults {
|
||||||
log.Debug("decodeSearchResultsSerial", zap.Any("i", i), zap.Any("len(SlicedBob)", len(partialSearchResult.SlicedBlob)))
|
log.Debug("decodeSearchResults", zap.Any("i", i), zap.Any("len(SlicedBob)", len(partialSearchResult.SlicedBlob)))
|
||||||
if partialSearchResult.SlicedBlob == nil {
|
if partialSearchResult.SlicedBlob == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var partialResultData schemapb.SearchResultData
|
var partialResultData schemapb.SearchResultData
|
||||||
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
||||||
log.Debug("decodeSearchResultsSerial, Unmarshal partitalSearchResult.SliceBlob", zap.Error(err))
|
log.Debug("decodeSearchResults, Unmarshal partitalSearchResult.SliceBlob", zap.Error(err))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
results = append(results, &partialResultData)
|
results = append(results, &partialResultData)
|
||||||
}
|
}
|
||||||
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfResults", len(results)))
|
log.Debug("decodeSearchResults", zap.Any("lenOfResults", len(results)))
|
||||||
|
tr.Elapse("done")
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeSearchResults(searchResults []*internalpb.SearchResults) (res []*schemapb.SearchResultData, err error) {
|
|
||||||
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
|
||||||
res, err = decodeSearchResultsSerial(searchResults)
|
|
||||||
// res, err = decodeSearchResultsParallelByCPU(searchResults)
|
|
||||||
tr.Elapse("done")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
||||||
if data.NumQueries != nq {
|
if data.NumQueries != nq {
|
||||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||||
|
@ -1892,17 +1886,14 @@ func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchRe
|
||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
|
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
|
||||||
nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
|
|
||||||
|
|
||||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||||
defer func() {
|
defer func() {
|
||||||
tr.Elapse("done")
|
tr.Elapse("done")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debug("reduceSearchResultData",
|
log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
|
||||||
zap.Int("len(searchResultData)", len(searchResultData)),
|
|
||||||
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
|
|
||||||
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
|
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
|
||||||
|
|
||||||
ret := &milvuspb.SearchResults{
|
ret := &milvuspb.SearchResults{
|
||||||
|
@ -1939,7 +1930,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, avail
|
||||||
|
|
||||||
var realTopK int64 = -1
|
var realTopK int64 = -1
|
||||||
for i := int64(0); i < nq; i++ {
|
for i := int64(0); i < nq; i++ {
|
||||||
offsets := make([]int64, availableQueryNodeNum)
|
offsets := make([]int64, len(searchResultData))
|
||||||
|
|
||||||
var prevIDSet = make(map[int64]struct{})
|
var prevIDSet = make(map[int64]struct{})
|
||||||
var prevScore float32 = math.MaxFloat32
|
var prevScore float32 = math.MaxFloat32
|
||||||
|
@ -2032,11 +2023,11 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
||||||
return fmt.Errorf("searchTask:wait to finish failed, timeout: %d", st.ID())
|
return fmt.Errorf("searchTask:wait to finish failed, timeout: %d", st.ID())
|
||||||
case searchResults := <-st.resultBuf:
|
case searchResults := <-st.resultBuf:
|
||||||
// fmt.Println("searchResults: ", searchResults)
|
// fmt.Println("searchResults: ", searchResults)
|
||||||
filterSearchResult := make([]*internalpb.SearchResults, 0)
|
filterSearchResults := make([]*internalpb.SearchResults, 0)
|
||||||
var filterReason string
|
var filterReason string
|
||||||
for _, partialSearchResult := range searchResults {
|
for _, partialSearchResult := range searchResults {
|
||||||
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
|
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
|
||||||
filterSearchResult = append(filterSearchResult, partialSearchResult)
|
filterSearchResults = append(filterSearchResults, partialSearchResult)
|
||||||
// For debugging, please don't delete.
|
// For debugging, please don't delete.
|
||||||
// printSearchResult(partialSearchResult)
|
// printSearchResult(partialSearchResult)
|
||||||
} else {
|
} else {
|
||||||
|
@ -2044,7 +2035,7 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
availableQueryNodeNum := len(filterSearchResult)
|
availableQueryNodeNum := len(filterSearchResults)
|
||||||
log.Debug("Proxy Search PostExecute stage1",
|
log.Debug("Proxy Search PostExecute stage1",
|
||||||
zap.Any("availableQueryNodeNum", availableQueryNodeNum))
|
zap.Any("availableQueryNodeNum", availableQueryNodeNum))
|
||||||
tr.Record("Proxy Search PostExecute stage1 done")
|
tr.Record("Proxy Search PostExecute stage1 done")
|
||||||
|
@ -2058,19 +2049,17 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
||||||
return fmt.Errorf("No Available Query node result, filter reason %s: id %d", filterReason, st.ID())
|
return fmt.Errorf("No Available Query node result, filter reason %s: id %d", filterReason, st.ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
availableQueryNodeNum = 0
|
validSearchResults, err := decodeSearchResults(filterSearchResults)
|
||||||
for _, partialSearchResult := range filterSearchResult {
|
if err != nil {
|
||||||
if partialSearchResult.SlicedBlob == nil {
|
return err
|
||||||
filterReason += "empty search result\n"
|
|
||||||
} else {
|
|
||||||
availableQueryNodeNum++
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Proxy Search PostExecute stage2", zap.Any("availableQueryNodeNum", availableQueryNodeNum))
|
log.Debug("Proxy Search PostExecute stage2", zap.Any("availableQueryNodeNum", availableQueryNodeNum))
|
||||||
|
|
||||||
if availableQueryNodeNum <= 0 {
|
if len(validSearchResults) <= 0 {
|
||||||
log.Debug("Proxy Search PostExecute stage2 failed", zap.Any("filterReason", filterReason))
|
log.Debug("Proxy Search PostExecute stage2 failed", zap.Any("filterReason", filterReason))
|
||||||
|
|
||||||
|
filterReason += "empty search result\n"
|
||||||
st.result = &milvuspb.SearchResults{
|
st.result = &milvuspb.SearchResults{
|
||||||
Status: &commonpb.Status{
|
Status: &commonpb.Status{
|
||||||
ErrorCode: commonpb.ErrorCode_Success,
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
@ -2084,13 +2073,7 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := decodeSearchResults(filterSearchResult)
|
st.result, err = reduceSearchResultData(validSearchResults, searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum),
|
|
||||||
searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -3093,7 +3093,7 @@ func TestSearchTask_Reduce(t *testing.T) {
|
||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
res, err := reduceSearchResultData(dataArray, 2, nq, topk, metricType)
|
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, ids, res.Results.Ids.GetIntId().Data)
|
assert.Equal(t, ids, res.Results.Ids.GetIntId().Data)
|
||||||
assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores)
|
assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores)
|
||||||
|
@ -3108,7 +3108,7 @@ func TestSearchTask_Reduce(t *testing.T) {
|
||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
res, err := reduceSearchResultData(dataArray, 2, nq, topk, metricType)
|
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data)
|
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data)
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue