mirror of https://github.com/milvus-io/milvus.git
Optimize decodeSearchResults (#10728)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10733/head
parent
555912f40e
commit
c51155a542
|
@ -1690,39 +1690,33 @@ func (st *searchTask) Execute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
||||
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfSearchResults", len(searchResults)))
|
||||
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
|
||||
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
||||
log.Debug("decodeSearchResults", zap.Any("lenOfSearchResults", len(searchResults)))
|
||||
|
||||
results := make([]*schemapb.SearchResultData, 0)
|
||||
// necessary to parallel this?
|
||||
for i, partialSearchResult := range searchResults {
|
||||
log.Debug("decodeSearchResultsSerial", zap.Any("i", i), zap.Any("len(SlicedBob)", len(partialSearchResult.SlicedBlob)))
|
||||
log.Debug("decodeSearchResults", zap.Any("i", i), zap.Any("len(SlicedBob)", len(partialSearchResult.SlicedBlob)))
|
||||
if partialSearchResult.SlicedBlob == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var partialResultData schemapb.SearchResultData
|
||||
err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
|
||||
log.Debug("decodeSearchResultsSerial, Unmarshal partitalSearchResult.SliceBlob", zap.Error(err))
|
||||
log.Debug("decodeSearchResults, Unmarshal partitalSearchResult.SliceBlob", zap.Error(err))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results = append(results, &partialResultData)
|
||||
}
|
||||
log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfResults", len(results)))
|
||||
log.Debug("decodeSearchResults", zap.Any("lenOfResults", len(results)))
|
||||
tr.Elapse("done")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func decodeSearchResults(searchResults []*internalpb.SearchResults) (res []*schemapb.SearchResultData, err error) {
|
||||
tr := timerecord.NewTimeRecorder("decodeSearchResults")
|
||||
res, err = decodeSearchResultsSerial(searchResults)
|
||||
// res, err = decodeSearchResultsParallelByCPU(searchResults)
|
||||
tr.Elapse("done")
|
||||
return
|
||||
}
|
||||
|
||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
||||
if data.NumQueries != nq {
|
||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||
|
@ -1892,17 +1886,14 @@ func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchRe
|
|||
// }
|
||||
//}
|
||||
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
|
||||
nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
|
||||
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.Elapse("done")
|
||||
}()
|
||||
|
||||
log.Debug("reduceSearchResultData",
|
||||
zap.Int("len(searchResultData)", len(searchResultData)),
|
||||
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
|
||||
log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
|
||||
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
|
@ -1939,7 +1930,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, avail
|
|||
|
||||
var realTopK int64 = -1
|
||||
for i := int64(0); i < nq; i++ {
|
||||
offsets := make([]int64, availableQueryNodeNum)
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
|
||||
var prevIDSet = make(map[int64]struct{})
|
||||
var prevScore float32 = math.MaxFloat32
|
||||
|
@ -2032,11 +2023,11 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
|||
return fmt.Errorf("searchTask:wait to finish failed, timeout: %d", st.ID())
|
||||
case searchResults := <-st.resultBuf:
|
||||
// fmt.Println("searchResults: ", searchResults)
|
||||
filterSearchResult := make([]*internalpb.SearchResults, 0)
|
||||
filterSearchResults := make([]*internalpb.SearchResults, 0)
|
||||
var filterReason string
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
|
||||
filterSearchResult = append(filterSearchResult, partialSearchResult)
|
||||
filterSearchResults = append(filterSearchResults, partialSearchResult)
|
||||
// For debugging, please don't delete.
|
||||
// printSearchResult(partialSearchResult)
|
||||
} else {
|
||||
|
@ -2044,7 +2035,7 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
availableQueryNodeNum := len(filterSearchResult)
|
||||
availableQueryNodeNum := len(filterSearchResults)
|
||||
log.Debug("Proxy Search PostExecute stage1",
|
||||
zap.Any("availableQueryNodeNum", availableQueryNodeNum))
|
||||
tr.Record("Proxy Search PostExecute stage1 done")
|
||||
|
@ -2058,19 +2049,17 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
|||
return fmt.Errorf("No Available Query node result, filter reason %s: id %d", filterReason, st.ID())
|
||||
}
|
||||
|
||||
availableQueryNodeNum = 0
|
||||
for _, partialSearchResult := range filterSearchResult {
|
||||
if partialSearchResult.SlicedBlob == nil {
|
||||
filterReason += "empty search result\n"
|
||||
} else {
|
||||
availableQueryNodeNum++
|
||||
}
|
||||
validSearchResults, err := decodeSearchResults(filterSearchResults)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Proxy Search PostExecute stage2", zap.Any("availableQueryNodeNum", availableQueryNodeNum))
|
||||
|
||||
if availableQueryNodeNum <= 0 {
|
||||
if len(validSearchResults) <= 0 {
|
||||
log.Debug("Proxy Search PostExecute stage2 failed", zap.Any("filterReason", filterReason))
|
||||
|
||||
filterReason += "empty search result\n"
|
||||
st.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -2084,13 +2073,7 @@ func (st *searchTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
results, err := decodeSearchResults(filterSearchResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum),
|
||||
searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
|
||||
st.result, err = reduceSearchResultData(validSearchResults, searchResults[0].NumQueries, searchResults[0].TopK, searchResults[0].MetricType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3093,7 +3093,7 @@ func TestSearchTask_Reduce(t *testing.T) {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := reduceSearchResultData(dataArray, 2, nq, topk, metricType)
|
||||
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ids, res.Results.Ids.GetIntId().Data)
|
||||
assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores)
|
||||
|
@ -3108,7 +3108,7 @@ func TestSearchTask_Reduce(t *testing.T) {
|
|||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := reduceSearchResultData(dataArray, 2, nq, topk, metricType)
|
||||
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data)
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue