enhance: Refactor hybrid search (#31742)

issue: https://github.com/milvus-io/milvus/issues/25639
https://github.com/milvus-io/milvus/issues/31368
pr :https://github.com/milvus-io/milvus/pull/32020

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/32038/head
zhenshan.cao 2024-04-09 10:15:18 +08:00 committed by GitHub
parent 39d988cf8d
commit 4c07304790
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1261 additions and 2331 deletions

View File

@ -332,10 +332,3 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr
return client.Delete(ctx, req)
})
}
// HybridSearch performs replica hybrid search tasks in QueryNode.
func (c *Client) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest, _ ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.HybridSearchResult, error) {
return client.HybridSearch(ctx, req)
})
}

View File

@ -387,8 +387,3 @@ func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
return s.querynode.Delete(ctx, req)
}
// HybridSearch performs hybrid search of streaming/historical replica on QueryNode.
func (s *Server) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
return s.querynode.HybridSearch(ctx, req)
}

View File

@ -1258,8 +1258,8 @@ type DataCoordCatalog_SaveChannelCheckpoints_Call struct {
}
// SaveChannelCheckpoints is a helper method to define mock.On call
// - ctx context.Context
// - positions []*msgpb.MsgPosition
// - ctx context.Context
// - positions []*msgpb.MsgPosition
func (_e *DataCoordCatalog_Expecter) SaveChannelCheckpoints(ctx interface{}, positions interface{}) *DataCoordCatalog_SaveChannelCheckpoints_Call {
return &DataCoordCatalog_SaveChannelCheckpoints_Call{Call: _e.mock.On("SaveChannelCheckpoints", ctx, positions)}
}

View File

@ -552,61 +552,6 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.C
return _c
}
// HybridSearch provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNode) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
ret := _m.Called(_a0, _a1)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNode_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockQueryNode_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.HybridSearchRequest
func (_e *MockQueryNode_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNode_HybridSearch_Call {
return &MockQueryNode_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
}
func (_c *MockQueryNode_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNode_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
})
return _c
}
func (_c *MockQueryNode_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNode_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryNode_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNode_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// Init provides a mock function with given fields:
func (_m *MockQueryNode) Init() error {
ret := _m.Called()

View File

@ -632,76 +632,6 @@ func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(con
return _c
}
// HybridSearch provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) *querypb.HybridSearchResult); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeClient_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockQueryNodeClient_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - ctx context.Context
// - in *querypb.HybridSearchRequest
// - opts ...grpc.CallOption
func (_e *MockQueryNodeClient_Expecter) HybridSearch(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_HybridSearch_Call {
return &MockQueryNodeClient_HybridSearch_Call{Call: _e.mock.On("HybridSearch",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockQueryNodeClient_HybridSearch_Call) Run(run func(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest), variadicArgs...)
})
return _c
}
func (_c *MockQueryNodeClient_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeClient_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryNodeClient_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)) *MockQueryNodeClient_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// LoadPartitions provides a mock function with given fields: ctx, in, opts
func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))

View File

@ -82,6 +82,20 @@ message CreateIndexRequest {
repeated common.KeyValuePair extra_params = 8;
}
message SubSearchRequest {
string dsl = 1;
// serialized `PlaceholderGroup`
bytes placeholder_group = 2;
common.DslType dsl_type = 3;
bytes serialized_expr_plan = 4;
int64 nq = 5;
repeated int64 partitionIDs = 6;
int64 topk = 7;
int64 offset = 8;
string metricType = 9;
}
message SearchRequest {
common.MsgBase base = 1;
int64 reqID = 2;
@ -102,18 +116,22 @@ message SearchRequest {
string metricType = 16;
bool ignoreGrowing = 17; // Optional
string username = 18;
repeated SubSearchRequest sub_reqs = 19;
bool is_advanced = 20;
int64 offset = 21;
common.ConsistencyLevel consistency_level = 22;
}
message HybridSearchRequest {
common.MsgBase base = 1;
int64 reqID = 2;
int64 dbID = 3;
int64 collectionID = 4;
repeated int64 partitionIDs = 5;
repeated SearchRequest reqs = 6;
uint64 mvcc_timestamp = 11;
uint64 guarantee_timestamp = 12;
uint64 timeout_timestamp = 13;
message SubSearchResults {
string metric_type = 1;
int64 num_queries = 2;
int64 top_k = 3;
// schema.SearchResultsData inside
bytes sliced_blob = 4;
int64 sliced_num_count = 5;
int64 sliced_offset = 6;
// to indicate it belongs to which sub request
int64 req_index = 7;
}
message SearchResults {
@ -134,6 +152,8 @@ message SearchResults {
// search request cost
CostAggregation costAggregation = 13;
map<string, uint64> channels_mvcc = 14;
repeated SubSearchResults sub_results = 15;
bool is_advanced = 16;
}
message CostAggregation {

View File

@ -139,8 +139,6 @@ service QueryNode {
}
rpc Search(SearchRequest) returns (internal.SearchResults) {
}
rpc HybridSearch(HybridSearchRequest) returns (HybridSearchResult) {
}
rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {
}
rpc Query(QueryRequest) returns (internal.RetrieveResults) {
@ -416,20 +414,6 @@ message SearchRequest {
int32 total_channel_num = 6;
}
message HybridSearchRequest {
internal.HybridSearchRequest req = 1;
repeated string dml_channels = 2;
int32 total_channel_num = 3;
}
message HybridSearchResult {
common.MsgBase base = 1;
common.Status status = 2;
repeated internal.SearchResults results = 3;
internal.CostAggregation costAggregation = 4;
map<string, uint64> channels_mvcc = 5;
}
message QueryRequest {
internal.RetrieveRequest req = 1;
repeated string dml_channels = 2;

View File

@ -2825,18 +2825,19 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch")
defer sp.End()
qt := &hybridSearchTask{
newSearchReq := convertHybridSearchToSearch(request)
qt := &searchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
HybridSearchRequest: &internalpb.HybridSearchRequest{
SearchRequest: &internalpb.SearchRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: request,
tr: timerecord.NewTimeRecorder(method),
request: newSearchReq,
tr: timerecord.NewTimeRecorder("search"),
qc: node.queryCoord,
node: node,
lb: node.lbPolicy,
@ -2921,7 +2922,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
metrics.SuccessLabel,
).Inc()
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(qt.request.GetRequests())))
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(request.GetRequests())))
searchDur := tr.ElapseSpan().Milliseconds()
metrics.ProxySQLatency.WithLabelValues(

View File

@ -1562,6 +1562,31 @@ func TestProxy(t *testing.T) {
}
}
constructSubSearchRequest := func(nq int) *milvuspb.SubSearchRequest {
plg := constructVectorsPlaceholderGroup(nq)
plgBs, err := proto.Marshal(plg)
assert.NoError(t, err)
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
assert.NoError(t, err)
searchParams := []*commonpb.KeyValuePair{
{Key: MetricTypeKey, Value: metric.L2},
{Key: SearchParamsKey, Value: string(b)},
{Key: AnnsFieldKey, Value: floatVecField},
{Key: TopKKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
}
return &milvuspb.SubSearchRequest{
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
SearchParams: searchParams,
}
}
wg.Add(1)
t.Run("search", func(t *testing.T) {
defer wg.Done()
@ -1572,7 +1597,7 @@ func TestProxy(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
constructHybridSearchRequest := func(reqs []*milvuspb.SearchRequest) *milvuspb.HybridSearchRequest {
constructAdvancedSearchRequest := func() *milvuspb.SearchRequest {
params := make(map[string]float64)
params[RRFParamsKey] = 60
b, err := json.Marshal(params)
@ -1584,32 +1609,33 @@ func TestProxy(t *testing.T) {
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
}
return &milvuspb.HybridSearchRequest{
req1 := constructSubSearchRequest(nq)
req2 := constructSubSearchRequest(nq)
ret := &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Requests: reqs,
PartitionNames: nil,
OutputFields: nil,
RankParams: rankParams,
SearchParams: rankParams,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
ret.SubReqs = append(ret.SubReqs, req1)
ret.SubReqs = append(ret.SubReqs, req2)
return ret
}
wg.Add(1)
nq = 1
t.Run("hybrid search", func(t *testing.T) {
t.Run("advanced search", func(t *testing.T) {
defer wg.Done()
req1 := constructSearchRequest(nq)
req2 := constructSearchRequest(nq)
resp, err := proxy.HybridSearch(ctx, constructHybridSearchRequest([]*milvuspb.SearchRequest{req1, req2}))
req := constructAdvancedSearchRequest()
resp, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
nq = 10
nq = 10
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
exprBytes := []byte(expr)

View File

@ -3,7 +3,9 @@ package proxy
import (
"encoding/json"
"fmt"
"math"
"reflect"
"strings"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
@ -13,6 +15,7 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
)
type rankType int
@ -35,16 +38,27 @@ type reScorer interface {
name() string
scorerType() rankType
reScore(input *milvuspb.SearchResults)
setMetricType(metricType string)
getMetricType() string
}
type baseScorer struct {
scorerName string
metricType string
}
func (bs *baseScorer) name() string {
return bs.scorerName
}
func (bs *baseScorer) setMetricType(metricType string) {
bs.metricType = metricType
}
func (bs *baseScorer) getMetricType() string {
return bs.metricType
}
type rrfScorer struct {
baseScorer
k float32
@ -65,9 +79,36 @@ type weightedScorer struct {
weight float32
}
type activateFunc func(float32) float32
func (ws *weightedScorer) getActivateFunc() activateFunc {
mUpper := strings.ToUpper(ws.getMetricType())
isCosine := mUpper == strings.ToUpper(metric.COSINE)
isIP := mUpper == strings.ToUpper(metric.IP)
if isCosine {
f := func(distance float32) float32 {
return (1 + distance) * 0.5
}
return f
}
if isIP {
f := func(distance float32) float32 {
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
}
return f
}
f := func(distance float32) float32 {
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
}
return f
}
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
for i, score := range input.Results.GetScores() {
input.Results.Scores[i] = ws.weight * score
activateF := ws.getActivateFunc()
for i, distance := range input.Results.GetScores() {
input.Results.Scores[i] = ws.weight * activateF(distance)
}
}
@ -75,13 +116,17 @@ func (ws *weightedScorer) scorerType() rankType {
return weightedRankType
}
func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
res := make([]reScorer, len(reqs))
func NewReScorers(reqCnt int, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
if reqCnt == 0 {
return []reScorer{}, nil
}
res := make([]reScorer, reqCnt)
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
if err != nil {
log.Info("rank strategy not specified, use rrf instead")
// if not set rank strategy, use rrf rank as default
for i := range reqs {
for i := 0; i < reqCnt; i++ {
res[i] = &rrfScorer{
baseScorer: baseScorer{
scorerName: "rrf",
@ -123,7 +168,7 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
}
log.Debug("rrf params", zap.Float64("k", k))
for i := range reqs {
for i := 0; i < reqCnt; i++ {
res[i] = &rrfScorer{
baseScorer: baseScorer{
scorerName: "rrf",
@ -156,10 +201,10 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
}
log.Debug("weights params", zap.Any("weights", weights))
if len(reqs) != len(weights) {
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(reqs)), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
if reqCnt != len(weights) {
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
}
for i := range reqs {
for i := 0; i < reqCnt; i++ {
res[i] = &weightedScorer{
baseScorer: baseScorer{
scorerName: "weighted",

View File

@ -7,12 +7,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
func TestRescorer(t *testing.T) {
t.Run("default scorer", func(t *testing.T) {
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, nil)
rescorers, err := NewReScorers(2, nil)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
@ -27,7 +26,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
_, err = NewReScorers(2, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "k not found in rank_params")
})
@ -42,7 +41,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
_, err = NewReScorers(2, rankParams)
assert.Error(t, err)
params[RRFParamsKey] = maxRRFParamsValue + 1
@ -53,7 +52,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
_, err = NewReScorers(2, rankParams)
assert.Error(t, err)
})
@ -67,7 +66,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
rescorers, err := NewReScorers(2, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
@ -83,7 +82,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
_, err = NewReScorers(2, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found in rank_params")
})
@ -99,7 +98,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
_, err = NewReScorers(2, rankParams)
assert.Error(t, err)
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
})
@ -115,7 +114,7 @@ func TestRescorer(t *testing.T) {
{Key: RankParamsKey, Value: string(b)},
}
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
rescorers, err := NewReScorers(2, rankParams)
assert.NoError(t, err)
assert.Equal(t, 2, len(rescorers))
assert.Equal(t, weightedRankType, rescorers[0].scorerType())

View File

@ -3,6 +3,8 @@ package proxy
import (
"context"
"fmt"
"math"
"sort"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
@ -379,3 +381,134 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
}
return ret, nil
}
func rankSearchResultData(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
searchResults []*milvuspb.SearchResults,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset := params.offset
limit := params.limit
topk := limit + offset
roundDecimal := params.roundDecimal
log.Ctx(ctx).Debug("rankSearchResultData",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0),
},
}
default:
return nil, errors.New("unsupported pk type")
}
// []map[id]score
accumulatedScores := make([]map[interface{}]float32, nq)
for i := int64(0); i < nq; i++ {
accumulatedScores[i] = make(map[interface{}]float32)
}
for _, result := range searchResults {
scores := result.GetResults().GetScores()
start := int64(0)
for i := int64(0); i < nq; i++ {
realTopk := result.GetResults().Topks[i]
for j := start; j < start+realTopk; j++ {
id := typeutil.GetPK(result.GetResults().GetIds(), j)
accumulatedScores[i][id] += scores[j]
}
start += realTopk
}
}
for i := int64(0); i < nq; i++ {
idSet := accumulatedScores[i]
keys := make([]interface{}, 0)
for key := range idSet {
keys = append(keys, key)
}
if int64(len(keys)) <= offset {
ret.Results.Topks = append(ret.Results.Topks, 0)
continue
}
compareKeys := func(keyI, keyJ interface{}) bool {
switch keyI.(type) {
case int64:
return keyI.(int64) < keyJ.(int64)
case string:
return keyI.(string) < keyJ.(string)
}
return false
}
// sort id by score
big := func(i, j int) bool {
if idSet[keys[i]] == idSet[keys[j]] {
return compareKeys(keys[i], keys[j])
}
return idSet[keys[i]] > idSet[keys[j]]
}
sort.Slice(keys, big)
if int64(len(keys)) > topk {
keys = keys[:topk]
}
// set real topk
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
// append id and score
for index := offset; index < int64(len(keys)); index++ {
typeutil.AppendPKs(ret.Results.Ids, keys[index])
score := idSet[keys[index]]
if roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Results.Scores = append(ret.Results.Scores, score)
}
}
return ret, nil
}
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
return &milvuspb.SearchResults{
Status: merr.Success("search result is empty"),
Results: &schemapb.SearchResultData{
NumQueries: numQueries,
Topks: make([]int64, numQueries),
},
}
}

View File

@ -3,163 +3,309 @@ package proxy
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func initSearchRequest(ctx context.Context, t *searchTask, isHybrid bool) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
defer sp.End()
type rankParams struct {
limit int64
offset int64
roundDecimal int64
}
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
var ignoreGrowing bool
var err error
for i, kv := range t.request.GetSearchParams() {
if kv.GetKey() == IgnoreGrowingKey {
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
if err != nil {
return errors.New("parse search growing failed")
}
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
break
}
}
t.SearchRequest.IgnoreGrowing = ignoreGrowing
// Manually update nq if not set.
nq, err := getNq(t.request)
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) {
// 1. parse offset and real topk
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
log.Warn("failed to get nq", zap.Error(err))
return err
return nil, 0, errors.New(TopKKey + " not found in search_params")
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateNQLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
log = log.With(zap.Int64("nq", nq))
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
topK, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
log.Warn("fail to get output field ids", zap.Error(err))
return err
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
if err := validateTopKLimit(topK); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil || len(annsFieldName) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return errors.New(AnnsFieldKey + " not found in schema")
}
if enableMultipleVectorFields && len(vecFields) > 1 {
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
if err != nil {
return err
}
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if queryInfo.GetGroupByFieldId() != -1 && isHybrid {
return errors.New("not support search_group_by operation in the hybrid search")
}
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return errors.New("not support search_group_by operation based on binary vector column")
}
t.offset = offset
plan, err := planparserv2.CreateSearchPlan(t.schema.schemaHelper, t.request.Dsl, annsFieldName, queryInfo)
if err != nil {
log.Warn("failed to create query plan", zap.Error(err),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
}
log.Debug("create query plan",
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
if t.partitionKeyMode {
expr, err := ParseExprFromPlan(plan)
var offset int64
if !ignoreOffset {
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
log.Warn("failed to parse expr", zap.Error(err))
return err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
if err != nil {
log.Warn("failed to assign partition keys", zap.Error(err))
return err
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
if len(hashedPartitionNames) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames)
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
if offset != 0 {
if err := validateTopKLimit(offset); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
}
}
}
plan.OutputFieldIds = outputFieldIDs
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.queryInfo = queryInfo
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
if err != nil {
log.Warn("failed to estimate result size", zap.Error(err))
return err
}
if estimateSize >= requeryThreshold {
t.requery = true
plan.OutputFieldIds = nil
}
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
queryTopK := topK + offset
if err := validateTopKLimit(queryTopK); err != nil {
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
}
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
// Set username of this search request for feature like task scheduling.
if username, _ := GetCurUserFromContext(ctx); username != "" {
t.SearchRequest.Username = username
// 2. parse metrics type
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
if err != nil {
metricType = ""
}
return nil
// 3. parse round decimal
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
// 4. parse search param str
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
if err != nil {
searchParamStr = ""
}
err = checkRangeSearchParams(searchParamStr, metricType)
if err != nil {
return nil, 0, err
}
// 5. parse group by field
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
if err != nil {
groupByFieldName = ""
}
var groupByFieldId int64 = -1
if groupByFieldName != "" {
fields := schema.GetFields()
for _, field := range fields {
if field.Name == groupByFieldName {
groupByFieldId = field.FieldID
break
}
}
if groupByFieldId == -1 {
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
}
}
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")
}
return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
}, offset, nil
}
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
for _, name := range outputFields {
id, ok := schema.MapFieldID(name)
if !ok {
return nil, fmt.Errorf("Field %s not exist", name)
}
outputFieldIDs = append(outputFieldIDs, id)
}
return outputFieldIDs, nil
}
func getNqFromSubSearch(req *milvuspb.SubSearchRequest) (int64, error) {
if req.GetNq() == 0 {
// keep compatible with older client version.
x := &commonpb.PlaceholderGroup{}
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
if err != nil {
return 0, err
}
total := int64(0)
for _, h := range x.GetPlaceholders() {
total += int64(len(h.Values))
}
return total, nil
}
return req.GetNq(), nil
}
func getNq(req *milvuspb.SearchRequest) (int64, error) {
if req.GetNq() == 0 {
// keep compatible with older client version.
x := &commonpb.PlaceholderGroup{}
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
if err != nil {
return 0, err
}
total := int64(0)
for _, h := range x.GetPlaceholders() {
total += int64(len(h.Values))
}
return total, nil
}
return req.GetNq(), nil
}
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
for _, tag := range partitionNames {
if err := validatePartitionTag(tag, false); err != nil {
return nil, err
}
}
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, err
}
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
partitionsSet := typeutil.NewSet[int64]()
for _, partitionName := range partitionNames {
if useRegexp {
// Legacy feature, use partition name as regexp
pattern := fmt.Sprintf("^%s$", partitionName)
re, err := regexp.Compile(pattern)
if err != nil {
return nil, fmt.Errorf("invalid partition: %s", partitionName)
}
var found bool
for name, pID := range partitionsMap {
if re.MatchString(name) {
partitionsSet.Insert(pID)
found = true
}
}
if !found {
return nil, fmt.Errorf("partition name %s not found", partitionName)
}
} else {
partitionID, found := partitionsMap[partitionName]
if !found {
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
return nil, fmt.Errorf("partition name %s not found", partitionName)
}
if !partitionsSet.Contain(partitionID) {
partitionsSet.Insert(partitionID)
}
}
}
return partitionsSet.Collect(), nil
}
// parseRankParams get limit and offset from rankParams, both are optional.
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
var (
limit int64
offset int64
roundDecimal int64
err error
)
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
if err != nil {
return nil, errors.New(LimitKey + " not found in rank_params")
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
}
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
}
// validate max result window.
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
return nil, fmt.Errorf("invalid max query result window, %w", err)
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
return &rankParams{
limit: limit,
offset: offset,
roundDecimal: roundDecimal,
}, nil
}
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
ret := &milvuspb.SearchRequest{
Base: req.GetBase(),
DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(),
PartitionNames: req.GetPartitionNames(),
OutputFields: req.GetOutputFields(),
SearchParams: req.GetRankParams(),
TravelTimestamp: req.GetTravelTimestamp(),
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
Nq: 0,
NotReturnAllMeta: req.GetNotReturnAllMeta(),
ConsistencyLevel: req.GetConsistencyLevel(),
UseDefaultConsistency: req.GetUseDefaultConsistency(),
SearchByPrimaryKeys: false,
SubReqs: nil,
}
for _, sub := range req.GetRequests() {
subReq := &milvuspb.SubSearchRequest{
Dsl: sub.GetDsl(),
PlaceholderGroup: sub.GetPlaceholderGroup(),
DslType: sub.GetDslType(),
SearchParams: sub.GetSearchParams(),
Nq: sub.GetNq(),
}
ret.SubReqs = append(ret.SubReqs, subReq)
}
return ret
}

View File

@ -1,659 +0,0 @@
package proxy
import (
"context"
"fmt"
"math"
"sort"
"strconv"
"github.com/cockroachdb/errors"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
const (
HybridSearchTaskName = "HybridSearchTask"
)
type hybridSearchTask struct {
baseTask
Condition
ctx context.Context
*internalpb.HybridSearchRequest
result *milvuspb.SearchResults
request *milvuspb.HybridSearchRequest
searchTasks []*searchTask
tr *timerecord.TimeRecorder
schema *schemaInfo
requery bool
partitionKeyMode bool
userOutputFields []string
qc types.QueryCoordClient
node types.ProxyComponent
lb LBPolicy
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
reScorers []reScorer
queryChannelsTs map[string]Timestamp
rankParams *rankParams
}
func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PreExecute")
defer sp.End()
if len(t.request.Requests) <= 0 {
return errors.New("minimum of ann search requests is 1")
}
if len(t.request.Requests) > defaultMaxSearchRequest {
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
}
for _, req := range t.request.GetRequests() {
nq, err := getNq(req)
if err != nil {
log.Debug("failed to get nq", zap.Error(err))
return err
}
if nq != 1 {
err = merr.WrapErrParameterInvalid("1", fmt.Sprint(nq), "nq should be equal to 1")
log.Debug(err.Error())
return err
}
}
t.Base.MsgType = commonpb.MsgType_Search
t.Base.SourceID = paramtable.GetNodeID()
collectionName := t.request.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
if err != nil {
return err
}
t.CollectionID = collID
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("is partition key mode failed", zap.Error(err))
return err
}
if t.partitionKeyMode {
if len(t.request.GetPartitionNames()) != 0 {
return errors.New("not support manually specifying the partition names if partition key mode is used")
}
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
}
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
}
}
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
if err != nil {
log.Warn("translate output fields failed", zap.Error(err))
return err
}
log.Debug("translate output fields",
zap.Strings("output fields", t.request.GetOutputFields()))
if len(t.request.OutputFields) > 0 {
t.requery = true
}
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil {
log.Warn("Proxy::hybridSearchTask::PreExecute failed to GetCollectionInfo from cache",
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
return err2
}
guaranteeTs := t.request.GetGuaranteeTimestamp()
var consistencyLevel commonpb.ConsistencyLevel
useDefaultConsistency := t.request.GetUseDefaultConsistency()
if useDefaultConsistency {
consistencyLevel = collectionInfo.consistencyLevel
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
} else {
consistencyLevel = t.request.GetConsistencyLevel()
// Compatibility logic, parse guarantee timestamp
if consistencyLevel == 0 && guaranteeTs > 0 {
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
} else {
// parse from guarantee timestamp and user input consistency level
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
}
}
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
return err
}
t.HybridSearchRequest.GuaranteeTimestamp = guaranteeTs
t.searchTasks = make([]*searchTask, len(t.request.GetRequests()))
for index := range t.request.Requests {
searchReq := t.request.Requests[index]
if len(searchReq.GetCollectionName()) == 0 {
searchReq.CollectionName = t.request.GetCollectionName()
} else if searchReq.GetCollectionName() != t.request.GetCollectionName() {
return errors.New(fmt.Sprintf("inconsistent collection name in hybrid search request, "+
"expect %s, actual %s", searchReq.GetCollectionName(), t.request.GetCollectionName()))
}
searchReq.PartitionNames = t.request.GetPartitionNames()
searchReq.ConsistencyLevel = consistencyLevel
searchReq.GuaranteeTimestamp = guaranteeTs
searchReq.UseDefaultConsistency = useDefaultConsistency
searchReq.OutputFields = nil
t.searchTasks[index] = &searchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
collectionName: collectionName,
SearchRequest: &internalpb.SearchRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
DbID: 0, // todo
CollectionID: collID,
PartitionIDs: t.GetPartitionIDs(),
},
request: searchReq,
schema: t.schema,
tr: timerecord.NewTimeRecorder("hybrid search"),
qc: t.qc,
node: t.node,
lb: t.lb,
partitionKeyMode: t.partitionKeyMode,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
}
err := initSearchRequest(ctx, t.searchTasks[index], true)
if err != nil {
log.Debug("init hybrid search request failed", zap.Error(err))
return err
}
if t.partitionKeyMode {
t.partitionIDsSet.Upsert(t.searchTasks[index].GetPartitionIDs()...)
}
}
log.Debug("hybrid search preExecute done.",
zap.Uint64("guarantee_ts", guaranteeTs),
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
zap.Any("consistency level", t.request.GetConsistencyLevel()))
return nil
}
func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
hybridSearchReq.GetBase().TargetID = nodeID
if t.partitionKeyMode {
t.PartitionIDs = t.partitionIDsSet.Collect()
}
req := &querypb.HybridSearchRequest{
Req: hybridSearchReq,
DmlChannels: []string{channel},
TotalChannelNum: int32(1),
}
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int64("nodeID", nodeID),
zap.String("channel", channel))
var result *querypb.HybridSearchResult
var err error
result, err = qn.HybridSearch(ctx, req)
if err != nil {
log.Warn("QueryNode hybrid search return error", zap.Error(err))
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.request.GetCollectionName())
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader")
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.request.GetCollectionName())
return errInvalidShardLeaders
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode hybrid search result error",
zap.String("reason", result.GetStatus().GetReason()))
return errors.Wrapf(merr.Error(result.GetStatus()), "fail to hybrid search on QueryNode %d", nodeID)
}
t.resultBuf.Insert(result)
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
return nil
}
func (t *hybridSearchTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
defer tr.CtxElapse(ctx, "done")
for _, searchTask := range t.searchTasks {
t.HybridSearchRequest.Reqs = append(t.HybridSearchRequest.Reqs, searchTask.SearchRequest)
}
t.resultBuf = typeutil.NewConcurrentSet[*querypb.HybridSearchResult]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collectionID: t.CollectionID,
collectionName: t.request.GetCollectionName(),
nq: 1,
exec: t.hybridSearchShard,
})
if err != nil {
log.Warn("hybrid search execute failed", zap.Error(err))
return errors.Wrap(err, "failed to hybrid search")
}
log.Debug("hybrid search execute done.")
return nil
}
type rankParams struct {
limit int64
offset int64
roundDecimal int64
}
// parseRankParams get limit and offset from rankParams, both are optional.
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
var (
limit int64
offset int64
roundDecimal int64
err error
)
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
if err != nil {
return nil, errors.New(LimitKey + " not found in rank_params")
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
}
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
}
// validate max result window.
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
return nil, fmt.Errorf("invalid max query result window, %w", err)
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
return &rankParams{
limit: limit,
offset: offset,
roundDecimal: roundDecimal,
}, nil
}
func (t *hybridSearchTask) collectHybridSearchResults(ctx context.Context) error {
select {
case <-t.TraceCtx().Done():
log.Ctx(ctx).Warn("hybrid search task wait to finish timeout!")
return fmt.Errorf("hybrid search task wait to finish timeout, msgID=%d", t.ID())
default:
log.Ctx(ctx).Debug("all hybrid searches are finished or canceled")
t.resultBuf.Range(func(res *querypb.HybridSearchResult) bool {
for index, searchResult := range res.GetResults() {
t.searchTasks[index].resultBuf.Insert(searchResult)
}
log.Ctx(ctx).Debug("proxy receives one hybrid search result",
zap.Int64("sourceID", res.GetBase().GetSourceID()))
return true
})
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
for i, searchTask := range t.searchTasks {
err := searchTask.PostExecute(ctx)
if err != nil {
return err
}
t.reScorers[i].reScore(searchTask.result)
t.multipleRecallResults.Insert(searchTask.result)
}
return nil
}
}
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
defer func() {
tr.CtxElapse(ctx, "done")
}()
err := t.collectHybridSearchResults(ctx)
if err != nil {
log.Warn("failed to collect hybrid search results", zap.Error(err))
return err
}
metricType := ""
t.queryChannelsTs = make(map[string]uint64)
for _, r := range t.resultBuf.Collect() {
metricType = r.GetResults()[0].GetMetricType()
for ch, ts := range r.GetChannelsMvcc() {
t.queryChannelsTs[ch] = ts
}
}
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
t.rankParams, err = parseRankParams(t.request.GetRankParams())
if err != nil {
return err
}
t.result, err = rankSearchResultData(ctx, 1,
t.rankParams,
primaryFieldSchema.GetDataType(),
metricType,
t.multipleRecallResults.Collect())
if err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
t.result.CollectionName = t.request.GetCollectionName()
t.fillInFieldInfo()
if t.requery {
err := t.Requery()
if err != nil {
log.Warn("failed to requery", zap.Error(err))
return err
}
}
t.result.Results.OutputFields = t.userOutputFields
log.Debug("hybrid search post execute done")
return nil
}
func (t *hybridSearchTask) Requery() error {
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
Expr: "",
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
TravelTimestamp: t.request.GetTravelTimestamp(),
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
ConsistencyLevel: t.request.GetConsistencyLevel(),
UseDefaultConsistency: t.request.GetUseDefaultConsistency(),
QueryParams: []*commonpb.KeyValuePair{
{
Key: LimitKey,
Value: strconv.FormatInt(t.rankParams.limit, 10),
},
},
}
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
}
func rankSearchResultData(ctx context.Context,
nq int64,
params *rankParams,
pkType schemapb.DataType,
metricType string,
searchResults []*milvuspb.SearchResults,
) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("rankSearchResultData")
defer func() {
tr.CtxElapse(ctx, "done")
}()
offset := params.offset
limit := params.limit
topk := limit + offset
roundDecimal := params.roundDecimal
log.Ctx(ctx).Debug("rankSearchResultData",
zap.Int("len(searchResults)", len(searchResults)),
zap.Int64("nq", nq),
zap.Int64("offset", offset),
zap.Int64("limit", limit),
zap.String("metric type", metricType))
ret := &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: nq,
TopK: limit,
FieldsData: make([]*schemapb.FieldData, 0),
Scores: []float32{},
Ids: &schemapb.IDs{},
Topks: []int64{},
},
}
switch pkType {
case schemapb.DataType_Int64:
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, 0),
},
}
case schemapb.DataType_VarChar:
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: make([]string, 0),
},
}
default:
return nil, errors.New("unsupported pk type")
}
// []map[id]score
accumulatedScores := make([]map[interface{}]float32, nq)
for i := int64(0); i < nq; i++ {
accumulatedScores[i] = make(map[interface{}]float32)
}
for _, result := range searchResults {
scores := result.GetResults().GetScores()
start := int64(0)
for i := int64(0); i < nq; i++ {
realTopk := result.GetResults().Topks[i]
for j := start; j < start+realTopk; j++ {
id := typeutil.GetPK(result.GetResults().GetIds(), j)
accumulatedScores[i][id] += scores[j]
}
start += realTopk
}
}
for i := int64(0); i < nq; i++ {
idSet := accumulatedScores[i]
keys := make([]interface{}, 0)
for key := range idSet {
keys = append(keys, key)
}
if int64(len(keys)) <= offset {
ret.Results.Topks = append(ret.Results.Topks, 0)
continue
}
compareKeys := func(keyI, keyJ interface{}) bool {
switch keyI.(type) {
case int64:
return keyI.(int64) < keyJ.(int64)
case string:
return keyI.(string) < keyJ.(string)
}
return false
}
// sort id by score
var less func(i, j int) bool
if metric.PositivelyRelated(metricType) {
less = func(i, j int) bool {
if idSet[keys[i]] == idSet[keys[j]] {
return compareKeys(keys[i], keys[j])
}
return idSet[keys[i]] > idSet[keys[j]]
}
} else {
less = func(i, j int) bool {
if idSet[keys[i]] == idSet[keys[j]] {
return compareKeys(keys[i], keys[j])
}
return idSet[keys[i]] < idSet[keys[j]]
}
}
sort.Slice(keys, less)
if int64(len(keys)) > topk {
keys = keys[:topk]
}
// set real topk
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
// append id and score
for index := offset; index < int64(len(keys)); index++ {
typeutil.AppendPKs(ret.Results.Ids, keys[index])
score := idSet[keys[index]]
if roundDecimal != -1 {
multiplier := math.Pow(10.0, float64(roundDecimal))
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
}
ret.Results.Scores = append(ret.Results.Scores, score)
}
}
return ret, nil
}
func (t *hybridSearchTask) fillInFieldInfo() {
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
for i, name := range t.request.OutputFields {
for _, field := range t.schema.Fields {
if t.result.Results.FieldsData[i] != nil && field.Name == name {
t.result.Results.FieldsData[i].FieldName = field.Name
t.result.Results.FieldsData[i].FieldId = field.FieldID
t.result.Results.FieldsData[i].Type = field.DataType
t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic
}
}
}
}
}
func (t *hybridSearchTask) TraceCtx() context.Context {
return t.ctx
}
func (t *hybridSearchTask) ID() UniqueID {
return t.Base.MsgID
}
func (t *hybridSearchTask) SetID(uid UniqueID) {
t.Base.MsgID = uid
}
func (t *hybridSearchTask) Name() string {
return HybridSearchTaskName
}
func (t *hybridSearchTask) Type() commonpb.MsgType {
return t.Base.MsgType
}
func (t *hybridSearchTask) BeginTs() Timestamp {
return t.Base.Timestamp
}
func (t *hybridSearchTask) EndTs() Timestamp {
return t.Base.Timestamp
}
func (t *hybridSearchTask) SetTs(ts Timestamp) {
t.Base.Timestamp = ts
}
func (t *hybridSearchTask) OnEnqueue() error {
t.Base = commonpbutil.NewMsgBase()
t.Base.MsgType = commonpb.MsgType_Search
t.Base.SourceID = paramtable.GetNodeID()
return nil
}

View File

@ -1,363 +0,0 @@
package proxy
import (
"context"
"strconv"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func createCollWithMultiVecField(t *testing.T, name string, rc types.RootCoordClient) {
schema := genCollectionSchema(name)
marshaledSchema, err := proto.Marshal(schema)
require.NoError(t, err)
ctx := context.TODO()
createColT := &createCollectionTask{
Condition: NewTaskCondition(context.TODO()),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: name,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
},
ctx: ctx,
rootCoord: rc,
}
require.NoError(t, createColT.OnEnqueue())
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
}
func TestHybridSearchTask_PreExecute(t *testing.T) {
var err error
var (
rc = NewRootCoordMock()
qc = mocks.NewMockQueryCoordClient(t)
ctx = context.TODO()
)
defer rc.Close()
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, rc, qc, mgr)
require.NoError(t, err)
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
task := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
HybridSearchRequest: &internalpb.HybridSearchRequest{},
request: &milvuspb.HybridSearchRequest{
CollectionName: collName,
Requests: reqs,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-hybrid-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("bad nq 0", func(t *testing.T) {
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
// Nq must be 1.
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 0}})
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("bad req num 0", func(t *testing.T) {
collName := "test_bad_req_num0_error" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
// num of reqs must be [1, 1024].
task := genHybridSearchTaskWithNq(t, collName, nil)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("bad req num 1025", func(t *testing.T) {
collName := "test_bad_req_num1025_error" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
// num of reqs must be [1, 1024].
reqs := make([]*milvuspb.SearchRequest, 0)
for i := 0; i <= defaultMaxSearchRequest; i++ {
reqs = append(reqs, &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
})
}
task := genHybridSearchTaskWithNq(t, collName, reqs)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("collection not exist", func(t *testing.T) {
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("hybrid search with timeout", func(t *testing.T) {
collName := "hybrid_search_with_timeout" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
task.ctx = ctxTimeout
task.request.OutputFields = []string{testFloatVecField}
assert.NoError(t, task.PreExecute(ctx))
})
t.Run("hybrid search with group_by", func(t *testing.T) {
collName := "hybrid_search_with_group_by" + funcutil.GenRandomStr()
createCollWithMultiVecField(t, collName, rc)
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{
{Nq: 1, DslType: commonpb.DslType_BoolExprV1, SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: "fvec"},
{Key: TopKKey, Value: "10"},
{Key: GroupByFieldKey, Value: "bool"},
}},
})
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
task.ctx = ctxTimeout
task.request.OutputFields = []string{testFloatVecField}
err := task.PreExecute(ctx)
assert.Error(t, err)
assert.Equal(t, "not support search_group_by operation in the hybrid search", err.Error())
})
}
func TestHybridSearchTask_ErrExecute(t *testing.T) {
var (
err error
ctx = context.TODO()
rc = NewRootCoordMock()
qc = getQueryCoordClient()
qn = getQueryNodeClient()
collectionName = t.Name() + funcutil.GenRandomStr()
)
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
mgr := NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
lb := NewLBPolicyImpl(mgr)
factory := dependency.NewDefaultFactory(true)
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
assert.NoError(t, err)
node.sched = scheduler
err = node.sched.Start()
assert.NoError(t, err)
err = node.initRateCollector()
assert.NoError(t, err)
node.rootCoord = rc
node.queryCoord = qc
defer qc.Close()
err = InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
createCollWithMultiVecField(t, collectionName, rc)
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
Status: successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1},
NodeAddrs: []string{"localhost:9000"},
},
},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: successStatus,
CollectionIDs: []int64{collectionID},
InMemoryPercentages: []int64{100},
}, nil)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
SourceID: paramtable.GetNodeID(),
},
CollectionID: collectionID,
})
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
vectorFields := typeutil.GetVectorFieldSchemas(schema.CollectionSchema)
vectorFieldNames := make([]string, len(vectorFields))
for i, field := range vectorFields {
vectorFieldNames[i] = field.GetName()
}
// test begins
task := &hybridSearchTask{
Condition: NewTaskCondition(ctx),
ctx: ctx,
result: &milvuspb.SearchResults{
Status: merr.Success(),
},
HybridSearchRequest: &internalpb.HybridSearchRequest{},
request: &milvuspb.HybridSearchRequest{
CollectionName: collectionName,
Requests: []*milvuspb.SearchRequest{
{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Nq: 1,
DslType: commonpb.DslType_BoolExprV1,
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: testFloatVecField},
{Key: TopKKey, Value: "10"},
},
},
{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Nq: 1,
DslType: commonpb.DslType_BoolExprV1,
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: testBinaryVecField},
{Key: TopKKey, Value: "10"},
},
},
},
OutputFields: vectorFieldNames,
},
qc: qc,
lb: lb,
node: node,
}
assert.NoError(t, task.OnEnqueue())
task.ctx = ctx
assert.NoError(t, task.PreExecute(ctx))
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}, nil)
assert.Error(t, task.Execute(ctx))
}
func TestHybridSearchTask_PostExecute(t *testing.T) {
var (
rc = NewRootCoordMock()
qc = getQueryCoordClient()
qn = getQueryNodeClient()
collectionName = t.Name() + funcutil.GenRandomStr()
)
defer rc.Close()
mgr := NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
Base: commonpbutil.NewMsgBase(),
Status: merr.Success(),
}, nil)
t.Run("Test empty result", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
createCollWithMultiVecField(t, collectionName, rc)
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
rankParams := []*commonpb.KeyValuePair{
{Key: LimitKey, Value: strconv.Itoa(3)},
{Key: OffsetKey, Value: strconv.Itoa(2)},
}
qt := &hybridSearchTask{
ctx: ctx,
Condition: NewTaskCondition(context.TODO()),
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
schema: schema,
HybridSearchRequest: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
request: &milvuspb.HybridSearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
},
CollectionName: collectionName,
RankParams: rankParams,
},
resultBuf: typeutil.NewConcurrentSet[*querypb.HybridSearchResult](),
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
}
err = qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
})
}

View File

@ -5,9 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
@ -23,7 +21,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
@ -32,6 +29,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -50,8 +48,8 @@ const (
type searchTask struct {
Condition
*internalpb.SearchRequest
ctx context.Context
*internalpb.SearchRequest
result *milvuspb.SearchResults
request *milvuspb.SearchRequest
@ -64,196 +62,18 @@ type searchTask struct {
userOutputFields []string
offset int64
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
qc types.QueryCoordClient
node types.ProxyComponent
lb LBPolicy
queryChannelsTs map[string]Timestamp
queryInfo *planpb.QueryInfo
}
queryInfos []*planpb.QueryInfo
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
for _, tag := range partitionNames {
if err := validatePartitionTag(tag, false); err != nil {
return nil, err
}
}
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, err
}
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
partitionsSet := typeutil.NewSet[int64]()
for _, partitionName := range partitionNames {
if useRegexp {
// Legacy feature, use partition name as regexp
pattern := fmt.Sprintf("^%s$", partitionName)
re, err := regexp.Compile(pattern)
if err != nil {
return nil, fmt.Errorf("invalid partition: %s", partitionName)
}
var found bool
for name, pID := range partitionsMap {
if re.MatchString(name) {
partitionsSet.Insert(pID)
found = true
}
}
if !found {
return nil, fmt.Errorf("partition name %s not found", partitionName)
}
} else {
partitionID, found := partitionsMap[partitionName]
if !found {
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
return nil, fmt.Errorf("partition name %s not found", partitionName)
}
if !partitionsSet.Contain(partitionID) {
partitionsSet.Insert(partitionID)
}
}
}
return partitionsSet.Collect(), nil
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*planpb.QueryInfo, int64, error) {
// 1. parse offset and real topk
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
return nil, 0, errors.New(TopKKey + " not found in search_params")
}
topK, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
if err := validateTopKLimit(topK); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
}
var offset int64
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
if offset != 0 {
if err := validateTopKLimit(offset); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
}
}
}
queryTopK := topK + offset
if err := validateTopKLimit(queryTopK); err != nil {
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
}
// 2. parse metrics type
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
if err != nil {
metricType = ""
}
// 3. parse round decimal
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
// 4. parse search param str
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
if err != nil {
searchParamStr = ""
}
err = checkRangeSearchParams(searchParamStr, metricType)
if err != nil {
return nil, 0, err
}
// 5. parse group by field
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
if err != nil {
groupByFieldName = ""
}
var groupByFieldId int64 = -1
if groupByFieldName != "" {
fields := schema.GetFields()
for _, field := range fields {
if field.Name == groupByFieldName {
groupByFieldId = field.FieldID
break
}
}
if groupByFieldId == -1 {
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
}
}
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")
}
return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
}, offset, nil
}
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
for _, name := range outputFields {
id, ok := schema.MapFieldID(name)
if !ok {
return nil, fmt.Errorf("Field %s not exist", name)
}
outputFieldIDs = append(outputFieldIDs, id)
}
return outputFieldIDs, nil
}
func getNq(req *milvuspb.SearchRequest) (int64, error) {
if req.GetNq() == 0 {
// keep compatible with older client version.
x := &commonpb.PlaceholderGroup{}
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
if err != nil {
return 0, err
}
total := int64(0)
for _, h := range x.GetPlaceholders() {
total += int64(len(h.Values))
}
return total, nil
}
return req.GetNq(), nil
reScorers []reScorer
rankParams *rankParams
}
func (t *searchTask) CanSkipAllocTimestamp() bool {
@ -284,7 +104,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool {
func (t *searchTask) PreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PreExecute")
defer sp.End()
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
t.Base.MsgType = commonpb.MsgType_Search
t.Base.SourceID = paramtable.GetNodeID()
@ -315,7 +135,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
@ -330,7 +150,59 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
log.Debug("translate output fields",
zap.Strings("output fields", t.request.GetOutputFields()))
err = initSearchRequest(ctx, t, false)
if t.SearchRequest.GetIsAdvanced() {
if len(t.request.GetSubReqs()) > defaultMaxSearchRequest {
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
}
}
if t.SearchRequest.GetIsAdvanced() {
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
if err != nil {
return err
}
}
// Manually update nq if not set.
nq, err := t.checkNq(ctx)
if err != nil {
log.Info("failed to check nq", zap.Error(err))
return err
}
t.SearchRequest.Nq = nq
var ignoreGrowing bool
// parse common search params
for i, kv := range t.request.GetSearchParams() {
if kv.GetKey() == IgnoreGrowingKey {
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
if err != nil {
return errors.New("parse search growing failed")
}
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
break
}
}
t.SearchRequest.IgnoreGrowing = ignoreGrowing
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
if err != nil {
log.Info("fail to get output field ids", zap.Error(err))
return err
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
// Currently, we get vectors by requery. Once we support getting vectors from search,
// searches with small result size could no longer need requery.
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
if t.SearchRequest.GetIsAdvanced() {
t.requery = len(t.request.OutputFields) > 0
err = t.initAdvancedSearchRequest(ctx)
} else {
t.requery = len(vectorOutputFields) > 0
err = t.initSearchRequest(ctx)
}
if err != nil {
log.Debug("init search request failed", zap.Error(err))
return err
@ -359,6 +231,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
}
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
t.SearchRequest.ConsistencyLevel = consistencyLevel
if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
// Set username of this search request for feature like task scheduling.
if username, _ := GetCurUserFromContext(ctx); username != "" {
t.SearchRequest.Username = username
}
log.Debug("search PreExecute done.",
zap.Uint64("guarantee_ts", guaranteeTs),
@ -368,16 +250,229 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *searchTask) checkNq(ctx context.Context) (int64, error) {
var nq int64
if t.SearchRequest.GetIsAdvanced() {
// In the context of Advanced Search, it is essential to verify that the number of vectors
// for each individual search, denoted as nq, remains consistent.
nq = t.request.GetNq()
for _, req := range t.request.GetSubReqs() {
subNq, err := getNqFromSubSearch(req)
if err != nil {
return 0, err
}
req.Nq = subNq
if nq == 0 {
nq = subNq
continue
}
if subNq != nq {
err = merr.WrapErrParameterInvalid(nq, subNq, "sub search request nq should be the same")
return 0, err
}
}
t.request.Nq = nq
} else {
var err error
nq, err = getNq(t.request)
if err != nil {
return 0, err
}
t.request.Nq = nq
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateNQLimit(nq); err != nil {
return 0, fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
return nq, nil
}
func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request")
defer sp.End()
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
for index, subReq := range t.request.GetSubReqs() {
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), true)
if err != nil {
return err
}
if queryInfo.GetGroupByFieldId() != -1 {
return errors.New("not support search_group_by operation in the hybrid search")
}
internalSubReq := &internalpb.SubSearchRequest{
Dsl: subReq.GetDsl(),
PlaceholderGroup: subReq.GetPlaceholderGroup(),
DslType: subReq.GetDslType(),
SerializedExprPlan: nil,
Nq: subReq.GetNq(),
PartitionIDs: nil,
Topk: queryInfo.GetTopk(),
Offset: offset,
}
// set PartitionIDs for sub search
if t.partitionKeyMode {
partitionIDs, err2 := t.tryParsePartitionIDsFromPlan(plan)
if err2 != nil {
return err2
}
if len(partitionIDs) > 0 {
internalSubReq.PartitionIDs = partitionIDs
t.partitionIDsSet.Upsert(partitionIDs...)
}
} else {
internalSubReq.PartitionIDs = t.SearchRequest.GetPartitionIDs()
}
if t.requery {
plan.OutputFieldIds = nil
} else {
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
}
internalSubReq.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
t.SearchRequest.SubReqs[index] = internalSubReq
t.queryInfos[index] = queryInfo
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
var err error
t.reScorers, err = NewReScorers(len(t.request.GetSubReqs()), t.request.GetSearchParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
return err
}
return nil
}
func (t *searchTask) initSearchRequest(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), false)
if err != nil {
return err
}
t.SearchRequest.Offset = offset
if t.partitionKeyMode {
partitionIDs, err2 := t.tryParsePartitionIDsFromPlan(plan)
if err2 != nil {
return err2
}
if len(partitionIDs) > 0 {
t.SearchRequest.PartitionIDs = partitionIDs
}
}
if t.requery {
plan.OutputFieldIds = nil
} else {
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
}
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.queryInfos = append(t.queryInfos, queryInfo)
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.
return nil
}
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, ignoreOffset bool) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
if err != nil || len(annsFieldName) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return nil, nil, 0, errors.New(AnnsFieldKey + " not found in schema")
}
if enableMultipleVectorFields && len(vecFields) > 1 {
return nil, nil, 0, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, ignoreOffset)
if parseErr != nil {
return nil, nil, 0, parseErr
}
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
}
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
if planErr != nil {
log.Warn("failed to create query plan", zap.Error(planErr),
zap.String("dsl", dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
}
log.Debug("create query plan",
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return plan, queryInfo, offset, nil
}
func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) {
expr, err := ParseExprFromPlan(plan)
if err != nil {
log.Warn("failed to parse expr", zap.Error(err))
return nil, err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(t.ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
if err != nil {
log.Warn("failed to assign partition keys", zap.Error(err))
return nil, err
}
if len(hashedPartitionNames) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
PartitionIDs, err2 := getPartitionIDs(t.ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames)
if err2 != nil {
log.Warn("failed to get partition ids", zap.Error(err2))
return nil, err2
}
return PartitionIDs, nil
}
return nil, nil
}
func (t *searchTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.Int64("nq", t.SearchRequest.GetNq()))
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
defer tr.CtxElapse(ctx, "done")
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collectionID: t.SearchRequest.CollectionID,
@ -396,6 +491,43 @@ func (t *searchTask) Execute(ctx context.Context) error {
return nil
}
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo) (*milvuspb.SearchResults, error) {
metricType := ""
if len(toReduceResults) >= 1 {
metricType = toReduceResults[0].GetMetricType()
}
// Decode all search results
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
if err != nil {
log.Warn("failed to decode search results", zap.Error(err))
return nil, err
}
if len(validSearchResults) <= 0 {
return fillInEmptyResult(nq), nil
}
// Reduce all search results
log.Debug("proxy search post execute reduce",
zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int("number of valid search results", len(validSearchResults)))
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return nil, err
}
var result *milvuspb.SearchResults
result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, nq, topK,
metricType, primaryFieldSchema.DataType, offset, queryInfo))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return nil, err
}
return result, nil
}
func (t *searchTask) PostExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute")
defer sp.End()
@ -406,11 +538,6 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}()
log := log.Ctx(ctx).With(zap.Int64("nq", t.SearchRequest.GetNq()))
var (
Nq = t.SearchRequest.GetNq()
Topk = t.SearchRequest.GetTopk()
MetricType = t.SearchRequest.GetMetricType()
)
toReduceResults, err := t.collectSearchResults(ctx)
if err != nil {
log.Warn("failed to collect search results", zap.Error(err))
@ -424,45 +551,65 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
}
if len(toReduceResults) >= 1 {
MetricType = toReduceResults[0].GetMetricType()
}
// Decode all search results
tr.CtxRecord(ctx, "decodeResultStart")
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
if err != nil {
log.Warn("failed to decode search results", zap.Error(err))
return err
}
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
if len(validSearchResults) <= 0 {
t.fillInEmptyResult(Nq)
return nil
}
// Reduce all search results
log.Debug("proxy search post execute reduce",
zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int("number of valid search results", len(validSearchResults)))
tr.CtxRecord(ctx, "reduceResultStart")
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
t.result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, Nq, Topk,
MetricType, primaryFieldSchema.DataType, t.offset, t.queryInfo))
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return err
}
if t.SearchRequest.GetIsAdvanced() {
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for _, searchResult := range toReduceResults {
// if get a non-advanced result, skip all
if !searchResult.GetIsAdvanced() {
continue
}
for _, subResult := range searchResult.GetSubResults() {
// swallow copy
internalResults := &internalpb.SearchResults{
MetricType: subResult.GetMetricType(),
NumQueries: subResult.GetNumQueries(),
TopK: subResult.GetTopK(),
SlicedBlob: subResult.GetSlicedBlob(),
SlicedNumCount: subResult.GetSlicedNumCount(),
SlicedOffset: subResult.GetSlicedOffset(),
IsAdvanced: false,
}
reqIndex := subResult.GetReqIndex()
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
}
}
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index]
metricType := ""
if len(internalResults) >= 1 {
metricType = internalResults[0].GetMetricType()
}
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index])
if err != nil {
return err
}
t.reScorers[index].setMetricType(metricType)
t.reScorers[index].reScore(result)
multipleMilvusResults[index] = result
}
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
t.rankParams,
primaryFieldSchema.GetDataType(),
multipleMilvusResults)
if err != nil {
log.Warn("rank search result failed", zap.Error(err))
return err
}
} else {
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0])
if err != nil {
return err
}
}
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()
@ -475,6 +622,9 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
}
t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName()
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
log.Debug("Search post execute done",
zap.Int64("collection", t.GetCollectionID()),
@ -550,29 +700,30 @@ func (t *searchTask) Requery() error {
MsgType: commonpb.MsgType_Retrieve,
Timestamp: t.BeginTs(),
},
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
Expr: "",
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
QueryParams: t.request.GetSearchParams(),
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
Expr: "",
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
UseDefaultConsistency: false,
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
}
if t.SearchRequest.GetIsAdvanced() {
queryReq.QueryParams = []*commonpb.KeyValuePair{
{
Key: LimitKey,
Value: strconv.FormatInt(t.rankParams.limit, 10),
},
}
} else {
queryReq.QueryParams = t.request.GetSearchParams()
}
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
}
func (t *searchTask) fillInEmptyResult(numQueries int64) {
t.result = &milvuspb.SearchResults{
Status: merr.Success("search result is empty"),
CollectionName: t.collectionName,
Results: &schemapb.SearchResultData{
NumQueries: numQueries,
Topks: make([]int64, numQueries),
},
}
}
func (t *searchTask) fillInFieldInfo() {
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
for i, name := range t.request.OutputFields {

View File

@ -49,28 +49,46 @@ import (
)
func TestSearchTask_PostExecute(t *testing.T) {
t.Run("Test empty result", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var err error
qt := &searchTask{
ctx: ctx,
Condition: NewTaskCondition(context.TODO()),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
var (
rc = NewRootCoordMock()
qc = mocks.NewMockQueryCoordClient(t)
ctx = context.TODO()
)
defer rc.Close()
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, rc, qc, mgr)
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string) *searchTask {
task := &searchTask{
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
SearchParams: getBaseSearchParams(),
},
request: nil,
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
resultBuf: &typeutil.ConcurrentSet[*internalpb.SearchResults]{},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
// no result
qt.resultBuf.Insert(&internalpb.SearchResults{})
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("Test empty result", func(t *testing.T) {
collName := "test_collection_empty_result" + funcutil.GenRandomStr()
createColl(t, collName, rc)
qt := getSearchTask(t, collName)
err = qt.PreExecute(qt.ctx)
assert.NoError(t, err)
qt.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
qt.resultBuf.Insert(&internalpb.SearchResults{})
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
@ -1988,7 +2006,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams, nil)
info, offset, err := parseSearchInfo(test.validParams, nil, false)
assert.NoError(t, err)
assert.NotNil(t, info)
if test.description == "offsetParam" {
@ -2077,7 +2095,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams, nil)
info, offset, err := parseSearchInfo(test.invalidParams, nil, false)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)
@ -2155,7 +2173,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, _, err := parseSearchInfo(test.validParams, nil)
info, _, err := parseSearchInfo(test.validParams, nil, false)
assert.NoError(t, err)
assert.NotNil(t, info)
})
@ -2171,7 +2189,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, _, err := parseSearchInfo(test.validParams, nil)
info, _, err := parseSearchInfo(test.validParams, nil, false)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Nil(t, info)
})
@ -2189,7 +2207,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, _, err := parseSearchInfo(test.validParams, nil)
info, _, err := parseSearchInfo(test.validParams, nil, false)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Nil(t, info)
})
@ -2213,7 +2231,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema)
info, _, err := parseSearchInfo(normalParam, schema, false)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
@ -2232,7 +2250,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema)
info, _, err := parseSearchInfo(normalParam, schema, false)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
@ -2504,9 +2522,9 @@ func TestSearchTask_Requery(t *testing.T) {
qt.resultBuf.Insert(&internalpb.SearchResults{
SlicedBlob: bytes,
})
qt.queryInfo = &planpb.QueryInfo{
qt.queryInfos = []*planpb.QueryInfo{{
GroupByFieldId: -1,
}
}}
err = qt.PostExecute(ctx)
t.Logf("err = %s", err)
assert.Error(t, err)

View File

@ -469,61 +469,6 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(con
return _c
}
// HybridSearch provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
ret := _m.Called(_a0, _a1)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockQueryNodeServer_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.HybridSearchRequest
func (_e *MockQueryNodeServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_HybridSearch_Call {
return &MockQueryNodeServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
}
func (_c *MockQueryNodeServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNodeServer_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
})
return _c
}
func (_c *MockQueryNodeServer_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeServer_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockQueryNodeServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNodeServer_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// LoadPartitions provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)

View File

@ -62,7 +62,6 @@ type ShardDelegator interface {
GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry)
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
@ -267,115 +266,81 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return funcutil.SliceContain(existPartitions, segment.PartitionID)
})
if req.GetReq().GetIsAdvanced() {
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs()))
for index, subReq := range req.GetReq().GetSubReqs() {
newRequest := &internalpb.SearchRequest{
Base: req.GetReq().GetBase(),
ReqID: req.GetReq().GetReqID(),
DbID: req.GetReq().GetDbID(),
CollectionID: req.GetReq().GetCollectionID(),
PartitionIDs: subReq.GetPartitionIDs(),
Dsl: subReq.GetDsl(),
PlaceholderGroup: subReq.GetPlaceholderGroup(),
DslType: subReq.GetDslType(),
SerializedExprPlan: subReq.GetSerializedExprPlan(),
OutputFieldsId: req.GetReq().GetOutputFieldsId(),
MvccTimestamp: req.GetReq().GetMvccTimestamp(),
GuaranteeTimestamp: req.GetReq().GetGuaranteeTimestamp(),
TimeoutTimestamp: req.GetReq().GetTimeoutTimestamp(),
Nq: subReq.GetNq(),
Topk: subReq.GetTopk(),
MetricType: subReq.GetMetricType(),
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
Username: req.GetReq().GetUsername(),
IsAdvanced: false,
}
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
Req: newRequest,
DmlChannels: req.GetDmlChannels(),
TotalChannelNum: req.GetTotalChannelNum(),
FromShardLeader: true,
}
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
if searchReq.GetReq().GetMvccTimestamp() == 0 {
searchReq.GetReq().MvccTimestamp = tSafe
}
results, err := sd.search(ctx, searchReq, sealed, growing)
if err != nil {
return nil, err
}
return segments.ReduceSearchResults(ctx,
results,
searchReq.Req.GetNq(),
searchReq.Req.GetTopk(),
searchReq.Req.GetMetricType())
})
futures[index] = future
}
err = conc.AwaitAll(futures...)
if err != nil {
return nil, err
}
results := make([]*internalpb.SearchResults, len(futures))
for i, future := range futures {
result := future.Value()
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Debug("delegator hybrid search failed",
zap.String("reason", result.GetStatus().GetReason()))
return nil, merr.Error(result.GetStatus())
}
results[i] = result
}
var ret *internalpb.SearchResults
ret, err = segments.MergeToAdvancedResults(ctx, results)
if err != nil {
return nil, err
}
return []*internalpb.SearchResults{ret}, nil
}
return sd.search(ctx, req, sealed, growing)
}
// HybridSearch preforms hybrid search operation on shard.
func (sd *shardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
log := sd.getLogger(ctx)
if err := sd.lifetime.Add(lifetime.IsWorking); err != nil {
return nil, err
}
defer sd.lifetime.Done()
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received hybrid search request not belongs to it",
zap.Strings("reqChannels", req.GetDmlChannels()),
)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}
partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return nil, merr.WrapErrPartitionNotLoaded(partitions)
}
// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator hybrid search failed to wait tsafe", zap.Error(err))
return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to hybrid search, current distribution is not serviceable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
})
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetReqs()))
for index := range req.GetReq().GetReqs() {
request := req.GetReq().Reqs[index]
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
Req: request,
DmlChannels: req.GetDmlChannels(),
TotalChannelNum: req.GetTotalChannelNum(),
FromShardLeader: true,
}
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
if searchReq.GetReq().GetMvccTimestamp() == 0 {
searchReq.GetReq().MvccTimestamp = tSafe
}
results, err := sd.search(ctx, searchReq, sealed, growing)
if err != nil {
return nil, err
}
return segments.ReduceSearchResults(ctx,
results,
searchReq.Req.GetNq(),
searchReq.Req.GetTopk(),
searchReq.Req.GetMetricType())
})
futures[index] = future
}
err = conc.AwaitAll(futures...)
if err != nil {
return nil, err
}
ret := &querypb.HybridSearchResult{
Status: merr.Success(),
Results: make([]*internalpb.SearchResults, len(futures)),
}
channelsMvcc := make(map[string]uint64)
for i, future := range futures {
result := future.Value()
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Debug("delegator hybrid search failed",
zap.String("reason", result.GetStatus().GetReason()))
return nil, merr.Error(result.GetStatus())
}
ret.Results[i] = result
for ch, ts := range result.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
}
ret.ChannelsMvcc = channelsMvcc
log.Debug("Delegator hybrid search done")
return ret, nil
}
func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {
log := sd.getLogger(ctx)
if !sd.Serviceable() {

View File

@ -469,251 +469,6 @@ func (s *DelegatorSuite) TestSearch() {
})
}
func (s *DelegatorSuite) TestHybridSearch() {
s.delegator.Start()
paramtable.SetNodeID(1)
s.initSegments()
s.Run("normal", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(1, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
if req.GetScope() == querypb.DataScope_Streaming {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
}
if req.GetScope() == querypb.DataScope_Historical {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
}
}).Return(&internalpb.SearchResults{}, nil)
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
results, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
Reqs: []*internalpb.SearchRequest{
{Base: commonpbutil.NewMsgBase()},
{Base: commonpbutil.NewMsgBase()},
},
},
DmlChannels: []string{s.vchannelName},
})
s.NoError(err)
s.Equal(2, len(results.Results))
})
s.Run("partition_not_loaded", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
// not load partation -1,will return error
PartitionIDs: []int64{-1},
},
DmlChannels: []string{s.vchannelName},
})
s.True(errors.Is(err, merr.ErrPartitionNotLoaded))
})
s.Run("worker_return_error", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
Reqs: []*internalpb.SearchRequest{
{
Base: commonpbutil.NewMsgBase(),
},
},
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("worker_return_failure_code", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
},
}, nil)
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
Reqs: []*internalpb.SearchRequest{
{
Base: commonpbutil.NewMsgBase(),
},
},
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("wrong_channel", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
DmlChannels: []string{"non_exist_channel"},
})
s.Error(err)
})
s.Run("wait_tsafe_timeout", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
GuaranteeTimestamp: 10100,
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("tsafe_behind_max_lag", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 10001,
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("distribution_not_serviceable", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sd, ok := s.delegator.(*shardDelegator)
s.Require().True(ok)
sd.distribution.AddOfflines(1001)
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("cluster_not_serviceable", func() {
s.delegator.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: commonpbutil.NewMsgBase(),
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
}
func (s *DelegatorSuite) TestQuery() {
s.delegator.Start()
paramtable.SetNodeID(1)

View File

@ -253,61 +253,6 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
return _c
}
// HybridSearch provides a mock function with given fields: ctx, req
func (_m *MockShardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
ret := _m.Called(ctx, req)
var r0 *querypb.HybridSearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
return rf(ctx, req)
}
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*querypb.HybridSearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardDelegator_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
type MockShardDelegator_HybridSearch_Call struct {
*mock.Call
}
// HybridSearch is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.HybridSearchRequest
func (_e *MockShardDelegator_Expecter) HybridSearch(ctx interface{}, req interface{}) *MockShardDelegator_HybridSearch_Call {
return &MockShardDelegator_HybridSearch_Call{Call: _e.mock.On("HybridSearch", ctx, req)}
}
func (_c *MockShardDelegator_HybridSearch_Call) Run(run func(ctx context.Context, req *querypb.HybridSearchRequest)) *MockShardDelegator_HybridSearch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
})
return _c
}
func (_c *MockShardDelegator_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockShardDelegator_HybridSearch_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockShardDelegator_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockShardDelegator_HybridSearch_Call {
_c.Call.Return(run)
return _c
}
// LoadGrowing provides a mock function with given fields: ctx, infos, version
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
ret := _m.Called(ctx, infos, version)

View File

@ -383,7 +383,12 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
req.GetSegmentIDs(),
))
resp, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
var resp *internalpb.SearchResults
if req.GetReq().GetIsAdvanced() {
resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq())
} else {
resp, err = segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
}
if err != nil {
return nil, err
}
@ -402,63 +407,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
return resp, nil
}
func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.HybridSearchRequest, channel string) (*querypb.HybridSearchResult, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Int64("collectionID", req.Req.GetCollectionID()),
zap.String("channel", channel),
)
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
return nil, err
}
defer node.lifetime.Done()
var err error
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
defer func() {
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
}
}()
log.Debug("start to search channel")
searchCtx, cancel := context.WithCancel(ctx)
defer cancel()
// From Proxy
tr := timerecord.NewTimeRecorder("hybridSearchDelegator")
// get delegator
sd, ok := node.delegators.Get(channel)
if !ok {
err := merr.WrapErrChannelNotFound(channel)
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
return nil, err
}
// do hybrid search
result, err := sd.HybridSearch(searchCtx, req)
if err != nil {
log.Warn("failed to hybrid search on delegator", zap.Error(err))
return nil, err
}
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , traceID = %s, vChannel = %s",
traceID,
channel,
))
// update metric to prometheus
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
for _, searchReq := range req.GetReq().GetReqs() {
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetTopk()))
}
return result, nil
}
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.Req.GetCollectionID()),

View File

@ -102,6 +102,88 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
return searchResults, nil
}
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64) (*internalpb.SearchResults, error) {
if len(results) == 1 {
return results[0], nil
}
channelsMvcc := make(map[string]uint64)
for _, r := range results {
for ch, ts := range r.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
}
searchResults := &internalpb.SearchResults{
IsAdvanced: true,
ChannelsMvcc: channelsMvcc,
}
for _, result := range results {
if !result.GetIsAdvanced() {
continue
}
// we just append here, no need to split subResult and reduce
// defer this reduce to proxy
searchResults.SubResults = append(searchResults.SubResults, result.GetSubResults()...)
searchResults.NumQueries = result.GetNumQueries()
}
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
return result.GetCostAggregation(), true
}
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
return result.GetCostAggregation(), true
}
return nil, false
})
searchResults.CostAggregation = mergeRequestCost(requestCosts)
return searchResults, nil
}
func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
searchResults := &internalpb.SearchResults{
IsAdvanced: true,
}
channelsMvcc := make(map[string]uint64)
for _, r := range results {
for ch, ts := range r.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
}
searchResults.ChannelsMvcc = channelsMvcc
for index, result := range results {
// we just append here, no need to split subResult and reduce
// defer this reduce to proxy
subResult := &internalpb.SubSearchResults{
MetricType: result.GetMetricType(),
NumQueries: result.GetNumQueries(),
TopK: result.GetTopK(),
SlicedBlob: result.GetSlicedBlob(),
SlicedNumCount: result.GetSlicedNumCount(),
SlicedOffset: result.GetSlicedOffset(),
ReqIndex: int64(index),
}
searchResults.NumQueries = result.GetNumQueries()
searchResults.SubResults = append(searchResults.SubResults, subResult)
}
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
return result.GetCostAggregation(), true
}
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
return result.GetCostAggregation(), true
}
return nil, false
})
searchResults.CostAggregation = mergeRequestCost(requestCosts)
return searchResults, nil
}
func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) {
log := log.Ctx(ctx)

View File

@ -751,7 +751,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
for i, ch := range req.GetDmlChannels() {
ch := ch
req := &querypb.SearchRequest{
channelReq := &querypb.SearchRequest{
Req: req.Req,
DmlChannels: []string{ch},
SegmentIDs: req.SegmentIDs,
@ -762,7 +762,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
i := i
runningGp.Go(func() error {
ret, err := node.searchChannel(runningCtx, req, ch)
ret, err := node.searchChannel(runningCtx, channelReq, ch)
if err != nil {
return err
}
@ -779,12 +779,20 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
}
tr.RecordSpan()
result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
resp.Status = merr.Status(err)
var result *internalpb.SearchResults
var err2 error
if req.GetReq().GetIsAdvanced() {
result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq())
} else {
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
}
if err2 != nil {
log.Warn("failed to reduce search results", zap.Error(err2))
resp.Status = merr.Status(err2)
return resp, nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
Observe(float64(reduceLatency.Milliseconds()))
@ -800,103 +808,6 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return result, nil
}
// HybridSearch performs replica search tasks.
func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
zap.Strings("channels", req.GetDmlChannels()))
log.Debug("Received HybridSearchRequest",
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()))
tr := timerecord.NewTimeRecorderWithTrace(ctx, "HybridSearchRequest")
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
return &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: node.GetNodeID(),
},
Status: merr.Status(err),
}, nil
}
defer node.lifetime.Done()
resp := &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: node.GetNodeID(),
},
Status: merr.Success(),
}
collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
if collection == nil {
resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID()))
return resp, nil
}
MultipleResults := make([]*querypb.HybridSearchResult, len(req.GetDmlChannels()))
runningGp, runningCtx := errgroup.WithContext(ctx)
for i, ch := range req.GetDmlChannels() {
ch := ch
req := &querypb.HybridSearchRequest{
Req: req.Req,
DmlChannels: []string{ch},
TotalChannelNum: 1,
}
i := i
runningGp.Go(func() error {
ret, err := node.hybridSearchChannel(runningCtx, req, ch)
if err != nil {
return err
}
if err := merr.Error(ret.GetStatus()); err != nil {
return err
}
MultipleResults[i] = ret
return nil
})
}
if err := runningGp.Wait(); err != nil {
resp.Status = merr.Status(err)
return resp, nil
}
tr.RecordSpan()
channelsMvcc := make(map[string]uint64)
for i, searchReq := range req.GetReq().GetReqs() {
toReduceResults := make([]*internalpb.SearchResults, len(MultipleResults))
for index, hs := range MultipleResults {
toReduceResults[index] = hs.Results[i]
}
result, err := segments.ReduceSearchResults(ctx, toReduceResults, searchReq.GetNq(), searchReq.GetTopk(), searchReq.GetMetricType())
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
resp.Status = merr.Status(err)
return resp, nil
}
for ch, ts := range result.GetChannelsMvcc() {
channelsMvcc[ch] = ts
}
resp.Results = append(resp.Results, result)
}
resp.ChannelsMvcc = channelsMvcc
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.HybridSearchLabel).
Add(float64(proto.Size(req)))
if resp.GetCostAggregation() != nil {
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return resp, nil
}
// only used for delegator query segments from worker
func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
resp := &internalpb.RetrieveResults{

View File

@ -1337,47 +1337,6 @@ func (suite *ServiceSuite) TestSearchSegments_Failed() {
suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode())
}
func (suite *ServiceSuite) TestHybridSearch_Concurrent() {
ctx := context.Background()
// pre
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
concurrency := 16
futures := make([]*conc.Future[*querypb.HybridSearchResult], 0, concurrency)
for i := 0; i < concurrency; i++ {
future := conc.Go(func() (*querypb.HybridSearchResult, error) {
creq1, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
suite.NoError(err)
creq2, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
suite.NoError(err)
req := &querypb.HybridSearchRequest{
Req: &internalpb.HybridSearchRequest{
Base: &commonpb.MsgBase{
MsgID: rand.Int63(),
TargetID: suite.node.session.ServerID,
},
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
MvccTimestamp: typeutil.MaxTimestamp,
Reqs: []*internalpb.SearchRequest{creq1, creq2},
},
DmlChannels: []string{suite.vchannel},
}
return suite.node.HybridSearch(ctx, req)
})
futures = append(futures, future)
}
err := conc.AwaitAll(futures...)
suite.NoError(err)
for i := range futures {
suite.True(merr.Ok(futures[i].Value().GetStatus()))
}
}
func (suite *ServiceSuite) TestSearchSegments_Normal() {
ctx := context.Background()
// pre

View File

@ -82,10 +82,6 @@ func (m *GrpcQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequ
return &internalpb.SearchResults{}, m.Err
}
func (m *GrpcQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
return &querypb.HybridSearchResult{}, m.Err
}
func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
return &internalpb.SearchResults{}, m.Err
}

View File

@ -93,10 +93,6 @@ func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest
return qn.QueryNode.Search(ctx, in)
}
func (qn *qnServerWrapper) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
return qn.QueryNode.HybridSearch(ctx, in)
}
func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
return qn.QueryNode.SearchSegments(ctx, in)
}

View File

@ -1790,26 +1790,72 @@ def extract_vector_field_name_list(collection_w):
return vector_name_list
def get_activate_func_from_metric_type(metric_type):
activate_function = lambda x: x
if metric_type == "COSINE":
activate_function = lambda x: (1 + x) * 0.5
elif metric_type == "IP":
activate_function = lambda x: 0.5 + math.atan(x)/ math.pi
else:
activate_function = lambda x: 1.0 - 2*math.atan(x) / math.pi
return activate_function
def get_hybrid_search_base_results(search_res_dict_array):
def get_hybrid_search_base_results_rrf(search_res_dict_array, round_decimal=-1):
"""
merge the element in the dicts array
search_res_dict_array : the dict array in which the elements to be merged
return: the sorted id and score answer
"""
# calculate hybrid search base line
search_res_dict_merge = {}
ids_answer = []
score_answer = []
for i in range(len(search_res_dict_array) - 1):
for key in search_res_dict_array[i]:
if search_res_dict_array[i + 1].get(key):
search_res_dict_merge[key] = search_res_dict_array[i][key] + search_res_dict_array[i + 1][key]
else:
search_res_dict_merge[key] = search_res_dict_array[i][key]
for key in search_res_dict_array[i + 1]:
if not search_res_dict_array[i].get(key):
search_res_dict_merge[key] = search_res_dict_array[i + 1][key]
for i, result in enumerate(search_res_dict_array, 0):
for key, distance in result.items():
search_res_dict_merge[key] = search_res_dict_merge.get(key, 0) + distance
if round_decimal != -1 :
for k, v in search_res_dict_merge.items():
multiplier = math.pow(10.0, round_decimal)
v = math.floor(v*multiplier+0.5) / multiplier
search_res_dict_merge[k] = v
sorted_list = sorted(search_res_dict_merge.items(), key=lambda x: x[1], reverse=True)
for sort in sorted_list:
ids_answer.append(int(sort[0]))
score_answer.append(float(sort[1]))
return ids_answer, score_answer
def get_hybrid_search_base_results(search_res_dict_array, weights, metric_types, round_decimal=-1):
"""
merge the element in the dicts array
search_res_dict_array : the dict array in which the elements to be merged
return: the sorted id and score answer
"""
# calculate hybrid search base line
search_res_dict_merge = {}
ids_answer = []
score_answer = []
for i, result in enumerate(search_res_dict_array, 0):
activate_function = get_activate_func_from_metric_type(metric_types[i])
for key, distance in result.items():
activate_distance = activate_function(distance)
weight = weights[i]
search_res_dict_merge[key] = search_res_dict_merge.get(key, 0) + activate_function(distance) * weights[i]
if round_decimal != -1 :
for k, v in search_res_dict_merge.items():
multiplier = math.pow(10.0, round_decimal)
v = math.floor(v*multiplier+0.5) / multiplier
search_res_dict_merge[k] = v
sorted_list = sorted(search_res_dict_merge.items(), key=lambda x: x[1], reverse=True)
for sort in sorted_list:

View File

@ -15,6 +15,7 @@ import decimal
import multiprocessing
import numbers
import random
import math
import numpy
import threading
import pytest
@ -10495,6 +10496,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
# 3. prepare search params
req_list = []
weights = [0.2, 0.3, 0.5]
metrics = []
search_res_dict_array = []
for i in range(len(vector_name_list)):
# 4. generate search data
@ -10508,22 +10510,22 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"expr": "int64 > 0"}
req = AnnSearchRequest(**search_param)
req_list.append(req)
metrics.append("COSINE")
# 5. search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, default_limit,
default_search_exp,
offset = offset,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"ids": insert_ids,
"limit": default_limit})[0]
ids = search_res[0].ids
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
distance_array = search_res[0].distances
for j in range(len(ids)):
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 6. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
# 7. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), default_limit,
offset = offset,
@ -11236,7 +11238,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
search_res_dict[ids[j]] = 1/(j + 60 +1)
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line for RRFRanker
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results_rrf(search_res_dict_array)
# 5. hybrid search
hybrid_search_0 = collection_w.hybrid_search(req_list, RRFRanker(), default_limit,
check_task=CheckTasks.check_search_results,
@ -11296,7 +11298,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
# search for get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, default_limit,
default_search_exp, offset=offset,
default_search_exp, offset=0,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"ids": insert_ids,
@ -11306,7 +11308,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
search_res_dict[ids[j]] = 1/(j + k +1)
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line for RRFRanker
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results_rrf(search_res_dict_array)
# 5. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, RRFRanker(k), default_limit,
offset=offset,
@ -11444,7 +11446,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
search_res_dict[ids[j]] = 1/(j + k +1)
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line for RRFRanker
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results_rrf(search_res_dict_array)
# 5. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, RRFRanker(k), default_limit,
check_task=CheckTasks.check_search_results,
@ -11453,7 +11455,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"limit": default_limit})[0]
# 6. compare results through the re-calculated distances
for i in range(len(score_answer[:default_limit])):
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
assert delta < hybrid_search_epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("limit", [1, 100, 16384])
@ -11477,6 +11480,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
search_res_dict_array = []
if limit > default_nb:
limit = default_limit
metrics = []
for i in range(len(vector_name_list)):
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
search_res_dict = {}
@ -11488,6 +11492,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"expr": "int64 > 0"}
req = AnnSearchRequest(**search_param)
req_list.append(req)
metrics.append("COSINE")
# search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, limit,
@ -11497,12 +11502,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"ids": insert_ids,
"limit": limit})[0]
ids = search_res[0].ids
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
distance_array = search_res[0].distances
for j in range(len(ids)):
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics, 5)
# 5. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), limit,
round_decimal=5,
@ -11512,7 +11517,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"limit": limit})[0]
# 6. compare results through the re-calculated distances
for i in range(len(score_answer[:limit])):
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
assert delta < hybrid_search_epsilon
@pytest.mark.tags(CaseLabel.L1)
def test_hybrid_search_limit_out_of_range_max(self):
@ -11598,6 +11604,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
req_list = []
weights = [0.2, 0.3, 0.5]
search_res_dict_array = []
metrics = []
for i in range(len(vector_name_list)):
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
search_res_dict = {}
@ -11609,6 +11616,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"expr": "int64 > 0"}
req = AnnSearchRequest(**search_param)
req_list.append(req)
metrics.append("COSINE")
# search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, default_limit,
@ -11618,12 +11626,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"ids": insert_ids,
"limit": default_limit})[0]
ids = search_res[0].ids
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
distance_array = search_res[0].distances
for j in range(len(ids)):
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
# 5. hybrid search
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name,
default_json_field_name]
@ -11637,7 +11645,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"output_fields": output_fields})[0]
# 6. compare results through the re-calculated distances
for i in range(len(score_answer[:default_limit])):
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
assert delta < hybrid_search_epsilon
@pytest.mark.tags(CaseLabel.L2)
def test_hybrid_search_with_output_fields_all_fields_wildcard(self):
@ -11656,6 +11665,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
req_list = []
weights = [0.2, 0.3, 0.5]
search_res_dict_array = []
metrics = []
for i in range(len(vector_name_list)):
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
search_res_dict = {}
@ -11667,6 +11677,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"expr": "int64 > 0"}
req = AnnSearchRequest(**search_param)
req_list.append(req)
metrics.append("COSINE")
# search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, default_limit,
@ -11676,12 +11687,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"ids": insert_ids,
"limit": default_limit})[0]
ids = search_res[0].ids
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
distance_array = search_res[0].distances
for j in range(len(ids)):
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
# 5. hybrid search
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name,
default_json_field_name]
@ -11695,7 +11706,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"output_fields": output_fields})[0]
# 6. compare results through the re-calculated distances
for i in range(len(score_answer[:default_limit])):
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
assert delta < hybrid_search_epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [[default_search_field], [default_search_field, default_int64_field_name]])
@ -11717,6 +11729,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
req_list = []
weights = [0.2, 0.3, 0.5]
search_res_dict_array = []
metrics = []
for i in range(len(vector_name_list)):
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
search_res_dict = {}
@ -11728,6 +11741,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"expr": "int64 > 0"}
req = AnnSearchRequest(**search_param)
req_list.append(req)
metrics.append("COSINE")
# search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, default_limit,
@ -11742,12 +11756,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
search_res.done()
search_res = search_res.result()
ids = search_res[0].ids
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
distance_array = search_res[0].distances
for j in range(len(ids)):
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
# 5. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), default_limit, _async = _async,
output_fields = output_fields,
@ -11762,7 +11776,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
hybrid_res = hybrid_res.result()
# 6. compare results through the re-calculated distances
for i in range(len(score_answer[:default_limit])):
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
assert delta < hybrid_search_epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("rerank", [RRFRanker(), WeightedRanker(0.1, 0.9, 1)])
@ -11825,6 +11840,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
search_res_dict_array = []
if limit > default_nb:
limit = default_limit
metrics = []
for i in range(len(vector_name_list)):
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
search_res_dict = {}
@ -11836,6 +11852,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"expr": "int64 > 0"}
req = AnnSearchRequest(**search_param)
req_list.append(req)
metrics.append("COSINE")
# search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, limit,
@ -11845,12 +11862,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"ids": insert_ids,
"limit": limit})[0]
ids = search_res[0].ids
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
distance_array = search_res[0].distances
for j in range(len(ids)):
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics, 5)
# 5. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), limit,
round_decimal=5,
@ -11860,7 +11877,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
"limit": limit})[0]
# 6. compare results through the re-calculated distances
for i in range(len(score_answer[:limit])):
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
assert delta < hybrid_search_epsilon
@pytest.mark.tags(CaseLabel.L1)
def test_hybrid_search_result_L2_order(self):
@ -11896,8 +11914,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
req_list.append(req)
# 4. hybrid search
res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), 10)[0]
is_sorted_ascend = lambda lst: all(lst[i] <= lst[i+1] for i in range(len(lst)-1))
assert is_sorted_ascend(res[0].distances)
is_sorted_descend = lambda lst: all(lst[i] >= lst[i+1] for i in range(len(lst)-1))
assert is_sorted_descend(res[0].distances)
@pytest.mark.tags(CaseLabel.L1)
def test_hybrid_search_result_order(self):