fix: Validate PlaceholderGroups before combine them (#32016)

See also #32015

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/32048/head
congqixia 2024-04-09 11:33:17 +08:00 committed by GitHub
parent 08bfb431b7
commit 1f7f3993a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 174 additions and 11 deletions

View File

@ -137,7 +137,10 @@ func (t *SearchTask) Execute() error {
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")
req := t.req
t.combinePlaceHolderGroups()
err := t.combinePlaceHolderGroups()
if err != nil {
return err
}
searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup)
if err != nil {
return err
@ -343,15 +346,28 @@ func (t *SearchTask) MergeWith(other Task) bool {
}
// combinePlaceHolderGroups combine all the placeholder groups.
func (t *SearchTask) combinePlaceHolderGroups() {
if len(t.others) > 0 {
ret := &commonpb.PlaceholderGroup{}
_ = proto.Unmarshal(t.placeholderGroup, ret)
for _, t := range t.others {
x := &commonpb.PlaceholderGroup{}
_ = proto.Unmarshal(t.placeholderGroup, x)
ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...)
}
t.placeholderGroup, _ = proto.Marshal(ret)
func (t *SearchTask) combinePlaceHolderGroups() error {
if len(t.others) == 0 {
return nil
}
ret := &commonpb.PlaceholderGroup{}
if err := proto.Unmarshal(t.placeholderGroup, ret); err != nil {
return merr.WrapErrParameterInvalidMsg("invalid search vector placeholder: %v", err)
}
if len(ret.GetPlaceholders()) == 0 {
return merr.WrapErrParameterInvalidMsg("empty search vector is not allowed")
}
for _, t := range t.others {
x := &commonpb.PlaceholderGroup{}
if err := proto.Unmarshal(t.placeholderGroup, x); err != nil {
return merr.WrapErrParameterInvalidMsg("invalid search vector placeholder: %v", err)
}
if len(x.GetPlaceholders()) == 0 {
return merr.WrapErrParameterInvalidMsg("empty search vector is not allowed")
}
ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...)
}
t.placeholderGroup, _ = proto.Marshal(ret)
return nil
}

View File

@ -0,0 +1,147 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tasks
import (
"bytes"
"encoding/binary"
"math/rand"
"testing"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/common"
)
type SearchTaskSuite struct {
suite.Suite
}
func (s *SearchTaskSuite) composePlaceholderGroup(nq int, dim int) []byte {
placeHolderGroup := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: lo.RepeatBy(nq, func(_ int) []byte {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
s.Require().NoError(err)
bs = append(bs, buffer.Bytes()...)
}
return bs
}),
},
},
}
bs, err := proto.Marshal(placeHolderGroup)
s.Require().NoError(err)
return bs
}
func (s *SearchTaskSuite) composeEmptyPlaceholderGroup() []byte {
placeHolderGroup := &commonpb.PlaceholderGroup{}
bs, err := proto.Marshal(placeHolderGroup)
s.Require().NoError(err)
return bs
}
func (s *SearchTaskSuite) TestCombinePlaceHolderGroups() {
s.Run("normal", func() {
task := &SearchTask{
placeholderGroup: s.composePlaceholderGroup(1, 128),
others: []*SearchTask{
{
placeholderGroup: s.composePlaceholderGroup(1, 128),
},
},
}
task.combinePlaceHolderGroups()
})
s.Run("tasked_not_merged", func() {
task := &SearchTask{}
err := task.combinePlaceHolderGroups()
s.NoError(err)
})
s.Run("empty_placeholdergroup", func() {
task := &SearchTask{
placeholderGroup: s.composeEmptyPlaceholderGroup(),
others: []*SearchTask{
{
placeholderGroup: s.composePlaceholderGroup(1, 128),
},
},
}
err := task.combinePlaceHolderGroups()
s.Error(err)
task = &SearchTask{
placeholderGroup: s.composePlaceholderGroup(1, 128),
others: []*SearchTask{
{
placeholderGroup: s.composeEmptyPlaceholderGroup(),
},
},
}
err = task.combinePlaceHolderGroups()
s.Error(err)
})
s.Run("unmarshal_fail", func() {
task := &SearchTask{
placeholderGroup: []byte{0x12, 0x34},
others: []*SearchTask{
{
placeholderGroup: s.composePlaceholderGroup(1, 128),
},
},
}
err := task.combinePlaceHolderGroups()
s.Error(err)
task = &SearchTask{
placeholderGroup: s.composePlaceholderGroup(1, 128),
others: []*SearchTask{
{
placeholderGroup: []byte{0x12, 0x34},
},
},
}
err = task.combinePlaceHolderGroups()
s.Error(err)
})
}
func TestSearchTask(t *testing.T) {
suite.Run(t, new(SearchTaskSuite))
}