Optimize proxy reduce code readability (#10537)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/10541/head
Cai Yudong 2021-10-25 14:29:12 +08:00 committed by GitHub
parent be57e8fdd8
commit f6802589eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 28 deletions

View File

@ -20,7 +20,6 @@ import (
"math"
"reflect"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
@ -1752,19 +1751,19 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64
return nil
}
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, idx int64) int {
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, qi int64) int {
sel := -1
maxDistance := minFloat32
for q, loc := range offsets { // query num, the number of ways to merge
if loc >= topk {
for i, offset := range offsets { // query num, the number of ways to merge
if offset >= topk {
continue
}
offset := idx*topk + loc
id := dataArray[q].Ids.GetIntId().Data[offset]
idx := qi*topk + offset
id := dataArray[i].Ids.GetIntId().Data[idx]
if id != -1 {
distance := dataArray[q].Scores[offset]
distance := dataArray[i].Scores[idx]
if distance > maxDistance {
sel = q
sel = i
maxDistance = distance
}
}
@ -1905,14 +1904,18 @@ func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchRe
// }
//}
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
log.Debug("reduceSearchResultDataParallel",
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() {
tr.Elapse("done")
}()
log.Debug("reduceSearchResultData",
zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType),
zap.Int("maxParallel", maxParallel))
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{
Status: &commonpb.Status{
@ -1935,7 +1938,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
for i, sData := range searchResultData {
log.Debug("reduceSearchResultDataParallel",
log.Debug("reduceSearchResultData",
zap.Int("i", i),
zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK),
@ -1953,8 +1956,8 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
var prevIDSet = make(map[int64]struct{})
var prevScore float32 = math.MaxFloat32
var loc int64
for loc = 0; loc < topk; {
var j int64
for j = 0; j < topk; {
sel := selectSearchResultData(searchResultData, offsets, topk, i)
if sel == -1 {
break
@ -1975,7 +1978,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
ret.Results.Scores = append(ret.Results.Scores, score)
prevScore = score
prevIDSet = map[int64]struct{}{id: {}}
loc++
j++
} else {
// To handle this case:
// e1: [100, 0.99]
@ -1986,7 +1989,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
ret.Results.Scores = append(ret.Results.Scores, score)
prevIDSet[id] = struct{}{}
loc++
j++
} else {
// entity with same id and same score must be duplicated
log.Debug("skip duplicated search result",
@ -1997,11 +2000,11 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
}
offsets[sel]++
}
if realTopK != -1 && realTopK != loc {
if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different")
}
realTopK = loc
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
}
@ -2016,14 +2019,6 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
return ret, nil
}
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
nq int64, topk int64, metricType string) (res *milvuspb.SearchResults, err error) {
tr := timerecord.NewTimeRecorder("reduceSearchResults")
res, err = reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, nq, topk, metricType, runtime.NumCPU())
tr.Elapse("done")
return
}
//func printSearchResult(partialSearchResult *internalpb.SearchResults) {
// for i := 0; i < len(partialSearchResult.Hits); i++ {
// testHits := milvuspb.Hits{}