mirror of https://github.com/milvus-io/milvus.git
Calculate the real topk in proxy (#6132)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/6169/head^2
parent
6f4ad331c8
commit
d3c503f3aa
|
@ -1568,7 +1568,7 @@ func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNode
|
|||
return ret
|
||||
}
|
||||
|
||||
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) *milvuspb.SearchResults {
|
||||
func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) {
|
||||
log.Debug("reduceSearchResultDataParallel", zap.Any("NumOfGoRoutines", maxParallel))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
|
@ -1593,10 +1593,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
// TODO(yukun): Use parallel function
|
||||
realTopK := -1
|
||||
for idx := 0; idx < nq; idx++ {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
j := 0
|
||||
for ; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
|
@ -1696,7 +1698,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
}
|
||||
default:
|
||||
log.Debug("Not supported field type")
|
||||
return nil
|
||||
return nil, fmt.Errorf("not supported field type: %s", fieldData.Type.String())
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
dim := fieldType.Vectors.Dim
|
||||
|
@ -1729,9 +1731,15 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
}
|
||||
|
||||
ret.Results.TopK = int64(realTopK)
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
|
@ -1742,7 +1750,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
// return nil
|
||||
// }
|
||||
|
||||
return ret
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
|
@ -1767,7 +1775,7 @@ func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, top
|
|||
return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType)
|
||||
}
|
||||
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) {
|
||||
t := time.Now()
|
||||
defer func() {
|
||||
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
|
||||
|
@ -1853,7 +1861,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
nq := results[0].NumQueries
|
||||
topk := results[0].TopK
|
||||
topk := 0
|
||||
for _, partialResult := range results {
|
||||
topk = getMax(topk, int(partialResult.TopK))
|
||||
}
|
||||
if nq <= 0 {
|
||||
st.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1864,7 +1875,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
st.result = reduceSearchResultData(results, int(nq), availableQueryNodeNum, int(topk), searchResults[0].MetricType)
|
||||
st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, st.query.CollectionName)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue