diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d95dce85ee..4ec6fb7440 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1659,23 +1659,23 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb // return decodeSearchResultsParallelByCPU(searchResults) } -func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) { - log.Debug("reduceSearchResultDataParallel", zap.Any("lenOfsearchResultData", len(searchResultData)), - zap.Any("nq", nq), zap.Any("availableQueryNodeNum", availableQueryNodeNum), - zap.Any("topk", topk), zap.Any("metricType", metricType), - zap.Any("maxParallel", maxParallel)) +func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string, maxParallel int) (*milvuspb.SearchResults, error) { + nq := searchResultData[0].NumQueries + topk := searchResultData[0].TopK - for i, sData := range searchResultData { - log.Debug("reduceSearchResultDataParallel", zap.Any("i", i), zap.Any("len(FieldsData)", len(sData.FieldsData))) - } + log.Debug("reduceSearchResultDataParallel", + zap.Int("len(searchResultData)", len(searchResultData)), + zap.Int64("availableQueryNodeNum", availableQueryNodeNum), + zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType), + zap.Int("maxParallel", maxParallel)) ret := &milvuspb.SearchResults{ Status: &commonpb.Status{ ErrorCode: 0, }, Results: &schemapb.SearchResultData{ - NumQueries: int64(nq), - TopK: int64(topk), + NumQueries: nq, + TopK: topk, FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)), Scores: make([]float32, 0), Ids: &schemapb.IDs{ @@ -1689,14 +1689,36 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat }, } + for i, sData := range searchResultData { + log.Debug("reduceSearchResultDataParallel", + zap.Int("i", i), + zap.Int64("nq", sData.NumQueries), + zap.Int64("topk", sData.TopK), + zap.Any("len(FieldsData)", len(sData.FieldsData))) + if sData.NumQueries != nq { + return ret, fmt.Errorf("search result's nq(%d) mis-match with %d", sData.NumQueries, nq) + } + if sData.TopK != topk { + return ret, fmt.Errorf("search result's topk(%d) mis-match with %d", sData.TopK, topk) + } + if len(sData.Ids.GetIntId().Data) != (int)(nq*topk) { + return ret, fmt.Errorf("search result's id length %d invalid", len(sData.Ids.GetIntId().Data)) + } + if len(sData.Scores) != (int)(nq*topk) { + return ret, fmt.Errorf("search result's score length %d invalid", len(sData.Scores)) + } + } + const minFloat32 = -1 * float32(math.MaxFloat32) // TODO(yukun): Use parallel function - realTopK := -1 - for idx := 0; idx < nq; idx++ { - locs := make([]int, availableQueryNodeNum) + var realTopK int64 = -1 + var idx int64 + var j int64 + for idx = 0; idx < nq; idx++ { + locs := make([]int64, availableQueryNodeNum) - j := 0 + j = 0 for ; j < topk; j++ { valid := true choice, maxDistance := 0, minFloat32 @@ -1823,22 +1845,22 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat case *schemapb.VectorField_BinaryVector: if ret.Results.FieldsData[k].GetVectors().GetBinaryVector() == nil { bvec := &schemapb.VectorField_BinaryVector{ - BinaryVector: vectorType.BinaryVector[curIdx*int((dim/8)) : (curIdx+1)*int((dim/8))], + BinaryVector: vectorType.BinaryVector[curIdx*(dim/8) : (curIdx+1)*(dim/8)], } ret.Results.FieldsData[k].GetVectors().Data = bvec } else { - ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*int((dim/8)):(curIdx+1)*int((dim/8))]...) + ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector = append(ret.Results.FieldsData[k].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector, vectorType.BinaryVector[curIdx*(dim/8):(curIdx+1)*(dim/8)]...) } case *schemapb.VectorField_FloatVector: if ret.Results.FieldsData[k].GetVectors().GetFloatVector() == nil { fvec := &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ - Data: vectorType.FloatVector.Data[curIdx*int(dim) : (curIdx+1)*int(dim)], + Data: vectorType.FloatVector.Data[curIdx*dim : (curIdx+1)*dim], }, } ret.Results.FieldsData[k].GetVectors().Data = fvec } else { - ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*int(dim):(curIdx+1)*int(dim)]...) + ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data = append(ret.Results.FieldsData[k].GetVectors().GetFloatVector().Data, vectorType.FloatVector.Data[curIdx*dim:(curIdx+1)*dim]...) } } } @@ -1851,10 +1873,10 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat // return nil, errors.New("the length (topk) between all result of query is different") } realTopK = j - ret.Results.Topks = append(ret.Results.Topks, int64(realTopK)) + ret.Results.Topks = append(ret.Results.Topks, realTopK) } - ret.Results.TopK = int64(realTopK) + ret.Results.TopK = realTopK if metricType != "IP" { for k := range ret.Results.Scores { @@ -1865,12 +1887,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat return ret, nil } -func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) { +func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, availableQueryNodeNum int64, metricType string) (*milvuspb.SearchResults, error) { t := time.Now() defer func() { log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t))) }() - return reduceSearchResultDataParallel(searchResultData, nq, availableQueryNodeNum, topk, metricType, runtime.NumCPU()) + return reduceSearchResultDataParallel(searchResultData, availableQueryNodeNum, metricType, runtime.NumCPU()) } func printSearchResult(partialSearchResult *internalpb.SearchResults) { @@ -1950,22 +1972,7 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { return err } - nq := results[0].NumQueries - topk := 0 - for _, partialResult := range results { - topk = getMax(topk, int(partialResult.TopK)) - } - if nq <= 0 { - st.result = &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: filterReason, - }, - } - return nil - } - - st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType) + st.result, err = reduceSearchResultData(results, int64(availableQueryNodeNum), searchResults[0].MetricType) if err != nil { return err } diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index ece6b02b12..a881afd78f 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -470,10 +470,10 @@ func (c *Core) setMsgStreams() error { Timestamps: pt, DefaultTimestamp: t, } - log.Debug("update timetick", - zap.Any("DefaultTs", t), - zap.Any("sourceID", c.session.ServerID), - zap.Any("reason", reason)) + //log.Debug("update timetick", + // zap.Any("DefaultTs", t), + // zap.Any("sourceID", c.session.ServerID), + // zap.Any("reason", reason)) return c.chanTimeTick.UpdateTimeTick(&ttMsg, reason) } diff --git a/internal/rootcoord/timeticksync.go b/internal/rootcoord/timeticksync.go index d9e2ed3f68..dbdf559a13 100644 --- a/internal/rootcoord/timeticksync.go +++ b/internal/rootcoord/timeticksync.go @@ -83,13 +83,13 @@ func newTimeTickSync(core *Core) *timetickSync { // sendToChannel send all channels' timetick to sendChan // lock is needed by the invoker -func (t *timetickSync) sendToChannel() error { +func (t *timetickSync) sendToChannel() { if len(t.proxyTimeTick) == 0 { - return fmt.Errorf("proxyTimeTick empty") + return } for _, v := range t.proxyTimeTick { if v == nil { - return fmt.Errorf("proxyTimeTick has not been fulfilled") + return } } // clear proxyTimeTick and send a clone @@ -99,7 +99,6 @@ func (t *timetickSync) sendToChannel() error { t.proxyTimeTick[k] = nil } t.sendChan <- ptt - return nil } // AddDmlTimeTick add ts into ddlTimetickInfos[sourceID], @@ -191,12 +190,10 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg, reason } t.proxyTimeTick[in.Base.SourceID] = newChannelTimeTickMsg(in) - log.Debug("update proxyTimeTick", zap.Int64("source id", in.Base.SourceID), - zap.Uint64("inTs", in.DefaultTimestamp), zap.String("reason", reason)) + //log.Debug("update proxyTimeTick", zap.Int64("source id", in.Base.SourceID), + // zap.Uint64("inTs", in.DefaultTimestamp), zap.String("reason", reason)) - if err := t.sendToChannel(); err != nil { - log.Debug("sendToChannel fail", zap.Any("err", err.Error())) - } + t.sendToChannel() return nil }