mirror of https://github.com/milvus-io/milvus.git
Optimize proxy reduce code readability (#10537)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/10541/head
parent
be57e8fdd8
commit
f6802589eb
|
@ -20,7 +20,6 @@ import (
|
|||
"math"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -1752,19 +1751,19 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64
|
|||
return nil
|
||||
}
|
||||
|
||||
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, idx int64) int {
|
||||
func selectSearchResultData(dataArray []*schemapb.SearchResultData, offsets []int64, topk int64, qi int64) int {
|
||||
sel := -1
|
||||
maxDistance := minFloat32
|
||||
for q, loc := range offsets { // query num, the number of ways to merge
|
||||
if loc >= topk {
|
||||
for i, offset := range offsets { // query num, the number of ways to merge
|
||||
if offset >= topk {
|
||||
continue
|
||||
}
|
||||
offset := idx*topk + loc
|
||||
id := dataArray[q].Ids.GetIntId().Data[offset]
|
||||
idx := qi*topk + offset
|
||||
id := dataArray[i].Ids.GetIntId().Data[idx]
|
||||
if id != -1 {
|
||||
distance := dataArray[q].Scores[offset]
|
||||
distance := dataArray[i].Scores[idx]
|
||||
if distance > maxDistance {
|
||||
sel = q
|
||||
sel = i
|
||||
maxDistance = distance
|
||||
}
|
||||
}
|
||||
|
@ -1905,14 +1904,18 @@ func copySearchResultData(dst *schemapb.SearchResultData, src *schemapb.SearchRe
|
|||
// }
|
||||
//}
|
||||
|
||||
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
|
||||
nq int64, topk int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
|
||||
nq int64, topk int64, metricType string) (*milvuspb.SearchResults, error) {
|
||||
|
||||
log.Debug("reduceSearchResultDataParallel",
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.Elapse("done")
|
||||
}()
|
||||
|
||||
log.Debug("reduceSearchResultData",
|
||||
zap.Int("len(searchResultData)", len(searchResultData)),
|
||||
zap.Int64("availableQueryNodeNum", availableQueryNodeNum),
|
||||
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType),
|
||||
zap.Int("maxParallel", maxParallel))
|
||||
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1935,7 +1938,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
}
|
||||
|
||||
for i, sData := range searchResultData {
|
||||
log.Debug("reduceSearchResultDataParallel",
|
||||
log.Debug("reduceSearchResultData",
|
||||
zap.Int("i", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
|
@ -1953,8 +1956,8 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
|
||||
var prevIDSet = make(map[int64]struct{})
|
||||
var prevScore float32 = math.MaxFloat32
|
||||
var loc int64
|
||||
for loc = 0; loc < topk; {
|
||||
var j int64
|
||||
for j = 0; j < topk; {
|
||||
sel := selectSearchResultData(searchResultData, offsets, topk, i)
|
||||
if sel == -1 {
|
||||
break
|
||||
|
@ -1975,7 +1978,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
prevScore = score
|
||||
prevIDSet = map[int64]struct{}{id: {}}
|
||||
loc++
|
||||
j++
|
||||
} else {
|
||||
// To handle this case:
|
||||
// e1: [100, 0.99]
|
||||
|
@ -1986,7 +1989,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
prevIDSet[id] = struct{}{}
|
||||
loc++
|
||||
j++
|
||||
} else {
|
||||
// entity with same id and same score must be duplicated
|
||||
log.Debug("skip duplicated search result",
|
||||
|
@ -1997,11 +2000,11 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
}
|
||||
offsets[sel]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != loc {
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = loc
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
}
|
||||
|
||||
|
@ -2016,14 +2019,6 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64,
|
||||
nq int64, topk int64, metricType string) (res *milvuspb.SearchResults, err error) {
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResults")
|
||||
res, err = reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, nq, topk, metricType, runtime.NumCPU())
|
||||
tr.Elapse("done")
|
||||
return
|
||||
}
|
||||
|
||||
//func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
||||
// for i := 0; i < len(partialSearchResult.Hits); i++ {
|
||||
// testHits := milvuspb.Hits{}
|
||||
|
|
Loading…
Reference in New Issue