Optimize decodeSearchResults (#10728)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/10733/head
Cai Yudong 2021-10-27 14:10:27 +08:00 committed by GitHub
parent 555912f40e
commit c51155a542
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 39 deletions

View File

@ -1690,39 +1690,33 @@ func (st *searchTask) Execute(ctx context.Context) error {
return err
}
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfSearchResults", len(searchResults)))
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
tr := timerecord.NewTimeRecorder("decodeSearchResults")
log.Debug("decodeSearchResults", zap.Any("lenOfSearchResults", len(searchResults)))
results := make([]*schemapb.SearchResultData, 0)
// necessary to parallel this?
for i, partialSearchResult := range searchResults {
log.Debug("decodeSearchResultsSerial", zap.Any("i", i), zap.Any("len(SlicedBob)", len(partialSearchResult.SlicedBlob)))
log.Debug("decodeSearchResults", zap.Any("i", i), zap.Any("len(SlicedBob)", len(partialSearchResult.SlicedBlob)))
if partialSearchResult.SlicedBlob == nil {
continue
}
var partialResultData schemapb.SearchResultData
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
log.Debug("decodeSearchResultsSerial, Unmarshal partitalSearchResult.SliceBlob", zap.Error(err))
log.Debug("decodeSearchResults, Unmarshal partitalSearchResult.SliceBlob", zap.Error(err))
if err != nil {
return nil, err
}
results = append(results, &partialResultData)
}
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfResults", len(results)))
log.Debug("decodeSearchResults", zap.Any("lenOfResults", len(results)))
tr.Elapse("done")
return results, nil
}
func decodeSearchResults(searchResults []*internalpb.SearchResults) (res []*schemapb.SearchResultData, err error) {
tr := timerecord.NewTimeRecorder("decodeSearchResults")
res, err = decodeSearchResultsSerial(searchResults)
// res, err = decodeSearchResultsParallelByCPU(searchResults)
tr.Elapse("done")
return
}
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
if data.NumQueries != nq {
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
@ -1892,17 +1886,14 @@ func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchRe
// }
//}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.Elapse("done")
}()
log.Debug("reduceSearchResultData",
zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
@ -1939,7 +1930,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, avail
var realTopK int64 = -1
for i := int64(0); i < nq; i++ {
offsets := make([]int64, availableQueryNodeNum)
offsets := make([]int64, len(searchResultData))
var prevIDSet = make(map[int64]struct{})
var prevScore float32 = math.MaxFloat32
@ -2032,11 +2023,11 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
return fmt.Errorf("searchTask:wait to finish failed, timeout: %d", st.ID())
case searchResults := <-st.resultBuf:
// fmt.Println("searchResults: ", searchResults)
filterSearchResult := make([]*internalpb.SearchResults, 0)
filterSearchResults := make([]*internalpb.SearchResults, 0)
var filterReason string
for _, partialSearchResult := range searchResults {
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
filterSearchResult = append(filterSearchResult, partialSearchResult)
filterSearchResults = append(filterSearchResults, partialSearchResult)
// For debugging, please don't delete.
// printSearchResult(partialSearchResult)
} else {
@ -2044,7 +2035,7 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
}
}
availableQueryNodeNum := len(filterSearchResult)
availableQueryNodeNum := len(filterSearchResults)
log.Debug("Proxy Search PostExecute stage1",
zap.Any("availableQueryNodeNum", availableQueryNodeNum))
tr.Record("Proxy Search PostExecute stage1 done")
@ -2058,19 +2049,17 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
return fmt.Errorf("No Available Query node result, filter reason %s: id %d", filterReason, st.ID())
}
availableQueryNodeNum = 0
for _, partialSearchResult := range filterSearchResult {
if partialSearchResult.SlicedBlob == nil {
filterReason += "empty search result\n"
} else {
availableQueryNodeNum++
}
validSearchResults, err := decodeSearchResults(filterSearchResults)
if err != nil {
return err
}
log.Debug("Proxy Search PostExecute stage2", zap.Any("availableQueryNodeNum", availableQueryNodeNum))
if availableQueryNodeNum <= 0 {
if len(validSearchResults) <= 0 {
log.Debug("Proxy Search PostExecute stage2 failed", zap.Any("filterReason", filterReason))
filterReason += "empty search result\n"
st.result = &milvuspb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -2084,13 +2073,7 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
return nil
}
results, err := decodeSearchResults(filterSearchResult)
if err != nil {
return err
}
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum),
searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
st.result, err = reduceSearchResultData(validSearchResults, searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
if err != nil {
return err
}

View File

@ -3093,7 +3093,7 @@ func TestSearchTask_Reduce(t *testing.T) {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, 2, nq, topk, metricType)
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
assert.Nil(t, err)
assert.Equal(t, ids, res.Results.Ids.GetIntId().Data)
assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores)
@ -3108,7 +3108,7 @@ func TestSearchTask_Reduce(t *testing.T) {
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, 2, nq, topk, metricType)
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data)
})