mirror of https://github.com/milvus-io/milvus.git
fix: Support mvcc with hybrid serach (#30114)
issue: https://github.com/milvus-io/milvus/issues/29656 /kind bug Signed-off-by: xige-16 <xi.ge@zilliz.com> --------- Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/30398/head
parent
32914a3ddf
commit
060c8603a3
|
@ -329,3 +329,10 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr
|
|||
return client.Delete(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
// HybridSearch performs replica hybrid search tasks in QueryNode.
|
||||
func (c *Client) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest, _ ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.HybridSearchResult, error) {
|
||||
return client.HybridSearch(ctx, req)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -374,3 +374,8 @@ func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
|
|||
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
return s.querynode.Delete(ctx, req)
|
||||
}
|
||||
|
||||
// HybridSearch performs hybrid search of streaming/historical replica on QueryNode.
|
||||
func (s *Server) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
return s.querynode.HybridSearch(ctx, req)
|
||||
}
|
||||
|
|
|
@ -511,6 +511,61 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.C
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNode) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNode_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockQueryNode_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.HybridSearchRequest
|
||||
func (_e *MockQueryNode_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNode_HybridSearch_Call {
|
||||
return &MockQueryNode_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNode_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNode_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNode_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNode_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNode_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNode_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Init provides a mock function with given fields:
|
||||
func (_m *MockQueryNode) Init() error {
|
||||
ret := _m.Called()
|
||||
|
|
|
@ -632,6 +632,76 @@ func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(con
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNodeClient_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockQueryNodeClient_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.HybridSearchRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryNodeClient_Expecter) HybridSearch(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_HybridSearch_Call {
|
||||
return &MockQueryNodeClient_HybridSearch_Call{Call: _e.mock.On("HybridSearch",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeClient_HybridSearch_Call) Run(run func(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeClient_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeClient_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeClient_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)) *MockQueryNodeClient_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadPartitions provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
|
|
@ -104,6 +104,18 @@ message SearchRequest {
|
|||
string username = 18;
|
||||
}
|
||||
|
||||
message HybridSearchRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 reqID = 2;
|
||||
int64 dbID = 3;
|
||||
int64 collectionID = 4;
|
||||
repeated int64 partitionIDs = 5;
|
||||
repeated SearchRequest reqs = 6;
|
||||
uint64 mvcc_timestamp = 11;
|
||||
uint64 guarantee_timestamp = 12;
|
||||
uint64 timeout_timestamp = 13;
|
||||
}
|
||||
|
||||
message SearchResults {
|
||||
common.MsgBase base = 1;
|
||||
common.Status status = 2;
|
||||
|
|
|
@ -71,6 +71,7 @@ service QueryNode {
|
|||
|
||||
rpc GetStatistics(GetStatisticsRequest) returns (internal.GetStatisticsResponse) {}
|
||||
rpc Search(SearchRequest) returns (internal.SearchResults) {}
|
||||
rpc HybridSearch(HybridSearchRequest) returns (HybridSearchResult) {}
|
||||
rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {}
|
||||
rpc Query(QueryRequest) returns (internal.RetrieveResults) {}
|
||||
rpc QueryStream(QueryRequest) returns (stream internal.RetrieveResults){}
|
||||
|
@ -328,6 +329,20 @@ message SearchRequest {
|
|||
int32 total_channel_num = 6;
|
||||
}
|
||||
|
||||
message HybridSearchRequest {
|
||||
internal.HybridSearchRequest req = 1;
|
||||
repeated string dml_channels = 2;
|
||||
int32 total_channel_num = 3;
|
||||
}
|
||||
|
||||
message HybridSearchResult {
|
||||
common.MsgBase base = 1;
|
||||
common.Status status = 2;
|
||||
repeated internal.SearchResults results = 3;
|
||||
internal.CostAggregation costAggregation = 4;
|
||||
map<string, uint64> channels_mvcc = 5;
|
||||
}
|
||||
|
||||
message QueryRequest {
|
||||
internal.RetrieveRequest req = 1;
|
||||
repeated string dml_channels = 2;
|
||||
|
|
|
@ -2784,11 +2784,18 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||
qt := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
request: request,
|
||||
tr: timerecord.NewTimeRecorder(method),
|
||||
qc: node.queryCoord,
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
tr: timerecord.NewTimeRecorder(method),
|
||||
qc: node.queryCoord,
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
}
|
||||
|
||||
guaranteeTs := request.GuaranteeTimestamp
|
||||
|
@ -2831,7 +2838,7 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||
|
||||
log.Debug(
|
||||
rpcEnqueued(method),
|
||||
zap.Uint64("timestamp", qt.request.Base.Timestamp),
|
||||
zap.Uint64("timestamp", qt.Base.Timestamp),
|
||||
)
|
||||
|
||||
if err := qt.WaitToFinish(); err != nil {
|
||||
|
|
|
@ -120,7 +120,7 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
|
|||
return nil, errors.New("The type of rank param k should be float")
|
||||
}
|
||||
if k <= 0 || k >= maxRRFParamsValue {
|
||||
return nil, errors.New("The rank params k should be in range (0, 16384)")
|
||||
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
|
||||
}
|
||||
log.Debug("rrf params", zap.Float64("k", k))
|
||||
for i := range reqs {
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func initSearchRequest(ctx context.Context, t *searchTask) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
var err error
|
||||
for i, kv := range t.request.GetSearchParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
||||
if err != nil {
|
||||
return errors.New("parse search growing failed")
|
||||
}
|
||||
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
if err != nil {
|
||||
log.Warn("failed to get nq", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateNQLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
log = log.With(zap.Int64("nq", nq))
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
if err != nil {
|
||||
log.Warn("fail to get output field ids", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
partitionNames := t.request.GetPartitionNames()
|
||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||
if err != nil || len(annsField) == 0 {
|
||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||
if len(vecFields) == 0 {
|
||||
return errors.New(AnnsFieldKey + " not found in schema")
|
||||
}
|
||||
|
||||
if enableMultipleVectorFields && len(vecFields) > 1 {
|
||||
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||
}
|
||||
|
||||
annsField = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryInfo.GroupByFieldId != 0 {
|
||||
t.SearchRequest.IgnoreGrowing = true
|
||||
// for group by operation, currently, we ignore growing segments
|
||||
}
|
||||
t.offset = offset
|
||||
|
||||
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
|
||||
if err != nil {
|
||||
log.Warn("failed to create query plan", zap.Error(err),
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
|
||||
}
|
||||
log.Debug("create query plan",
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||
|
||||
if t.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to parse expr", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
|
||||
if err != nil {
|
||||
log.Warn("failed to assign partition keys", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||
}
|
||||
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||
if err != nil {
|
||||
log.Warn("failed to estimate result size", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if estimateSize >= requeryThreshold {
|
||||
t.requery = true
|
||||
plan.OutputFieldIds = nil
|
||||
}
|
||||
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames)
|
||||
if err != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
}
|
||||
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
|
||||
// Set username of this search request for feature like task scheduling.
|
||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||
t.SearchRequest.Username = username
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -14,10 +14,11 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
@ -32,9 +33,11 @@ const (
|
|||
type hybridSearchTask struct {
|
||||
Condition
|
||||
ctx context.Context
|
||||
*internalpb.HybridSearchRequest
|
||||
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.HybridSearchRequest
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.HybridSearchRequest
|
||||
searchTasks []*searchTask
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
schema *schemaInfo
|
||||
|
@ -42,15 +45,14 @@ type hybridSearchTask struct {
|
|||
|
||||
userOutputFields []string
|
||||
|
||||
qc types.QueryCoordClient
|
||||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
queryChannelsTs map[string]Timestamp
|
||||
|
||||
collectionID UniqueID
|
||||
qc types.QueryCoordClient
|
||||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
|
||||
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
|
||||
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
||||
reScorers []reScorer
|
||||
queryChannelsTs map[string]Timestamp
|
||||
rankParams *rankParams
|
||||
}
|
||||
|
||||
|
@ -63,7 +65,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
if len(t.request.Requests) > defaultMaxSearchRequest {
|
||||
return errors.New("maximum of ann search requests is 1024")
|
||||
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
|
||||
}
|
||||
for _, req := range t.request.GetRequests() {
|
||||
nq, err := getNq(req)
|
||||
|
@ -78,12 +80,15 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
t.Base.MsgType = commonpb.MsgType_Search
|
||||
t.Base.SourceID = paramtable.GetNodeID()
|
||||
|
||||
collectionName := t.request.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.collectionID = collID
|
||||
t.CollectionID = collID
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
||||
|
@ -113,6 +118,82 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||
t.requery = true
|
||||
}
|
||||
|
||||
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
|
||||
if err2 != nil {
|
||||
log.Warn("Proxy::hybridSearchTask::PreExecute failed to GetCollectionInfo from cache",
|
||||
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
|
||||
return err2
|
||||
}
|
||||
guaranteeTs := t.request.GetGuaranteeTimestamp()
|
||||
var consistencyLevel commonpb.ConsistencyLevel
|
||||
useDefaultConsistency := t.request.GetUseDefaultConsistency()
|
||||
if useDefaultConsistency {
|
||||
consistencyLevel = collectionInfo.consistencyLevel
|
||||
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
|
||||
} else {
|
||||
consistencyLevel = t.request.GetConsistencyLevel()
|
||||
// Compatibility logic, parse guarantee timestamp
|
||||
if consistencyLevel == 0 && guaranteeTs > 0 {
|
||||
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
|
||||
} else {
|
||||
// parse from guarantee timestamp and user input consistency level
|
||||
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
|
||||
}
|
||||
}
|
||||
|
||||
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.searchTasks = make([]*searchTask, len(t.request.GetRequests()))
|
||||
for index := range t.request.Requests {
|
||||
searchReq := t.request.Requests[index]
|
||||
|
||||
if len(searchReq.GetCollectionName()) == 0 {
|
||||
searchReq.CollectionName = t.request.GetCollectionName()
|
||||
} else if searchReq.GetCollectionName() != t.request.GetCollectionName() {
|
||||
return errors.New(fmt.Sprintf("inconsistent collection name in hybrid search request, "+
|
||||
"expect %s, actual %s", searchReq.GetCollectionName(), t.request.GetCollectionName()))
|
||||
}
|
||||
|
||||
searchReq.PartitionNames = t.request.GetPartitionNames()
|
||||
searchReq.ConsistencyLevel = consistencyLevel
|
||||
searchReq.GuaranteeTimestamp = guaranteeTs
|
||||
searchReq.UseDefaultConsistency = useDefaultConsistency
|
||||
searchReq.OutputFields = nil
|
||||
|
||||
t.searchTasks[index] = &searchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
collectionName: collectionName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
DbID: 0, // todo
|
||||
CollectionID: collID,
|
||||
},
|
||||
request: searchReq,
|
||||
schema: t.schema,
|
||||
tr: timerecord.NewTimeRecorder("hybrid search"),
|
||||
qc: t.qc,
|
||||
node: t.node,
|
||||
lb: t.lb,
|
||||
|
||||
partitionKeyMode: partitionKeyMode,
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
}
|
||||
err := initSearchRequest(ctx, t.searchTasks[index])
|
||||
if err != nil {
|
||||
log.Debug("init hybrid search request failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("hybrid search preExecute done.",
|
||||
zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()),
|
||||
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
|
||||
|
@ -121,56 +202,65 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
|
||||
for _, searchTask := range t.searchTasks {
|
||||
t.HybridSearchRequest.Reqs = append(t.HybridSearchRequest.Reqs, searchTask.SearchRequest)
|
||||
}
|
||||
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
|
||||
hybridSearchReq.GetBase().TargetID = nodeID
|
||||
req := &querypb.HybridSearchRequest{
|
||||
Req: hybridSearchReq,
|
||||
DmlChannels: []string{channel},
|
||||
TotalChannelNum: int32(1),
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.String("channel", channel))
|
||||
|
||||
var result *querypb.HybridSearchResult
|
||||
var err error
|
||||
|
||||
result, err = qn.HybridSearch(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode hybrid search return error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Warn("QueryNode is not shardLeader")
|
||||
return errInvalidShardLeaders
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("QueryNode hybrid search result error",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return errors.Wrapf(merr.Error(result.GetStatus()), "fail to hybrid search on QueryNode %d", nodeID)
|
||||
}
|
||||
t.resultBuf.Insert(result)
|
||||
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
|
||||
futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests))
|
||||
for index := range t.request.Requests {
|
||||
searchReq := t.request.Requests[index]
|
||||
future := conc.Go(func() (*milvuspb.SearchResults, error) {
|
||||
searchReq.TravelTimestamp = t.request.GetTravelTimestamp()
|
||||
searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
|
||||
searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta()
|
||||
searchReq.ConsistencyLevel = t.request.GetConsistencyLevel()
|
||||
searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency()
|
||||
searchReq.OutputFields = nil
|
||||
|
||||
return t.node.Search(ctx, searchReq)
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
||||
err := conc.AwaitAll(futures...)
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*querypb.HybridSearchResult]()
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.CollectionID,
|
||||
collectionName: t.request.GetCollectionName(),
|
||||
nq: 1,
|
||||
exec: t.hybridSearchShard,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
|
||||
for i, future := range futures {
|
||||
err = future.Err()
|
||||
if err != nil {
|
||||
log.Debug("QueryNode search result error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
result := futures[i].Value()
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Debug("QueryNode search result error",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return merr.Error(result.GetStatus())
|
||||
}
|
||||
|
||||
t.reScorers[i].reScore(result)
|
||||
t.multipleRecallResults.Insert(result)
|
||||
log.Warn("hybrid search execute failed", zap.Error(err))
|
||||
return errors.Wrap(err, "failed to hybrid search")
|
||||
}
|
||||
|
||||
log.Debug("hybrid search execute done.")
|
||||
|
@ -194,7 +284,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||
|
||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(LimitKey + " not found in search_params")
|
||||
return nil, errors.New(LimitKey + " not found in rank_params")
|
||||
}
|
||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||
if err != nil {
|
||||
|
@ -235,16 +325,59 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) collectHybridSearchResults(ctx context.Context) error {
|
||||
select {
|
||||
case <-t.TraceCtx().Done():
|
||||
log.Ctx(ctx).Warn("hybrid search task wait to finish timeout!")
|
||||
return fmt.Errorf("hybrid search task wait to finish timeout, msgID=%d", t.ID())
|
||||
default:
|
||||
log.Ctx(ctx).Debug("all hybrid searches are finished or canceled")
|
||||
t.resultBuf.Range(func(res *querypb.HybridSearchResult) bool {
|
||||
for index, searchResult := range res.GetResults() {
|
||||
t.searchTasks[index].resultBuf.Insert(searchResult)
|
||||
}
|
||||
log.Ctx(ctx).Debug("proxy receives one hybrid search result",
|
||||
zap.Int64("sourceID", res.GetBase().GetSourceID()))
|
||||
return true
|
||||
})
|
||||
|
||||
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
|
||||
for i, searchTask := range t.searchTasks {
|
||||
err := searchTask.PostExecute(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.reScorers[i].reScore(searchTask.result)
|
||||
t.multipleRecallResults.Insert(searchTask.result)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
err := t.collectHybridSearchResults(ctx)
|
||||
if err != nil {
|
||||
log.Warn("failed to collect hybrid search results", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.queryChannelsTs = make(map[string]uint64)
|
||||
for _, r := range t.resultBuf.Collect() {
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
t.queryChannelsTs[ch] = ts
|
||||
}
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
|
@ -304,9 +437,8 @@ func (t *hybridSearchTask) Requery() error {
|
|||
},
|
||||
}
|
||||
|
||||
// TODO:Xige-16 refine the mvcc functionality of hybrid search
|
||||
// TODO:silverxia move partitionIDs to hybrid search level
|
||||
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
||||
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
||||
}
|
||||
|
||||
func rankSearchResultData(ctx context.Context,
|
||||
|
@ -436,11 +568,11 @@ func (t *hybridSearchTask) TraceCtx() context.Context {
|
|||
}
|
||||
|
||||
func (t *hybridSearchTask) ID() UniqueID {
|
||||
return t.request.Base.MsgID
|
||||
return t.Base.MsgID
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) SetID(uid UniqueID) {
|
||||
t.request.Base.MsgID = uid
|
||||
t.Base.MsgID = uid
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Name() string {
|
||||
|
@ -448,24 +580,24 @@ func (t *hybridSearchTask) Name() string {
|
|||
}
|
||||
|
||||
func (t *hybridSearchTask) Type() commonpb.MsgType {
|
||||
return t.request.Base.MsgType
|
||||
return t.Base.MsgType
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) BeginTs() Timestamp {
|
||||
return t.request.Base.Timestamp
|
||||
return t.Base.Timestamp
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) EndTs() Timestamp {
|
||||
return t.request.Base.Timestamp
|
||||
return t.Base.Timestamp
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) SetTs(ts Timestamp) {
|
||||
t.request.Base.Timestamp = ts
|
||||
t.Base.Timestamp = ts
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) OnEnqueue() error {
|
||||
t.request.Base = commonpbutil.NewMsgBase()
|
||||
t.request.Base.MsgType = commonpb.MsgType_Search
|
||||
t.request.Base.SourceID = paramtable.GetNodeID()
|
||||
t.Base = commonpbutil.NewMsgBase()
|
||||
t.Base.MsgType = commonpb.MsgType_Search
|
||||
t.Base.SourceID = paramtable.GetNodeID()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
@ -67,8 +68,9 @@ func TestHybridSearchTask_PreExecute(t *testing.T) {
|
|||
|
||||
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
|
||||
task := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
CollectionName: collName,
|
||||
Requests: reqs,
|
||||
|
@ -225,6 +227,7 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
|||
result: &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
},
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
CollectionName: collectionName,
|
||||
Requests: []*milvuspb.SearchRequest{
|
||||
|
@ -266,12 +269,12 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
|||
task.ctx = ctx
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
||||
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
},
|
||||
|
@ -291,6 +294,10 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Status: merr.Success(),
|
||||
}, nil)
|
||||
|
||||
t.Run("Test empty result", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -313,6 +320,9 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|||
qc: nil,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
schema: schema,
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
|
@ -320,6 +330,7 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|||
CollectionName: collectionName,
|
||||
RankParams: rankParams,
|
||||
},
|
||||
resultBuf: typeutil.NewConcurrentSet[*querypb.HybridSearchResult](),
|
||||
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -53,10 +52,11 @@ type searchTask struct {
|
|||
result *milvuspb.SearchResults
|
||||
request *milvuspb.SearchRequest
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
collectionName string
|
||||
schema *schemaInfo
|
||||
requery bool
|
||||
tr *timerecord.TimeRecorder
|
||||
collectionName string
|
||||
schema *schemaInfo
|
||||
requery bool
|
||||
partitionKeyMode bool
|
||||
|
||||
userOutputFields []string
|
||||
|
||||
|
@ -250,22 +250,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||
|
||||
t.SearchRequest.DbID = 0 // todo
|
||||
t.SearchRequest.CollectionID = collID
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get collection schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("is partition key mode failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
|
||||
if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
|
||||
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
||||
}
|
||||
|
||||
|
@ -277,123 +276,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
log.Debug("translate output fields",
|
||||
zap.Strings("output fields", t.request.GetOutputFields()))
|
||||
|
||||
// fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
for i, kv := range t.request.GetSearchParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
||||
if err != nil {
|
||||
return errors.New("parse search growing failed")
|
||||
}
|
||||
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
err = initSearchRequest(ctx, t)
|
||||
if err != nil {
|
||||
log.Warn("failed to get nq", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateNQLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
log = log.With(zap.Int64("nq", nq))
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
if err != nil {
|
||||
log.Warn("fail to get output field ids", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
partitionNames := t.request.GetPartitionNames()
|
||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||
if err != nil || len(annsField) == 0 {
|
||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||
if len(vecFields) == 0 {
|
||||
return errors.New(AnnsFieldKey + " not found in schema")
|
||||
}
|
||||
|
||||
if enableMultipleVectorFields && len(vecFields) > 1 {
|
||||
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||
}
|
||||
|
||||
annsField = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryInfo.GroupByFieldId != 0 {
|
||||
t.SearchRequest.IgnoreGrowing = true
|
||||
// for group by operation, currently, we ignore growing segments
|
||||
}
|
||||
t.offset = offset
|
||||
|
||||
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
|
||||
if err != nil {
|
||||
log.Warn("failed to create query plan", zap.Error(err),
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
|
||||
}
|
||||
log.Debug("create query plan",
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||
|
||||
if partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to parse expr", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), collectionName, partitionKeys)
|
||||
if err != nil {
|
||||
log.Warn("failed to assign partition keys", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||
}
|
||||
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||
if err != nil {
|
||||
log.Warn("failed to estimate result size", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if estimateSize >= requeryThreshold {
|
||||
t.requery = true
|
||||
plan.OutputFieldIds = nil
|
||||
}
|
||||
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Proxy::searchTask::PreExecute",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, partitionNames)
|
||||
if err != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err))
|
||||
log.Debug("init search request failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -421,17 +306,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
||||
|
||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
}
|
||||
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
|
||||
// Set username of this search request for feature like task scheduling.
|
||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||
t.SearchRequest.Username = username
|
||||
}
|
||||
|
||||
log.Debug("search PreExecute done.",
|
||||
zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
zap.Bool("use_default_consistency", useDefaultConsistency),
|
||||
|
|
|
@ -469,6 +469,61 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(con
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNodeServer) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNodeServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockQueryNodeServer_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.HybridSearchRequest
|
||||
func (_e *MockQueryNodeServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_HybridSearch_Call {
|
||||
return &MockQueryNodeServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNodeServer_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeServer_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNodeServer_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadPartitions provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
|
|
@ -62,6 +62,7 @@ type ShardDelegator interface {
|
|||
GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry)
|
||||
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
|
||||
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
|
||||
HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)
|
||||
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
|
||||
QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error
|
||||
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
|
||||
|
@ -184,6 +185,44 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
|
|||
return nodeReq
|
||||
}
|
||||
|
||||
// Search preforms search operation on shard.
|
||||
func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if req.Req.IgnoreGrowing {
|
||||
growing = []SegmentEntry{}
|
||||
}
|
||||
|
||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||
log.Debug("search segments...",
|
||||
zap.Int("sealedNum", sealedNum),
|
||||
zap.Int("growingNum", len(growing)),
|
||||
)
|
||||
|
||||
req, err := optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
|
||||
if err != nil {
|
||||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||
if err != nil {
|
||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
||||
return worker.SearchSegments(ctx, req)
|
||||
}, "Search", log)
|
||||
if err != nil {
|
||||
log.Warn("Delegator search failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("Delegator search done")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Search preforms search operation on shard.
|
||||
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
|
@ -229,39 +268,113 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||
})
|
||||
|
||||
if req.Req.IgnoreGrowing {
|
||||
growing = []SegmentEntry{}
|
||||
return sd.search(ctx, req, sealed, growing)
|
||||
}
|
||||
|
||||
// HybridSearch preforms hybrid search operation on shard.
|
||||
func (sd *shardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if err := sd.lifetime.Add(lifetime.IsWorking); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sd.lifetime.Done()
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received hybrid search request not belongs to it",
|
||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||
)
|
||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||
}
|
||||
|
||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||
log.Debug("search segments...",
|
||||
zap.Int("sealedNum", sealedNum),
|
||||
zap.Int("growingNum", len(growing)),
|
||||
)
|
||||
partitions := req.GetReq().GetPartitionIDs()
|
||||
if !sd.collection.ExistPartition(partitions...) {
|
||||
return nil, merr.WrapErrPartitionNotLoaded(partitions)
|
||||
}
|
||||
|
||||
req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
|
||||
// wait tsafe
|
||||
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator hybrid search failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
}
|
||||
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel).
|
||||
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to hybrid search, current distribution is not serviceable")
|
||||
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
|
||||
}
|
||||
defer sd.distribution.Unpin(version)
|
||||
existPartitions := sd.collection.GetPartitions()
|
||||
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
|
||||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||
})
|
||||
|
||||
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetReqs()))
|
||||
for index := range req.GetReq().GetReqs() {
|
||||
request := req.GetReq().Reqs[index]
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
searchReq := &querypb.SearchRequest{
|
||||
Req: request,
|
||||
DmlChannels: req.GetDmlChannels(),
|
||||
TotalChannelNum: req.GetTotalChannelNum(),
|
||||
FromShardLeader: true,
|
||||
}
|
||||
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
|
||||
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
|
||||
if searchReq.GetReq().GetMvccTimestamp() == 0 {
|
||||
searchReq.GetReq().MvccTimestamp = tSafe
|
||||
}
|
||||
|
||||
results, err := sd.search(ctx, searchReq, sealed, growing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return segments.ReduceSearchResults(ctx,
|
||||
results,
|
||||
searchReq.Req.GetNq(),
|
||||
searchReq.Req.GetTopk(),
|
||||
searchReq.Req.GetMetricType())
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
||||
err = conc.AwaitAll(futures...)
|
||||
if err != nil {
|
||||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||
if err != nil {
|
||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||
return nil, err
|
||||
ret := &querypb.HybridSearchResult{
|
||||
Status: merr.Success(),
|
||||
Results: make([]*internalpb.SearchResults, len(futures)),
|
||||
}
|
||||
|
||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
||||
return worker.SearchSegments(ctx, req)
|
||||
}, "Search", log)
|
||||
if err != nil {
|
||||
log.Warn("Delegator search failed", zap.Error(err))
|
||||
return nil, err
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for i, future := range futures {
|
||||
result := future.Value()
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Debug("delegator hybrid search failed",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return nil, merr.Error(result.GetStatus())
|
||||
}
|
||||
|
||||
ret.Results[i] = result
|
||||
for ch, ts := range result.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
}
|
||||
ret.ChannelsMvcc = channelsMvcc
|
||||
|
||||
log.Debug("Delegator search done")
|
||||
log.Debug("Delegator hybrid search done")
|
||||
|
||||
return results, nil
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {
|
||||
|
|
|
@ -469,6 +469,251 @@ func (s *DelegatorSuite) TestSearch() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestHybridSearch() {
|
||||
s.delegator.Start()
|
||||
paramtable.SetNodeID(1)
|
||||
s.initSegments()
|
||||
s.Run("normal", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(1, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
if req.GetScope() == querypb.DataScope_Streaming {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
|
||||
}
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
|
||||
}
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
results, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Reqs: []*internalpb.SearchRequest{
|
||||
{Base: commonpbutil.NewMsgBase()},
|
||||
{Base: commonpbutil.NewMsgBase()},
|
||||
},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(2, len(results.Results))
|
||||
})
|
||||
|
||||
s.Run("partition_not_loaded", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
// not load partation -1,will return error
|
||||
PartitionIDs: []int64{-1},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.True(errors.Is(err, merr.ErrPartitionNotLoaded))
|
||||
})
|
||||
|
||||
s.Run("worker_return_error", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
|
||||
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Reqs: []*internalpb.SearchRequest{
|
||||
{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("worker_return_failure_code", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mocked error",
|
||||
},
|
||||
}, nil)
|
||||
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Reqs: []*internalpb.SearchRequest{
|
||||
{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wrong_channel", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
DmlChannels: []string{"non_exist_channel"},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wait_tsafe_timeout", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
GuaranteeTimestamp: 10100,
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("tsafe_behind_max_lag", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 10001,
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("distribution_not_serviceable", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sd, ok := s.delegator.(*shardDelegator)
|
||||
s.Require().True(ok)
|
||||
sd.distribution.AddOfflines(1001)
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("cluster_not_serviceable", func() {
|
||||
s.delegator.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestQuery() {
|
||||
s.delegator.Start()
|
||||
paramtable.SetNodeID(1)
|
||||
|
|
|
@ -253,6 +253,61 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: ctx, req
|
||||
func (_m *MockShardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(ctx, req)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardDelegator_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockShardDelegator_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.HybridSearchRequest
|
||||
func (_e *MockShardDelegator_Expecter) HybridSearch(ctx interface{}, req interface{}) *MockShardDelegator_HybridSearch_Call {
|
||||
return &MockShardDelegator_HybridSearch_Call{Call: _e.mock.On("HybridSearch", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_HybridSearch_Call) Run(run func(ctx context.Context, req *querypb.HybridSearchRequest)) *MockShardDelegator_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockShardDelegator_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockShardDelegator_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadGrowing provides a mock function with given fields: ctx, infos, version
|
||||
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
||||
ret := _m.Called(ctx, infos, version)
|
||||
|
|
|
@ -401,6 +401,63 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.HybridSearchRequest, channel string) (*querypb.HybridSearchResult, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
|
||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||
zap.String("channel", channel),
|
||||
)
|
||||
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
|
||||
|
||||
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer node.lifetime.Done()
|
||||
|
||||
var err error
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("start to search channel")
|
||||
searchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// From Proxy
|
||||
tr := timerecord.NewTimeRecorder("hybridSearchDelegator")
|
||||
// get delegator
|
||||
sd, ok := node.delegators.Get(channel)
|
||||
if !ok {
|
||||
err := merr.WrapErrChannelNotFound(channel)
|
||||
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
// do hybrid search
|
||||
result, err := sd.HybridSearch(searchCtx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to hybrid search on delegator", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , traceID = %s, vChannel = %s",
|
||||
traceID,
|
||||
channel,
|
||||
))
|
||||
|
||||
// update metric to prometheus
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
|
||||
for _, searchReq := range req.GetReq().GetReqs() {
|
||||
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetNq()))
|
||||
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetTopk()))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||
|
|
|
@ -821,6 +821,114 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// HybridSearch performs replica search tasks.
|
||||
func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
|
||||
zap.Strings("channels", req.GetDmlChannels()))
|
||||
|
||||
log.Debug("Received HybridSearchRequest",
|
||||
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
|
||||
zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()))
|
||||
|
||||
tr := timerecord.NewTimeRecorderWithTrace(ctx, "HybridSearchRequest")
|
||||
|
||||
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
|
||||
return &querypb.HybridSearchResult{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
defer node.lifetime.Done()
|
||||
|
||||
err := merr.CheckTargetID(req.GetReq().GetBase())
|
||||
if err != nil {
|
||||
log.Warn("target ID check failed", zap.Error(err))
|
||||
return &querypb.HybridSearchResult{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
resp := &querypb.HybridSearchResult{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
Status: merr.Success(),
|
||||
}
|
||||
collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
|
||||
if collection == nil {
|
||||
resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID()))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
MultipleResults := make([]*querypb.HybridSearchResult, len(req.GetDmlChannels()))
|
||||
runningGp, runningCtx := errgroup.WithContext(ctx)
|
||||
|
||||
for i, ch := range req.GetDmlChannels() {
|
||||
ch := ch
|
||||
req := &querypb.HybridSearchRequest{
|
||||
Req: req.Req,
|
||||
DmlChannels: []string{ch},
|
||||
TotalChannelNum: 1,
|
||||
}
|
||||
|
||||
i := i
|
||||
runningGp.Go(func() error {
|
||||
ret, err := node.hybridSearchChannel(runningCtx, req, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := merr.Error(ret.GetStatus()); err != nil {
|
||||
return err
|
||||
}
|
||||
MultipleResults[i] = ret
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := runningGp.Wait(); err != nil {
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
tr.RecordSpan()
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for i, searchReq := range req.GetReq().GetReqs() {
|
||||
toReduceResults := make([]*internalpb.SearchResults, len(MultipleResults))
|
||||
for index, hs := range MultipleResults {
|
||||
toReduceResults[index] = hs.Results[i]
|
||||
}
|
||||
result, err := segments.ReduceSearchResults(ctx, toReduceResults, searchReq.GetNq(), searchReq.GetTopk(), searchReq.GetMetricType())
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
for ch, ts := range result.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
resp.Results = append(resp.Results, result)
|
||||
}
|
||||
resp.ChannelsMvcc = channelsMvcc
|
||||
|
||||
reduceLatency := tr.RecordSpan()
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
|
||||
Observe(float64(reduceLatency.Milliseconds()))
|
||||
|
||||
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
|
||||
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.HybridSearchLabel).
|
||||
Add(float64(proto.Size(req)))
|
||||
|
||||
if resp.GetCostAggregation() != nil {
|
||||
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// only used for delegator query segments from worker
|
||||
func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||
resp := &internalpb.RetrieveResults{
|
||||
|
|
|
@ -1323,6 +1323,47 @@ func (suite *ServiceSuite) TestSearchSegments_Failed() {
|
|||
suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestHybridSearch_Concurrent() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
concurrency := 16
|
||||
futures := make([]*conc.Future[*querypb.HybridSearchResult], 0, concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
future := conc.Go(func() (*querypb.HybridSearchResult, error) {
|
||||
creq1, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
suite.NoError(err)
|
||||
creq2, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
suite.NoError(err)
|
||||
req := &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63(),
|
||||
TargetID: suite.node.session.ServerID,
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
MvccTimestamp: typeutil.MaxTimestamp,
|
||||
Reqs: []*internalpb.SearchRequest{creq1, creq2},
|
||||
},
|
||||
DmlChannels: []string{suite.vchannel},
|
||||
}
|
||||
|
||||
return suite.node.HybridSearch(ctx, req)
|
||||
})
|
||||
futures = append(futures, future)
|
||||
}
|
||||
|
||||
err := conc.AwaitAll(futures...)
|
||||
suite.NoError(err)
|
||||
|
||||
for i := range futures {
|
||||
suite.True(merr.Ok(futures[i].Value().GetStatus()))
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
|
|
|
@ -82,6 +82,10 @@ func (m *GrpcQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequ
|
|||
return &internalpb.SearchResults{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
return &querypb.HybridSearchResult{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
||||
return &internalpb.SearchResults{}, m.Err
|
||||
}
|
||||
|
|
|
@ -93,6 +93,10 @@ func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest
|
|||
return qn.QueryNode.Search(ctx, in)
|
||||
}
|
||||
|
||||
func (qn *qnServerWrapper) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
return qn.QueryNode.HybridSearch(ctx, in)
|
||||
}
|
||||
|
||||
func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
||||
return qn.QueryNode.SearchSegments(ctx, in)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue