package tasks import ( "bytes" "context" "fmt" "github.com/golang/protobuf/proto" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" ) type Task interface { Execute() error Done(err error) Canceled() error Wait() error } type SearchTask struct { ctx context.Context collection *segments.Collection segmentManager *segments.Manager req *querypb.SearchRequest result *internalpb.SearchResults topk int64 nq int64 placeholderGroup []byte originTopks []int64 originNqs []int64 others []*SearchTask notifier chan error tr *timerecord.TimeRecorder } func NewSearchTask(ctx context.Context, collection *segments.Collection, manager *segments.Manager, req *querypb.SearchRequest, ) *SearchTask { return &SearchTask{ ctx: ctx, collection: collection, segmentManager: manager, req: req, topk: req.GetReq().GetTopk(), nq: req.GetReq().GetNq(), placeholderGroup: req.GetReq().GetPlaceholderGroup(), originTopks: []int64{req.GetReq().GetTopk()}, originNqs: []int64{req.GetReq().GetNq()}, notifier: make(chan error, 1), tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"), } } func (t *SearchTask) Execute() error { log := log.Ctx(t.ctx).With( zap.Int64("collectionID", t.collection.ID()), zap.String("shard", t.req.GetDmlChannels()[0]), ) req := t.req t.combinePlaceHolderGroups() searchReq, err := segments.NewSearchRequest(t.collection, req, t.placeholderGroup) if err != nil { return err } defer searchReq.Delete() var results []*segments.SearchResult if req.GetScope() == querypb.DataScope_Historical { results, _, _, err = segments.SearchHistorical( t.ctx, t.segmentManager, searchReq, req.GetReq().GetCollectionID(), nil, req.GetSegmentIDs(), ) } else if req.GetScope() == querypb.DataScope_Streaming { results, _, _, err = segments.SearchStreaming( t.ctx, t.segmentManager, searchReq, req.GetReq().GetCollectionID(), nil, req.GetSegmentIDs(), ) } if err != nil { return err } defer segments.DeleteSearchResults(results) if len(results) == 0 { for i := range t.originNqs { var task *SearchTask if i == 0 { task = t } else { task = t.others[i-1] } task.result = &internalpb.SearchResults{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, MetricType: req.GetReq().GetMetricType(), NumQueries: t.originNqs[i], TopK: t.originTopks[i], SlicedOffset: 1, SlicedNumCount: 1, } } return nil } tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "searchTaskReduce") blobs, err := segments.ReduceSearchResultsAndFillData( searchReq.Plan(), results, int64(len(results)), t.originNqs, t.originTopks, ) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) return err } defer segments.DeleteSearchResultDataBlobs(blobs) for i := range t.originNqs { blob, err := segments.GetSearchResultDataBlob(blobs, i) if err != nil { return err } var task *SearchTask if i == 0 { task = t } else { task = t.others[i-1] } // Note: blob is unsafe because get from C bs := make([]byte, len(blob)) copy(bs, blob) metrics.QueryNodeReduceLatency.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel). Observe(float64(tr.ElapseSpan().Milliseconds())) task.result = &internalpb.SearchResults{ Status: util.WrapStatus(commonpb.ErrorCode_Success, ""), MetricType: req.GetReq().GetMetricType(), NumQueries: t.originNqs[i], TopK: t.originTopks[i], SlicedBlob: bs, SlicedOffset: 1, SlicedNumCount: 1, } } return nil } func (t *SearchTask) Merge(other *SearchTask) bool { var ( nq = t.nq topk = t.topk otherNq = other.req.GetReq().GetNq() otherTopk = other.req.GetReq().GetTopk() ) diffTopk := topk != otherTopk pre := funcutil.Min(nq*topk, otherNq*otherTopk) maxTopk := funcutil.Max(topk, otherTopk) after := (nq + otherNq) * maxTopk ratio := float64(after) / float64(pre) // Check mergeable if t.req.GetReq().GetDbID() != other.req.GetReq().GetDbID() || t.req.GetReq().GetCollectionID() != other.req.GetReq().GetCollectionID() || t.req.GetReq().GetTravelTimestamp() != other.req.GetReq().GetTravelTimestamp() || t.req.GetReq().GetDslType() != other.req.GetReq().GetDslType() || t.req.GetDmlChannels()[0] != other.req.GetDmlChannels()[0] || nq+otherNq > paramtable.Get().QueryNodeCfg.MaxGroupNQ.GetAsInt64() || diffTopk && ratio > paramtable.Get().QueryNodeCfg.TopKMergeRatio.GetAsFloat() || !funcutil.SliceSetEqual(t.req.GetReq().GetPartitionIDs(), other.req.GetReq().GetPartitionIDs()) || !funcutil.SliceSetEqual(t.req.GetSegmentIDs(), other.req.GetSegmentIDs()) || !bytes.Equal(t.req.GetReq().GetSerializedExprPlan(), other.req.GetReq().GetSerializedExprPlan()) { return false } // Merge t.topk = maxTopk t.nq += otherNq t.originTopks = append(t.originTopks, other.originTopks...) t.originNqs = append(t.originNqs, other.originNqs...) t.others = append(t.others, other) return true } func (t *SearchTask) Done(err error) { if len(t.others) > 0 { metrics.QueryNodeSearchGroupSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(len(t.others) + 1)) metrics.QueryNodeSearchGroupNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.originNqs[0])) metrics.QueryNodeSearchGroupTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.originTopks[0])) } select { case t.notifier <- err: default: } for _, other := range t.others { other.Done(err) } } func (t *SearchTask) Canceled() error { return t.ctx.Err() } func (t *SearchTask) Wait() error { return <-t.notifier } func (t *SearchTask) Result() *internalpb.SearchResults { return t.result } // combinePlaceHolderGroups combine all the placeholder groups. func (t *SearchTask) combinePlaceHolderGroups() { if len(t.others) > 0 { ret := &commonpb.PlaceholderGroup{} _ = proto.Unmarshal(t.placeholderGroup, ret) for _, t := range t.others { x := &commonpb.PlaceholderGroup{} _ = proto.Unmarshal(t.placeholderGroup, x) ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...) } t.placeholderGroup, _ = proto.Marshal(ret) } } type QueryTask struct { }