mirror of https://github.com/milvus-io/milvus.git
Simplify the merge logic of searchTask (#17194)
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/17198/head
parent
a4ea2fb18a
commit
ec1103ca27
|
@ -219,7 +219,12 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.Topk = int64(topK)
|
||||
err = validateTopK(int64(topK))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -55,6 +55,14 @@ func isNumber(c uint8) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func validateTopK(topK int64) error {
|
||||
// TODO make this configurable
|
||||
if topK <= 0 || topK >= 16385 {
|
||||
return fmt.Errorf("limit should be in range [1, 16385], but got %d", topK)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCollectionNameOrAlias(entity, entityType string) error {
|
||||
entity = strings.TrimSpace(entity)
|
||||
|
||||
|
|
|
@ -44,25 +44,20 @@ type searchTask struct {
|
|||
iReq *internalpb.SearchRequest
|
||||
req *querypb.SearchRequest
|
||||
|
||||
MetricType string
|
||||
PlaceholderGroup []byte
|
||||
OrigPlaceHolderGroups [][]byte
|
||||
NQ int64
|
||||
OrigNQs []int64
|
||||
TopK int64
|
||||
OrigTopKs []int64
|
||||
Ret *internalpb.SearchResults
|
||||
originTasks []*searchTask
|
||||
cpuOnce sync.Once
|
||||
plan *planpb.PlanNode
|
||||
qInfo *planpb.QueryInfo
|
||||
MetricType string
|
||||
PlaceholderGroup []byte
|
||||
NQ int64
|
||||
OrigNQs []int64
|
||||
TopK int64
|
||||
OrigTopKs []int64
|
||||
Ret *internalpb.SearchResults
|
||||
otherTasks []*searchTask
|
||||
cpuOnce sync.Once
|
||||
plan *planpb.PlanNode
|
||||
qInfo *planpb.QueryInfo
|
||||
}
|
||||
|
||||
func (s *searchTask) PreExecute(ctx context.Context) error {
|
||||
topK := s.TopK
|
||||
if topK <= 0 || topK >= 16385 {
|
||||
return fmt.Errorf("limit should be in range [1, 16385], but got %d", topK)
|
||||
}
|
||||
s.combinePlaceHolderGroups()
|
||||
return nil
|
||||
}
|
||||
|
@ -162,11 +157,8 @@ func (s *searchTask) Execute(ctx context.Context) error {
|
|||
|
||||
func (s *searchTask) Notify(err error) {
|
||||
s.done <- err
|
||||
for i := 1; i < len(s.originTasks); i++ {
|
||||
s.originTasks[i].Notify(err)
|
||||
}
|
||||
if len(s.originTasks) > 0 {
|
||||
s.originTasks[0] = nil
|
||||
for i := 0; i < len(s.otherTasks); i++ {
|
||||
s.otherTasks[i].Notify(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -193,6 +185,8 @@ func (s *searchTask) CPUUsage() int32 {
|
|||
// reduceResults reduce search results
|
||||
func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchResult) error {
|
||||
isEmpty := len(results) == 0
|
||||
cnt := 1 + len(s.otherTasks)
|
||||
var t *searchTask
|
||||
if !isEmpty {
|
||||
sInfo := parseSliceInfo(s.OrigNQs, s.OrigTopKs, s.NQ)
|
||||
numSegment := int64(len(results))
|
||||
|
@ -205,7 +199,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
|
|||
log.Debug("marshal for historical results error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
for i := 0; i < len(s.originTasks); i++ {
|
||||
for i := 0; i < cnt; i++ {
|
||||
blob, err := getSearchResultDataBlob(blobs, i)
|
||||
if err != nil {
|
||||
log.Debug("getSearchResultDataBlob for historical results error", zap.Error(err))
|
||||
|
@ -213,7 +207,12 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
|
|||
}
|
||||
bs := make([]byte, len(blob))
|
||||
copy(bs, blob)
|
||||
s.originTasks[i].Ret = &internalpb.SearchResults{
|
||||
if i == 0 {
|
||||
t = s
|
||||
} else {
|
||||
t = s.otherTasks[i-1]
|
||||
}
|
||||
t.Ret = &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
MetricType: s.MetricType,
|
||||
NumQueries: s.OrigNQs[i],
|
||||
|
@ -224,8 +223,13 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
|
|||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < len(s.originTasks); i++ {
|
||||
s.originTasks[i].Ret = &internalpb.SearchResults{
|
||||
for i := 0; i < cnt; i++ {
|
||||
if i == 0 {
|
||||
t = s
|
||||
} else {
|
||||
t = s.otherTasks[i-1]
|
||||
}
|
||||
t.Ret = &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
MetricType: s.MetricType,
|
||||
NumQueries: s.OrigNQs[i],
|
||||
|
@ -336,23 +340,18 @@ func (s *searchTask) Merge(t readTask) {
|
|||
s.TopK = newTopK
|
||||
s.OrigTopKs = append(s.OrigTopKs, src.OrigTopKs...)
|
||||
s.OrigNQs = append(s.OrigNQs, src.OrigNQs...)
|
||||
s.OrigPlaceHolderGroups = append(s.OrigPlaceHolderGroups, src.OrigPlaceHolderGroups...)
|
||||
s.NQ += src.NQ
|
||||
s.originTasks = append(s.originTasks, src)
|
||||
s.otherTasks = append(s.otherTasks, src)
|
||||
}
|
||||
|
||||
// combinePlaceHolderGroups combine all the placeholder groups.
|
||||
func (s *searchTask) combinePlaceHolderGroups() {
|
||||
if len(s.OrigPlaceHolderGroups) > 1 {
|
||||
if len(s.otherTasks) > 0 {
|
||||
ret := &commonpb.PlaceholderGroup{}
|
||||
//retValues := ret.Placeholders[0].GetValues()
|
||||
_ = proto.Unmarshal(s.PlaceholderGroup, ret)
|
||||
for i, grp := range s.OrigPlaceHolderGroups {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
for _, t := range s.otherTasks {
|
||||
x := &commonpb.PlaceholderGroup{}
|
||||
_ = proto.Unmarshal(grp, x)
|
||||
_ = proto.Unmarshal(t.PlaceholderGroup, x)
|
||||
ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...)
|
||||
}
|
||||
s.PlaceholderGroup, _ = proto.Marshal(ret)
|
||||
|
@ -376,20 +375,18 @@ func newSearchTask(ctx context.Context, src *querypb.SearchRequest) (*searchTask
|
|||
tr: timerecord.NewTimeRecorder("searchTask"),
|
||||
DataScope: src.GetScope(),
|
||||
},
|
||||
iReq: src.Req,
|
||||
req: src,
|
||||
TopK: src.Req.GetTopk(),
|
||||
OrigTopKs: []int64{src.Req.GetTopk()},
|
||||
NQ: src.Req.GetNq(),
|
||||
OrigNQs: []int64{src.Req.GetNq()},
|
||||
OrigPlaceHolderGroups: [][]byte{src.Req.GetPlaceholderGroup()},
|
||||
PlaceholderGroup: src.Req.GetPlaceholderGroup(),
|
||||
MetricType: src.Req.GetMetricType(),
|
||||
iReq: src.Req,
|
||||
req: src,
|
||||
TopK: src.Req.GetTopk(),
|
||||
OrigTopKs: []int64{src.Req.GetTopk()},
|
||||
NQ: src.Req.GetNq(),
|
||||
OrigNQs: []int64{src.Req.GetNq()},
|
||||
PlaceholderGroup: src.Req.GetPlaceholderGroup(),
|
||||
MetricType: src.Req.GetMetricType(),
|
||||
}
|
||||
err := target.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
target.originTasks = append(target.originTasks, target)
|
||||
return target, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue