From f6802589eb379cbcb9d06902eed8704caf1a5988 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Mon, 25 Oct 2021 14:29:12 +0800 Subject: [PATCH] Optimize proxy reduce code readability (#10537) Signed-off-by: yudong.cai --- internal/proxy/task.go | 51 +++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index fda1274f6c..5d8f848b9a 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -20,7 +20,6 @@ import ( "math" "reflect" "regexp" - "runtime" "sort" "strconv" "strings" @@ -1752,19 +1751,19 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64 return nil } -func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, idx int64) int { +func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, qi int64) int { sel := -1 maxDistance := minFloat32 - for q, loc := range offsets { // query num, the number of ways to merge - if loc >= topk { + for i, offset := range offsets { // query num, the number of ways to merge + if offset >= topk { continue } - offset := idx*topk + loc - id := dataArray[q].Ids.GetIntId().Data[offset] + idx := qi*topk + offset + id := dataArray[i].Ids.GetIntId().Data[idx] if id != -1 { - distance := dataArray[q].Scores[offset] + distance := dataArray[i].Scores[idx] if distance > maxDistance { - sel = q + sel = i maxDistance = distance } } @@ -1905,14 +1904,18 @@ func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchRe // } //} -func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, - nq int64, topk int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) { +func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, + nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) { - log.Debug("reduceSearchResultDataParallel", + tr := timerecord.NewTimeRecorder("reduceSearchResultData") + defer func() { + tr.Elapse("done") + }() + + log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)), zap.Int64("availableQueryNodeNum", availableQueryNodeNum), - zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType), - zap.Int("maxParallel", maxParallel)) + zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType)) ret := &milvuspb.SearchResults{ Status: &commonpb.Status{ @@ -1935,7 +1938,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat } for i, sData := range searchResultData { - log.Debug("reduceSearchResultDataParallel", + log.Debug("reduceSearchResultData", zap.Int("i", i), zap.Int64("nq", sData.NumQueries), zap.Int64("topk", sData.TopK), @@ -1953,8 +1956,8 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat var prevIDSet = make(map[int64]struct{}) var prevScore float32 = math.MaxFloat32 - var loc int64 - for loc = 0; loc < topk; { + var j int64 + for j = 0; j < topk; { sel := selectSearchResultData(searchResultData, offsets, topk, i) if sel == -1 { break @@ -1975,7 +1978,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ret.Results.Scores = append(ret.Results.Scores, score) prevScore = score prevIDSet = map[int64]struct{}{id: {}} - loc++ + j++ } else { // To handle this case: // e1: [100, 0.99] @@ -1986,7 +1989,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id) ret.Results.Scores = append(ret.Results.Scores, score) prevIDSet[id] = struct{}{} - loc++ + j++ } else { // entity with same id and same score must be duplicated log.Debug("skip duplicated search result", @@ -1997,11 +2000,11 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat } offsets[sel]++ } - if realTopK != -1 && realTopK != loc { + if realTopK != -1 && realTopK != j { log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) // return nil, errors.New("the length (topk) between all result of query is different") } - realTopK = loc + realTopK = j ret.Results.Topks = append(ret.Results.Topks, realTopK) } @@ -2016,14 +2019,6 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat return ret, nil } -func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, - nq int64, topk int64, metricType string) (res *milvuspb.SearchResults, err error) { - tr := timerecord.NewTimeRecorder("reduceSearchResults") - res, err = reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, nq, topk, metricType, runtime.NumCPU()) - tr.Elapse("done") - return -} - //func printSearchResult(partialSearchResult *internalpb.SearchResults) { // for i := 0; i < len(partialSearchResult.Hits); i++ { // testHits := milvuspb.Hits{}