Simplify the merge logic of searchTask (#17194)

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/17198/head
zhenshan.cao 2022-05-24 21:27:59 +08:00 committed by GitHub
parent a4ea2fb18a
commit ec1103ca27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 44 deletions

View File

@ -219,7 +219,12 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
t.SearchRequest.Topk = int64(topK)
err = validateTopK(int64(topK))
if err != nil {
return err
}

View File

@ -55,6 +55,14 @@ func isNumber(c uint8) bool {
return true
}
func validateTopK(topK int64) error {
// TODO make this configurable
if topK <= 0 || topK >= 16385 {
return fmt.Errorf("limit should be in range [1, 16385], but got %d", topK)
}
return nil
}
func validateCollectionNameOrAlias(entity, entityType string) error {
entity = strings.TrimSpace(entity)

View File

@ -44,25 +44,20 @@ type searchTask struct {
iReq *internalpb.SearchRequest
req *querypb.SearchRequest
MetricType string
PlaceholderGroup []byte
OrigPlaceHolderGroups [][]byte
NQ int64
OrigNQs []int64
TopK int64
OrigTopKs []int64
Ret *internalpb.SearchResults
originTasks []*searchTask
cpuOnce sync.Once
plan *planpb.PlanNode
qInfo *planpb.QueryInfo
MetricType string
PlaceholderGroup []byte
NQ int64
OrigNQs []int64
TopK int64
OrigTopKs []int64
Ret *internalpb.SearchResults
otherTasks []*searchTask
cpuOnce sync.Once
plan *planpb.PlanNode
qInfo *planpb.QueryInfo
}
func (s *searchTask) PreExecute(ctx context.Context) error {
topK := s.TopK
if topK <= 0 || topK >= 16385 {
return fmt.Errorf("limit should be in range [1, 16385], but got %d", topK)
}
s.combinePlaceHolderGroups()
return nil
}
@ -162,11 +157,8 @@ func (s *searchTask) Execute(ctx context.Context) error {
func (s *searchTask) Notify(err error) {
s.done <- err
for i := 1; i < len(s.originTasks); i++ {
s.originTasks[i].Notify(err)
}
if len(s.originTasks) > 0 {
s.originTasks[0] = nil
for i := 0; i < len(s.otherTasks); i++ {
s.otherTasks[i].Notify(err)
}
}
@ -193,6 +185,8 @@ func (s *searchTask) CPUUsage() int32 {
// reduceResults reduce search results
func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchResult) error {
isEmpty := len(results) == 0
cnt := 1 + len(s.otherTasks)
var t *searchTask
if !isEmpty {
sInfo := parseSliceInfo(s.OrigNQs, s.OrigTopKs, s.NQ)
numSegment := int64(len(results))
@ -205,7 +199,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
log.Debug("marshal for historical results error", zap.Error(err))
return err
}
for i := 0; i < len(s.originTasks); i++ {
for i := 0; i < cnt; i++ {
blob, err := getSearchResultDataBlob(blobs, i)
if err != nil {
log.Debug("getSearchResultDataBlob for historical results error", zap.Error(err))
@ -213,7 +207,12 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
}
bs := make([]byte, len(blob))
copy(bs, blob)
s.originTasks[i].Ret = &internalpb.SearchResults{
if i == 0 {
t = s
} else {
t = s.otherTasks[i-1]
}
t.Ret = &internalpb.SearchResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
MetricType: s.MetricType,
NumQueries: s.OrigNQs[i],
@ -224,8 +223,13 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
}
}
} else {
for i := 0; i < len(s.originTasks); i++ {
s.originTasks[i].Ret = &internalpb.SearchResults{
for i := 0; i < cnt; i++ {
if i == 0 {
t = s
} else {
t = s.otherTasks[i-1]
}
t.Ret = &internalpb.SearchResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
MetricType: s.MetricType,
NumQueries: s.OrigNQs[i],
@ -336,23 +340,18 @@ func (s *searchTask) Merge(t readTask) {
s.TopK = newTopK
s.OrigTopKs = append(s.OrigTopKs, src.OrigTopKs...)
s.OrigNQs = append(s.OrigNQs, src.OrigNQs...)
s.OrigPlaceHolderGroups = append(s.OrigPlaceHolderGroups, src.OrigPlaceHolderGroups...)
s.NQ += src.NQ
s.originTasks = append(s.originTasks, src)
s.otherTasks = append(s.otherTasks, src)
}
// combinePlaceHolderGroups combine all the placeholder groups.
func (s *searchTask) combinePlaceHolderGroups() {
if len(s.OrigPlaceHolderGroups) > 1 {
if len(s.otherTasks) > 0 {
ret := &commonpb.PlaceholderGroup{}
//retValues := ret.Placeholders[0].GetValues()
_ = proto.Unmarshal(s.PlaceholderGroup, ret)
for i, grp := range s.OrigPlaceHolderGroups {
if i == 0 {
continue
}
for _, t := range s.otherTasks {
x := &commonpb.PlaceholderGroup{}
_ = proto.Unmarshal(grp, x)
_ = proto.Unmarshal(t.PlaceholderGroup, x)
ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...)
}
s.PlaceholderGroup, _ = proto.Marshal(ret)
@ -376,20 +375,18 @@ func newSearchTask(ctx context.Context, src *querypb.SearchRequest) (*searchTask
tr: timerecord.NewTimeRecorder("searchTask"),
DataScope: src.GetScope(),
},
iReq: src.Req,
req: src,
TopK: src.Req.GetTopk(),
OrigTopKs: []int64{src.Req.GetTopk()},
NQ: src.Req.GetNq(),
OrigNQs: []int64{src.Req.GetNq()},
OrigPlaceHolderGroups: [][]byte{src.Req.GetPlaceholderGroup()},
PlaceholderGroup: src.Req.GetPlaceholderGroup(),
MetricType: src.Req.GetMetricType(),
iReq: src.Req,
req: src,
TopK: src.Req.GetTopk(),
OrigTopKs: []int64{src.Req.GetTopk()},
NQ: src.Req.GetNq(),
OrigNQs: []int64{src.Req.GetNq()},
PlaceholderGroup: src.Req.GetPlaceholderGroup(),
MetricType: src.Req.GetMetricType(),
}
err := target.init()
if err != nil {
return nil, err
}
target.originTasks = append(target.originTasks, target)
return target, nil
}