mirror of https://github.com/milvus-io/milvus.git
enhance: Refactor hybrid search (#31742)
issue: https://github.com/milvus-io/milvus/issues/25639 https://github.com/milvus-io/milvus/issues/31368 pr :https://github.com/milvus-io/milvus/pull/32020 Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/32038/head
parent
39d988cf8d
commit
4c07304790
|
@ -332,10 +332,3 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr
|
|||
return client.Delete(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
// HybridSearch performs replica hybrid search tasks in QueryNode.
|
||||
func (c *Client) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest, _ ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.HybridSearchResult, error) {
|
||||
return client.HybridSearch(ctx, req)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -387,8 +387,3 @@ func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
|
|||
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
return s.querynode.Delete(ctx, req)
|
||||
}
|
||||
|
||||
// HybridSearch performs hybrid search of streaming/historical replica on QueryNode.
|
||||
func (s *Server) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
return s.querynode.HybridSearch(ctx, req)
|
||||
}
|
||||
|
|
|
@ -1258,8 +1258,8 @@ type DataCoordCatalog_SaveChannelCheckpoints_Call struct {
|
|||
}
|
||||
|
||||
// SaveChannelCheckpoints is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - positions []*msgpb.MsgPosition
|
||||
// - ctx context.Context
|
||||
// - positions []*msgpb.MsgPosition
|
||||
func (_e *DataCoordCatalog_Expecter) SaveChannelCheckpoints(ctx interface{}, positions interface{}) *DataCoordCatalog_SaveChannelCheckpoints_Call {
|
||||
return &DataCoordCatalog_SaveChannelCheckpoints_Call{Call: _e.mock.On("SaveChannelCheckpoints", ctx, positions)}
|
||||
}
|
||||
|
|
|
@ -552,61 +552,6 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.C
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNode) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNode_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockQueryNode_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.HybridSearchRequest
|
||||
func (_e *MockQueryNode_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNode_HybridSearch_Call {
|
||||
return &MockQueryNode_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNode_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNode_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNode_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNode_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNode_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNode_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Init provides a mock function with given fields:
|
||||
func (_m *MockQueryNode) Init() error {
|
||||
ret := _m.Called()
|
||||
|
|
|
@ -632,76 +632,6 @@ func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(con
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNodeClient_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockQueryNodeClient_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *querypb.HybridSearchRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockQueryNodeClient_Expecter) HybridSearch(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_HybridSearch_Call {
|
||||
return &MockQueryNodeClient_HybridSearch_Call{Call: _e.mock.On("HybridSearch",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeClient_HybridSearch_Call) Run(run func(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeClient_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeClient_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeClient_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)) *MockQueryNodeClient_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadPartitions provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
|
|
@ -82,6 +82,20 @@ message CreateIndexRequest {
|
|||
repeated common.KeyValuePair extra_params = 8;
|
||||
}
|
||||
|
||||
|
||||
message SubSearchRequest {
|
||||
string dsl = 1;
|
||||
// serialized `PlaceholderGroup`
|
||||
bytes placeholder_group = 2;
|
||||
common.DslType dsl_type = 3;
|
||||
bytes serialized_expr_plan = 4;
|
||||
int64 nq = 5;
|
||||
repeated int64 partitionIDs = 6;
|
||||
int64 topk = 7;
|
||||
int64 offset = 8;
|
||||
string metricType = 9;
|
||||
}
|
||||
|
||||
message SearchRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 reqID = 2;
|
||||
|
@ -102,18 +116,22 @@ message SearchRequest {
|
|||
string metricType = 16;
|
||||
bool ignoreGrowing = 17; // Optional
|
||||
string username = 18;
|
||||
repeated SubSearchRequest sub_reqs = 19;
|
||||
bool is_advanced = 20;
|
||||
int64 offset = 21;
|
||||
common.ConsistencyLevel consistency_level = 22;
|
||||
}
|
||||
|
||||
message HybridSearchRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 reqID = 2;
|
||||
int64 dbID = 3;
|
||||
int64 collectionID = 4;
|
||||
repeated int64 partitionIDs = 5;
|
||||
repeated SearchRequest reqs = 6;
|
||||
uint64 mvcc_timestamp = 11;
|
||||
uint64 guarantee_timestamp = 12;
|
||||
uint64 timeout_timestamp = 13;
|
||||
message SubSearchResults {
|
||||
string metric_type = 1;
|
||||
int64 num_queries = 2;
|
||||
int64 top_k = 3;
|
||||
// schema.SearchResultsData inside
|
||||
bytes sliced_blob = 4;
|
||||
int64 sliced_num_count = 5;
|
||||
int64 sliced_offset = 6;
|
||||
// to indicate it belongs to which sub request
|
||||
int64 req_index = 7;
|
||||
}
|
||||
|
||||
message SearchResults {
|
||||
|
@ -134,6 +152,8 @@ message SearchResults {
|
|||
// search request cost
|
||||
CostAggregation costAggregation = 13;
|
||||
map<string, uint64> channels_mvcc = 14;
|
||||
repeated SubSearchResults sub_results = 15;
|
||||
bool is_advanced = 16;
|
||||
}
|
||||
|
||||
message CostAggregation {
|
||||
|
|
|
@ -139,8 +139,6 @@ service QueryNode {
|
|||
}
|
||||
rpc Search(SearchRequest) returns (internal.SearchResults) {
|
||||
}
|
||||
rpc HybridSearch(HybridSearchRequest) returns (HybridSearchResult) {
|
||||
}
|
||||
rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {
|
||||
}
|
||||
rpc Query(QueryRequest) returns (internal.RetrieveResults) {
|
||||
|
@ -416,20 +414,6 @@ message SearchRequest {
|
|||
int32 total_channel_num = 6;
|
||||
}
|
||||
|
||||
message HybridSearchRequest {
|
||||
internal.HybridSearchRequest req = 1;
|
||||
repeated string dml_channels = 2;
|
||||
int32 total_channel_num = 3;
|
||||
}
|
||||
|
||||
message HybridSearchResult {
|
||||
common.MsgBase base = 1;
|
||||
common.Status status = 2;
|
||||
repeated internal.SearchResults results = 3;
|
||||
internal.CostAggregation costAggregation = 4;
|
||||
map<string, uint64> channels_mvcc = 5;
|
||||
}
|
||||
|
||||
message QueryRequest {
|
||||
internal.RetrieveRequest req = 1;
|
||||
repeated string dml_channels = 2;
|
||||
|
|
|
@ -2825,18 +2825,19 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch")
|
||||
defer sp.End()
|
||||
|
||||
qt := &hybridSearchTask{
|
||||
newSearchReq := convertHybridSearchToSearch(request)
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
tr: timerecord.NewTimeRecorder(method),
|
||||
request: newSearchReq,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
qc: node.queryCoord,
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
|
@ -2921,7 +2922,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(qt.request.GetRequests())))
|
||||
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(request.GetRequests())))
|
||||
|
||||
searchDur := tr.ElapseSpan().Milliseconds()
|
||||
metrics.ProxySQLatency.WithLabelValues(
|
||||
|
|
|
@ -1562,6 +1562,31 @@ func TestProxy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
constructSubSearchRequest := func(nq int) *milvuspb.SubSearchRequest {
|
||||
plg := constructVectorsPlaceholderGroup(nq)
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params := make(map[string]string)
|
||||
params["nprobe"] = strconv.Itoa(nprobe)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
searchParams := []*commonpb.KeyValuePair{
|
||||
{Key: MetricTypeKey, Value: metric.L2},
|
||||
{Key: SearchParamsKey, Value: string(b)},
|
||||
{Key: AnnsFieldKey, Value: floatVecField},
|
||||
{Key: TopKKey, Value: strconv.Itoa(topk)},
|
||||
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
|
||||
}
|
||||
|
||||
return &milvuspb.SubSearchRequest{
|
||||
Dsl: expr,
|
||||
PlaceholderGroup: plgBs,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
SearchParams: searchParams,
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("search", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
@ -1572,7 +1597,7 @@ func TestProxy(t *testing.T) {
|
|||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
constructHybridSearchRequest := func(reqs []*milvuspb.SearchRequest) *milvuspb.HybridSearchRequest {
|
||||
constructAdvancedSearchRequest := func() *milvuspb.SearchRequest {
|
||||
params := make(map[string]float64)
|
||||
params[RRFParamsKey] = 60
|
||||
b, err := json.Marshal(params)
|
||||
|
@ -1584,32 +1609,33 @@ func TestProxy(t *testing.T) {
|
|||
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
|
||||
}
|
||||
|
||||
return &milvuspb.HybridSearchRequest{
|
||||
req1 := constructSubSearchRequest(nq)
|
||||
req2 := constructSubSearchRequest(nq)
|
||||
ret := &milvuspb.SearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Requests: reqs,
|
||||
PartitionNames: nil,
|
||||
OutputFields: nil,
|
||||
RankParams: rankParams,
|
||||
SearchParams: rankParams,
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
}
|
||||
ret.SubReqs = append(ret.SubReqs, req1)
|
||||
ret.SubReqs = append(ret.SubReqs, req2)
|
||||
return ret
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
nq = 1
|
||||
t.Run("hybrid search", func(t *testing.T) {
|
||||
t.Run("advanced search", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req1 := constructSearchRequest(nq)
|
||||
req2 := constructSearchRequest(nq)
|
||||
|
||||
resp, err := proxy.HybridSearch(ctx, constructHybridSearchRequest([]*milvuspb.SearchRequest{req1, req2}))
|
||||
req := constructAdvancedSearchRequest()
|
||||
resp, err := proxy.Search(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
nq = 10
|
||||
|
||||
nq = 10
|
||||
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
|
||||
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
|
||||
exprBytes := []byte(expr)
|
||||
|
|
|
@ -3,7 +3,9 @@ package proxy
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
@ -13,6 +15,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
type rankType int
|
||||
|
@ -35,16 +38,27 @@ type reScorer interface {
|
|||
name() string
|
||||
scorerType() rankType
|
||||
reScore(input *milvuspb.SearchResults)
|
||||
setMetricType(metricType string)
|
||||
getMetricType() string
|
||||
}
|
||||
|
||||
type baseScorer struct {
|
||||
scorerName string
|
||||
metricType string
|
||||
}
|
||||
|
||||
func (bs *baseScorer) name() string {
|
||||
return bs.scorerName
|
||||
}
|
||||
|
||||
func (bs *baseScorer) setMetricType(metricType string) {
|
||||
bs.metricType = metricType
|
||||
}
|
||||
|
||||
func (bs *baseScorer) getMetricType() string {
|
||||
return bs.metricType
|
||||
}
|
||||
|
||||
type rrfScorer struct {
|
||||
baseScorer
|
||||
k float32
|
||||
|
@ -65,9 +79,36 @@ type weightedScorer struct {
|
|||
weight float32
|
||||
}
|
||||
|
||||
type activateFunc func(float32) float32
|
||||
|
||||
func (ws *weightedScorer) getActivateFunc() activateFunc {
|
||||
mUpper := strings.ToUpper(ws.getMetricType())
|
||||
isCosine := mUpper == strings.ToUpper(metric.COSINE)
|
||||
isIP := mUpper == strings.ToUpper(metric.IP)
|
||||
if isCosine {
|
||||
f := func(distance float32) float32 {
|
||||
return (1 + distance) * 0.5
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
if isIP {
|
||||
f := func(distance float32) float32 {
|
||||
return 0.5 + float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
f := func(distance float32) float32 {
|
||||
return 1.0 - 2*float32(math.Atan(float64(distance)))/math.Pi
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) {
|
||||
for i, score := range input.Results.GetScores() {
|
||||
input.Results.Scores[i] = ws.weight * score
|
||||
activateF := ws.getActivateFunc()
|
||||
for i, distance := range input.Results.GetScores() {
|
||||
input.Results.Scores[i] = ws.weight * activateF(distance)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,13 +116,17 @@ func (ws *weightedScorer) scorerType() rankType {
|
|||
return weightedRankType
|
||||
}
|
||||
|
||||
func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
|
||||
res := make([]reScorer, len(reqs))
|
||||
func NewReScorers(reqCnt int, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) {
|
||||
if reqCnt == 0 {
|
||||
return []reScorer{}, nil
|
||||
}
|
||||
|
||||
res := make([]reScorer, reqCnt)
|
||||
rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams)
|
||||
if err != nil {
|
||||
log.Info("rank strategy not specified, use rrf instead")
|
||||
// if not set rank strategy, use rrf rank as default
|
||||
for i := range reqs {
|
||||
for i := 0; i < reqCnt; i++ {
|
||||
res[i] = &rrfScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "rrf",
|
||||
|
@ -123,7 +168,7 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
|
|||
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
|
||||
}
|
||||
log.Debug("rrf params", zap.Float64("k", k))
|
||||
for i := range reqs {
|
||||
for i := 0; i < reqCnt; i++ {
|
||||
res[i] = &rrfScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "rrf",
|
||||
|
@ -156,10 +201,10 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
|
|||
}
|
||||
|
||||
log.Debug("weights params", zap.Any("weights", weights))
|
||||
if len(reqs) != len(weights) {
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(reqs)), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
|
||||
if reqCnt != len(weights) {
|
||||
return nil, merr.WrapErrParameterInvalid(fmt.Sprint(reqCnt), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests")
|
||||
}
|
||||
for i := range reqs {
|
||||
for i := 0; i < reqCnt; i++ {
|
||||
res[i] = &weightedScorer{
|
||||
baseScorer: baseScorer{
|
||||
scorerName: "weighted",
|
||||
|
|
|
@ -7,12 +7,11 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
func TestRescorer(t *testing.T) {
|
||||
t.Run("default scorer", func(t *testing.T) {
|
||||
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, nil)
|
||||
rescorers, err := NewReScorers(2, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
|
@ -27,7 +26,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
_, err = NewReScorers(2, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "k not found in rank_params")
|
||||
})
|
||||
|
@ -42,7 +41,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
_, err = NewReScorers(2, rankParams)
|
||||
assert.Error(t, err)
|
||||
|
||||
params[RRFParamsKey] = maxRRFParamsValue + 1
|
||||
|
@ -53,7 +52,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
_, err = NewReScorers(2, rankParams)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -67,7 +66,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
rescorers, err := NewReScorers(2, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, rrfRankType, rescorers[0].scorerType())
|
||||
|
@ -83,7 +82,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
_, err = NewReScorers(2, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found in rank_params")
|
||||
})
|
||||
|
@ -99,7 +98,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
_, err = NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
_, err = NewReScorers(2, rankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "rank param weight should be in range [0, 1]")
|
||||
})
|
||||
|
@ -115,7 +114,7 @@ func TestRescorer(t *testing.T) {
|
|||
{Key: RankParamsKey, Value: string(b)},
|
||||
}
|
||||
|
||||
rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams)
|
||||
rescorers, err := NewReScorers(2, rankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(rescorers))
|
||||
assert.Equal(t, weightedRankType, rescorers[0].scorerType())
|
||||
|
|
|
@ -3,6 +3,8 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
@ -379,3 +381,134 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
|||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func rankSearchResultData(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
offset := params.offset
|
||||
limit := params.limit
|
||||
topk := limit + offset
|
||||
roundDecimal := params.roundDecimal
|
||||
log.Ctx(ctx).Debug("rankSearchResultData",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
|
||||
// []map[id]score
|
||||
accumulatedScores := make([]map[interface{}]float32, nq)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
accumulatedScores[i] = make(map[interface{}]float32)
|
||||
}
|
||||
|
||||
for _, result := range searchResults {
|
||||
scores := result.GetResults().GetScores()
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
realTopk := result.GetResults().Topks[i]
|
||||
for j := start; j < start+realTopk; j++ {
|
||||
id := typeutil.GetPK(result.GetResults().GetIds(), j)
|
||||
accumulatedScores[i][id] += scores[j]
|
||||
}
|
||||
start += realTopk
|
||||
}
|
||||
}
|
||||
|
||||
for i := int64(0); i < nq; i++ {
|
||||
idSet := accumulatedScores[i]
|
||||
keys := make([]interface{}, 0)
|
||||
for key := range idSet {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if int64(len(keys)) <= offset {
|
||||
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||
continue
|
||||
}
|
||||
|
||||
compareKeys := func(keyI, keyJ interface{}) bool {
|
||||
switch keyI.(type) {
|
||||
case int64:
|
||||
return keyI.(int64) < keyJ.(int64)
|
||||
case string:
|
||||
return keyI.(string) < keyJ.(string)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
big := func(i, j int) bool {
|
||||
if idSet[keys[i]] == idSet[keys[j]] {
|
||||
return compareKeys(keys[i], keys[j])
|
||||
}
|
||||
return idSet[keys[i]] > idSet[keys[j]]
|
||||
}
|
||||
|
||||
sort.Slice(keys, big)
|
||||
|
||||
if int64(len(keys)) > topk {
|
||||
keys = keys[:topk]
|
||||
}
|
||||
|
||||
// set real topk
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
|
||||
// append id and score
|
||||
for index := offset; index < int64(len(keys)); index++ {
|
||||
typeutil.AppendPKs(ret.Results.Ids, keys[index])
|
||||
score := idSet[keys[index]]
|
||||
if roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func fillInEmptyResult(numQueries int64) *milvuspb.SearchResults {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success("search result is empty"),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: numQueries,
|
||||
Topks: make([]int64, numQueries),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,163 +3,309 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func initSearchRequest(ctx context.Context, t *searchTask, isHybrid bool) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
|
||||
defer sp.End()
|
||||
type rankParams struct {
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
var err error
|
||||
for i, kv := range t.request.GetSearchParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
||||
if err != nil {
|
||||
return errors.New("parse search growing failed")
|
||||
}
|
||||
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) {
|
||||
// 1. parse offset and real topk
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
log.Warn("failed to get nq", zap.Error(err))
|
||||
return err
|
||||
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateNQLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
log = log.With(zap.Int64("nq", nq))
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
log.Warn("fail to get output field ids", zap.Error(err))
|
||||
return err
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
}
|
||||
if err := validateTopKLimit(topK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
||||
}
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||
if err != nil || len(annsFieldName) == 0 {
|
||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||
if len(vecFields) == 0 {
|
||||
return errors.New(AnnsFieldKey + " not found in schema")
|
||||
}
|
||||
|
||||
if enableMultipleVectorFields && len(vecFields) > 1 {
|
||||
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||
}
|
||||
|
||||
annsFieldName = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
|
||||
if queryInfo.GetGroupByFieldId() != -1 && isHybrid {
|
||||
return errors.New("not support search_group_by operation in the hybrid search")
|
||||
}
|
||||
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||
return errors.New("not support search_group_by operation based on binary vector column")
|
||||
}
|
||||
|
||||
t.offset = offset
|
||||
|
||||
plan, err := planparserv2.CreateSearchPlan(t.schema.schemaHelper, t.request.Dsl, annsFieldName, queryInfo)
|
||||
if err != nil {
|
||||
log.Warn("failed to create query plan", zap.Error(err),
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
|
||||
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
|
||||
}
|
||||
log.Debug("create query plan",
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
|
||||
|
||||
if t.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(plan)
|
||||
var offset int64
|
||||
if !ignoreOffset {
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
log.Warn("failed to parse expr", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
|
||||
if err != nil {
|
||||
log.Warn("failed to assign partition keys", zap.Error(err))
|
||||
return err
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
}
|
||||
|
||||
if len(hashedPartitionNames) > 0 {
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames)
|
||||
if err != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err))
|
||||
return err
|
||||
if offset != 0 {
|
||||
if err := validateTopKLimit(offset); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.queryInfo = queryInfo
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||
if err != nil {
|
||||
log.Warn("failed to estimate result size", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if estimateSize >= requeryThreshold {
|
||||
t.requery = true
|
||||
plan.OutputFieldIds = nil
|
||||
}
|
||||
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
|
||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
queryTopK := topK + offset
|
||||
if err := validateTopKLimit(queryTopK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
||||
}
|
||||
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
|
||||
// Set username of this search request for feature like task scheduling.
|
||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||
t.SearchRequest.Username = username
|
||||
// 2. parse metrics type
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
metricType = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
// 3. parse round decimal
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
// 4. parse search param str
|
||||
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
||||
if err != nil {
|
||||
searchParamStr = ""
|
||||
}
|
||||
|
||||
err = checkRangeSearchParams(searchParamStr, metricType)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 5. parse group by field
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
}
|
||||
var groupByFieldId int64 = -1
|
||||
if groupByFieldName != "" {
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.Name == groupByFieldName {
|
||||
groupByFieldId = field.FieldID
|
||||
break
|
||||
}
|
||||
}
|
||||
if groupByFieldId == -1 {
|
||||
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||
}
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
if isIterator == "True" && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do groupBy when doing iteration")
|
||||
}
|
||||
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do range-search when doing search-group-by")
|
||||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
||||
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
|
||||
for _, name := range outputFields {
|
||||
id, ok := schema.MapFieldID(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Field %s not exist", name)
|
||||
}
|
||||
outputFieldIDs = append(outputFieldIDs, id)
|
||||
}
|
||||
return outputFieldIDs, nil
|
||||
}
|
||||
|
||||
func getNqFromSubSearch(req *milvuspb.SubSearchRequest) (int64, error) {
|
||||
if req.GetNq() == 0 {
|
||||
// keep compatible with older client version.
|
||||
x := &commonpb.PlaceholderGroup{}
|
||||
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := int64(0)
|
||||
for _, h := range x.GetPlaceholders() {
|
||||
total += int64(len(h.Values))
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
return req.GetNq(), nil
|
||||
}
|
||||
|
||||
func getNq(req *milvuspb.SearchRequest) (int64, error) {
|
||||
if req.GetNq() == 0 {
|
||||
// keep compatible with older client version.
|
||||
x := &commonpb.PlaceholderGroup{}
|
||||
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := int64(0)
|
||||
for _, h := range x.GetPlaceholders() {
|
||||
total += int64(len(h.Values))
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
return req.GetNq(), nil
|
||||
}
|
||||
|
||||
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
for _, tag := range partitionNames {
|
||||
if err := validatePartitionTag(tag, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
|
||||
|
||||
partitionsSet := typeutil.NewSet[int64]()
|
||||
for _, partitionName := range partitionNames {
|
||||
if useRegexp {
|
||||
// Legacy feature, use partition name as regexp
|
||||
pattern := fmt.Sprintf("^%s$", partitionName)
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid partition: %s", partitionName)
|
||||
}
|
||||
var found bool
|
||||
for name, pID := range partitionsMap {
|
||||
if re.MatchString(name) {
|
||||
partitionsSet.Insert(pID)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
||||
}
|
||||
} else {
|
||||
partitionID, found := partitionsMap[partitionName]
|
||||
if !found {
|
||||
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
|
||||
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
||||
}
|
||||
if !partitionsSet.Contain(partitionID) {
|
||||
partitionsSet.Insert(partitionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return partitionsSet.Collect(), nil
|
||||
}
|
||||
|
||||
// parseRankParams get limit and offset from rankParams, both are optional.
|
||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
|
||||
var (
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
err error
|
||||
)
|
||||
|
||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(LimitKey + " not found in rank_params")
|
||||
}
|
||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
|
||||
}
|
||||
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
}
|
||||
}
|
||||
|
||||
// validate max result window.
|
||||
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
|
||||
return nil, fmt.Errorf("invalid max query result window, %w", err)
|
||||
}
|
||||
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
return &rankParams{
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
roundDecimal: roundDecimal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
||||
ret := &milvuspb.SearchRequest{
|
||||
Base: req.GetBase(),
|
||||
DbName: req.GetDbName(),
|
||||
CollectionName: req.GetCollectionName(),
|
||||
PartitionNames: req.GetPartitionNames(),
|
||||
OutputFields: req.GetOutputFields(),
|
||||
SearchParams: req.GetRankParams(),
|
||||
TravelTimestamp: req.GetTravelTimestamp(),
|
||||
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
||||
Nq: 0,
|
||||
NotReturnAllMeta: req.GetNotReturnAllMeta(),
|
||||
ConsistencyLevel: req.GetConsistencyLevel(),
|
||||
UseDefaultConsistency: req.GetUseDefaultConsistency(),
|
||||
SearchByPrimaryKeys: false,
|
||||
SubReqs: nil,
|
||||
}
|
||||
|
||||
for _, sub := range req.GetRequests() {
|
||||
subReq := &milvuspb.SubSearchRequest{
|
||||
Dsl: sub.GetDsl(),
|
||||
PlaceholderGroup: sub.GetPlaceholderGroup(),
|
||||
DslType: sub.GetDslType(),
|
||||
SearchParams: sub.GetSearchParams(),
|
||||
Nq: sub.GetNq(),
|
||||
}
|
||||
ret.SubReqs = append(ret.SubReqs, subReq)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
@ -1,659 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
HybridSearchTaskName = "HybridSearchTask"
|
||||
)
|
||||
|
||||
type hybridSearchTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
ctx context.Context
|
||||
*internalpb.HybridSearchRequest
|
||||
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.HybridSearchRequest
|
||||
searchTasks []*searchTask
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
schema *schemaInfo
|
||||
requery bool
|
||||
partitionKeyMode bool
|
||||
|
||||
userOutputFields []string
|
||||
|
||||
qc types.QueryCoordClient
|
||||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
|
||||
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
|
||||
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
||||
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
|
||||
|
||||
reScorers []reScorer
|
||||
queryChannelsTs map[string]Timestamp
|
||||
rankParams *rankParams
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PreExecute")
|
||||
defer sp.End()
|
||||
|
||||
if len(t.request.Requests) <= 0 {
|
||||
return errors.New("minimum of ann search requests is 1")
|
||||
}
|
||||
|
||||
if len(t.request.Requests) > defaultMaxSearchRequest {
|
||||
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
|
||||
}
|
||||
for _, req := range t.request.GetRequests() {
|
||||
nq, err := getNq(req)
|
||||
if err != nil {
|
||||
log.Debug("failed to get nq", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if nq != 1 {
|
||||
err = merr.WrapErrParameterInvalid("1", fmt.Sprint(nq), "nq should be equal to 1")
|
||||
log.Debug(err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.Base.MsgType = commonpb.MsgType_Search
|
||||
t.Base.SourceID = paramtable.GetNodeID()
|
||||
|
||||
collectionName := t.request.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.CollectionID = collID
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get collection schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("is partition key mode failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if t.partitionKeyMode {
|
||||
if len(t.request.GetPartitionNames()) != 0 {
|
||||
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
||||
}
|
||||
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||
}
|
||||
|
||||
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
|
||||
if err != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
||||
if err != nil {
|
||||
log.Warn("translate output fields failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("translate output fields",
|
||||
zap.Strings("output fields", t.request.GetOutputFields()))
|
||||
|
||||
if len(t.request.OutputFields) > 0 {
|
||||
t.requery = true
|
||||
}
|
||||
|
||||
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
|
||||
if err2 != nil {
|
||||
log.Warn("Proxy::hybridSearchTask::PreExecute failed to GetCollectionInfo from cache",
|
||||
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
|
||||
return err2
|
||||
}
|
||||
guaranteeTs := t.request.GetGuaranteeTimestamp()
|
||||
var consistencyLevel commonpb.ConsistencyLevel
|
||||
useDefaultConsistency := t.request.GetUseDefaultConsistency()
|
||||
if useDefaultConsistency {
|
||||
consistencyLevel = collectionInfo.consistencyLevel
|
||||
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
|
||||
} else {
|
||||
consistencyLevel = t.request.GetConsistencyLevel()
|
||||
// Compatibility logic, parse guarantee timestamp
|
||||
if consistencyLevel == 0 && guaranteeTs > 0 {
|
||||
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
|
||||
} else {
|
||||
// parse from guarantee timestamp and user input consistency level
|
||||
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
|
||||
}
|
||||
}
|
||||
|
||||
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.HybridSearchRequest.GuaranteeTimestamp = guaranteeTs
|
||||
t.searchTasks = make([]*searchTask, len(t.request.GetRequests()))
|
||||
for index := range t.request.Requests {
|
||||
searchReq := t.request.Requests[index]
|
||||
|
||||
if len(searchReq.GetCollectionName()) == 0 {
|
||||
searchReq.CollectionName = t.request.GetCollectionName()
|
||||
} else if searchReq.GetCollectionName() != t.request.GetCollectionName() {
|
||||
return errors.New(fmt.Sprintf("inconsistent collection name in hybrid search request, "+
|
||||
"expect %s, actual %s", searchReq.GetCollectionName(), t.request.GetCollectionName()))
|
||||
}
|
||||
|
||||
searchReq.PartitionNames = t.request.GetPartitionNames()
|
||||
searchReq.ConsistencyLevel = consistencyLevel
|
||||
searchReq.GuaranteeTimestamp = guaranteeTs
|
||||
searchReq.UseDefaultConsistency = useDefaultConsistency
|
||||
searchReq.OutputFields = nil
|
||||
|
||||
t.searchTasks[index] = &searchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
collectionName: collectionName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
DbID: 0, // todo
|
||||
CollectionID: collID,
|
||||
PartitionIDs: t.GetPartitionIDs(),
|
||||
},
|
||||
request: searchReq,
|
||||
schema: t.schema,
|
||||
tr: timerecord.NewTimeRecorder("hybrid search"),
|
||||
qc: t.qc,
|
||||
node: t.node,
|
||||
lb: t.lb,
|
||||
|
||||
partitionKeyMode: t.partitionKeyMode,
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
}
|
||||
err := initSearchRequest(ctx, t.searchTasks[index], true)
|
||||
if err != nil {
|
||||
log.Debug("init hybrid search request failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if t.partitionKeyMode {
|
||||
t.partitionIDsSet.Upsert(t.searchTasks[index].GetPartitionIDs()...)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("hybrid search preExecute done.",
|
||||
zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
|
||||
zap.Any("consistency level", t.request.GetConsistencyLevel()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
|
||||
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
|
||||
hybridSearchReq.GetBase().TargetID = nodeID
|
||||
if t.partitionKeyMode {
|
||||
t.PartitionIDs = t.partitionIDsSet.Collect()
|
||||
}
|
||||
req := &querypb.HybridSearchRequest{
|
||||
Req: hybridSearchReq,
|
||||
DmlChannels: []string{channel},
|
||||
TotalChannelNum: int32(1),
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.String("channel", channel))
|
||||
|
||||
var result *querypb.HybridSearchResult
|
||||
var err error
|
||||
|
||||
result, err = qn.HybridSearch(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode hybrid search return error", zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.request.GetCollectionName())
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Warn("QueryNode is not shardLeader")
|
||||
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.request.GetCollectionName())
|
||||
return errInvalidShardLeaders
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("QueryNode hybrid search result error",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return errors.Wrapf(merr.Error(result.GetStatus()), "fail to hybrid search on QueryNode %d", nodeID)
|
||||
}
|
||||
t.resultBuf.Insert(result)
|
||||
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
|
||||
for _, searchTask := range t.searchTasks {
|
||||
t.HybridSearchRequest.Reqs = append(t.HybridSearchRequest.Reqs, searchTask.SearchRequest)
|
||||
}
|
||||
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*querypb.HybridSearchResult]()
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.CollectionID,
|
||||
collectionName: t.request.GetCollectionName(),
|
||||
nq: 1,
|
||||
exec: t.hybridSearchShard,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("hybrid search execute failed", zap.Error(err))
|
||||
return errors.Wrap(err, "failed to hybrid search")
|
||||
}
|
||||
|
||||
log.Debug("hybrid search execute done.")
|
||||
return nil
|
||||
}
|
||||
|
||||
type rankParams struct {
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
}
|
||||
|
||||
// parseRankParams get limit and offset from rankParams, both are optional.
|
||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
|
||||
var (
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
err error
|
||||
)
|
||||
|
||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(LimitKey + " not found in rank_params")
|
||||
}
|
||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
|
||||
}
|
||||
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
}
|
||||
}
|
||||
|
||||
// validate max result window.
|
||||
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
|
||||
return nil, fmt.Errorf("invalid max query result window, %w", err)
|
||||
}
|
||||
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
return &rankParams{
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
roundDecimal: roundDecimal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) collectHybridSearchResults(ctx context.Context) error {
|
||||
select {
|
||||
case <-t.TraceCtx().Done():
|
||||
log.Ctx(ctx).Warn("hybrid search task wait to finish timeout!")
|
||||
return fmt.Errorf("hybrid search task wait to finish timeout, msgID=%d", t.ID())
|
||||
default:
|
||||
log.Ctx(ctx).Debug("all hybrid searches are finished or canceled")
|
||||
t.resultBuf.Range(func(res *querypb.HybridSearchResult) bool {
|
||||
for index, searchResult := range res.GetResults() {
|
||||
t.searchTasks[index].resultBuf.Insert(searchResult)
|
||||
}
|
||||
log.Ctx(ctx).Debug("proxy receives one hybrid search result",
|
||||
zap.Int64("sourceID", res.GetBase().GetSourceID()))
|
||||
return true
|
||||
})
|
||||
|
||||
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
|
||||
for i, searchTask := range t.searchTasks {
|
||||
err := searchTask.PostExecute(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.reScorers[i].reScore(searchTask.result)
|
||||
t.multipleRecallResults.Insert(searchTask.result)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
err := t.collectHybridSearchResults(ctx)
|
||||
if err != nil {
|
||||
log.Warn("failed to collect hybrid search results", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
metricType := ""
|
||||
t.queryChannelsTs = make(map[string]uint64)
|
||||
for _, r := range t.resultBuf.Collect() {
|
||||
metricType = r.GetResults()[0].GetMetricType()
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
t.queryChannelsTs[ch] = ts
|
||||
}
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.rankParams, err = parseRankParams(t.request.GetRankParams())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.result, err = rankSearchResultData(ctx, 1,
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
metricType,
|
||||
t.multipleRecallResults.Collect())
|
||||
if err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
t.fillInFieldInfo()
|
||||
|
||||
if t.requery {
|
||||
err := t.Requery()
|
||||
if err != nil {
|
||||
log.Warn("failed to requery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
|
||||
log.Debug("hybrid search post execute done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Requery() error {
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
},
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: "",
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
TravelTimestamp: t.request.GetTravelTimestamp(),
|
||||
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
|
||||
ConsistencyLevel: t.request.GetConsistencyLevel(),
|
||||
UseDefaultConsistency: t.request.GetUseDefaultConsistency(),
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: LimitKey,
|
||||
Value: strconv.FormatInt(t.rankParams.limit, 10),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
|
||||
}
|
||||
|
||||
func rankSearchResultData(ctx context.Context,
|
||||
nq int64,
|
||||
params *rankParams,
|
||||
pkType schemapb.DataType,
|
||||
metricType string,
|
||||
searchResults []*milvuspb.SearchResults,
|
||||
) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("rankSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
offset := params.offset
|
||||
limit := params.limit
|
||||
topk := limit + offset
|
||||
roundDecimal := params.roundDecimal
|
||||
log.Ctx(ctx).Debug("rankSearchResultData",
|
||||
zap.Int("len(searchResults)", len(searchResults)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit),
|
||||
zap.String("metric type", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: limit,
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0),
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: make([]string, 0),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
|
||||
// []map[id]score
|
||||
accumulatedScores := make([]map[interface{}]float32, nq)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
accumulatedScores[i] = make(map[interface{}]float32)
|
||||
}
|
||||
|
||||
for _, result := range searchResults {
|
||||
scores := result.GetResults().GetScores()
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
realTopk := result.GetResults().Topks[i]
|
||||
for j := start; j < start+realTopk; j++ {
|
||||
id := typeutil.GetPK(result.GetResults().GetIds(), j)
|
||||
accumulatedScores[i][id] += scores[j]
|
||||
}
|
||||
start += realTopk
|
||||
}
|
||||
}
|
||||
|
||||
for i := int64(0); i < nq; i++ {
|
||||
idSet := accumulatedScores[i]
|
||||
keys := make([]interface{}, 0)
|
||||
for key := range idSet {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
if int64(len(keys)) <= offset {
|
||||
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||
continue
|
||||
}
|
||||
|
||||
compareKeys := func(keyI, keyJ interface{}) bool {
|
||||
switch keyI.(type) {
|
||||
case int64:
|
||||
return keyI.(int64) < keyJ.(int64)
|
||||
case string:
|
||||
return keyI.(string) < keyJ.(string)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sort id by score
|
||||
var less func(i, j int) bool
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
less = func(i, j int) bool {
|
||||
if idSet[keys[i]] == idSet[keys[j]] {
|
||||
return compareKeys(keys[i], keys[j])
|
||||
}
|
||||
return idSet[keys[i]] > idSet[keys[j]]
|
||||
}
|
||||
} else {
|
||||
less = func(i, j int) bool {
|
||||
if idSet[keys[i]] == idSet[keys[j]] {
|
||||
return compareKeys(keys[i], keys[j])
|
||||
}
|
||||
return idSet[keys[i]] < idSet[keys[j]]
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(keys, less)
|
||||
|
||||
if int64(len(keys)) > topk {
|
||||
keys = keys[:topk]
|
||||
}
|
||||
|
||||
// set real topk
|
||||
ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset)
|
||||
// append id and score
|
||||
for index := offset; index < int64(len(keys)); index++ {
|
||||
typeutil.AppendPKs(ret.Results.Ids, keys[index])
|
||||
score := idSet[keys[index]]
|
||||
if roundDecimal != -1 {
|
||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||
}
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) fillInFieldInfo() {
|
||||
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
||||
for i, name := range t.request.OutputFields {
|
||||
for _, field := range t.schema.Fields {
|
||||
if t.result.Results.FieldsData[i] != nil && field.Name == name {
|
||||
t.result.Results.FieldsData[i].FieldName = field.Name
|
||||
t.result.Results.FieldsData[i].FieldId = field.FieldID
|
||||
t.result.Results.FieldsData[i].Type = field.DataType
|
||||
t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) ID() UniqueID {
|
||||
return t.Base.MsgID
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) SetID(uid UniqueID) {
|
||||
t.Base.MsgID = uid
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Name() string {
|
||||
return HybridSearchTaskName
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) Type() commonpb.MsgType {
|
||||
return t.Base.MsgType
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) BeginTs() Timestamp {
|
||||
return t.Base.Timestamp
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) EndTs() Timestamp {
|
||||
return t.Base.Timestamp
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) SetTs(ts Timestamp) {
|
||||
t.Base.Timestamp = ts
|
||||
}
|
||||
|
||||
func (t *hybridSearchTask) OnEnqueue() error {
|
||||
t.Base = commonpbutil.NewMsgBase()
|
||||
t.Base.MsgType = commonpb.MsgType_Search
|
||||
t.Base.SourceID = paramtable.GetNodeID()
|
||||
return nil
|
||||
}
|
|
@ -1,363 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func createCollWithMultiVecField(t *testing.T, name string, rc types.RootCoordClient) {
|
||||
schema := genCollectionSchema(name)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
require.NoError(t, err)
|
||||
ctx := context.TODO()
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(context.TODO()),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: name,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
}
|
||||
|
||||
require.NoError(t, createColT.OnEnqueue())
|
||||
require.NoError(t, createColT.PreExecute(ctx))
|
||||
require.NoError(t, createColT.Execute(ctx))
|
||||
require.NoError(t, createColT.PostExecute(ctx))
|
||||
}
|
||||
|
||||
func TestHybridSearchTask_PreExecute(t *testing.T) {
|
||||
var err error
|
||||
|
||||
var (
|
||||
rc = NewRootCoordMock()
|
||||
qc = mocks.NewMockQueryCoordClient(t)
|
||||
ctx = context.TODO()
|
||||
)
|
||||
|
||||
defer rc.Close()
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
|
||||
task := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
CollectionName: collName,
|
||||
Requests: reqs,
|
||||
},
|
||||
qc: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-hybrid-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
t.Run("bad nq 0", func(t *testing.T) {
|
||||
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
// Nq must be 1.
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 0}})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("bad req num 0", func(t *testing.T) {
|
||||
collName := "test_bad_req_num0_error" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
// num of reqs must be [1, 1024].
|
||||
task := genHybridSearchTaskWithNq(t, collName, nil)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("bad req num 1025", func(t *testing.T) {
|
||||
collName := "test_bad_req_num1025_error" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
// num of reqs must be [1, 1024].
|
||||
reqs := make([]*milvuspb.SearchRequest, 0)
|
||||
for i := 0; i <= defaultMaxSearchRequest; i++ {
|
||||
reqs = append(reqs, &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
Nq: 1,
|
||||
})
|
||||
}
|
||||
task := genHybridSearchTaskWithNq(t, collName, reqs)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("collection not exist", func(t *testing.T) {
|
||||
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("hybrid search with timeout", func(t *testing.T) {
|
||||
collName := "hybrid_search_with_timeout" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
task.ctx = ctxTimeout
|
||||
task.request.OutputFields = []string{testFloatVecField}
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
})
|
||||
|
||||
t.Run("hybrid search with group_by", func(t *testing.T) {
|
||||
collName := "hybrid_search_with_group_by" + funcutil.GenRandomStr()
|
||||
createCollWithMultiVecField(t, collName, rc)
|
||||
|
||||
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{
|
||||
{Nq: 1, DslType: commonpb.DslType_BoolExprV1, SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: "fvec"},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
{Key: GroupByFieldKey, Value: "bool"},
|
||||
}},
|
||||
})
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
task.ctx = ctxTimeout
|
||||
task.request.OutputFields = []string{testFloatVecField}
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "not support search_group_by operation in the hybrid search", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
ctx = context.TODO()
|
||||
|
||||
rc = NewRootCoordMock()
|
||||
qc = getQueryCoordClient()
|
||||
qn = getQueryNodeClient()
|
||||
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
|
||||
factory := dependency.NewDefaultFactory(true)
|
||||
node, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
|
||||
assert.NoError(t, err)
|
||||
node.sched = scheduler
|
||||
err = node.sched.Start()
|
||||
assert.NoError(t, err)
|
||||
err = node.initRateCollector()
|
||||
assert.NoError(t, err)
|
||||
node.rootCoord = rc
|
||||
node.queryCoord = qc
|
||||
|
||||
defer qc.Close()
|
||||
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createCollWithMultiVecField(t, collectionName, rc)
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
|
||||
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
|
||||
Status: successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1},
|
||||
NodeAddrs: []string{"localhost:9000"},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||
Status: successStatus,
|
||||
CollectionIDs: []int64{collectionID},
|
||||
InMemoryPercentages: []int64{100},
|
||||
}, nil)
|
||||
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
vectorFields := typeutil.GetVectorFieldSchemas(schema.CollectionSchema)
|
||||
vectorFieldNames := make([]string, len(vectorFields))
|
||||
for i, field := range vectorFields {
|
||||
vectorFieldNames[i] = field.GetName()
|
||||
}
|
||||
|
||||
// test begins
|
||||
task := &hybridSearchTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
result: &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
},
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
CollectionName: collectionName,
|
||||
Requests: []*milvuspb.SearchRequest{
|
||||
{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Nq: 1,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: testFloatVecField},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Nq: 1,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: testBinaryVecField},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputFields: vectorFieldNames,
|
||||
},
|
||||
qc: qc,
|
||||
lb: lb,
|
||||
node: node,
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
task.ctx = ctx
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
|
||||
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
},
|
||||
}, nil)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
}
|
||||
|
||||
func TestHybridSearchTask_PostExecute(t *testing.T) {
|
||||
var (
|
||||
rc = NewRootCoordMock()
|
||||
qc = getQueryCoordClient()
|
||||
qn = getQueryNodeClient()
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
defer rc.Close()
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Status: merr.Success(),
|
||||
}, nil)
|
||||
|
||||
t.Run("Test empty result", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := InitMetaCache(ctx, rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
createCollWithMultiVecField(t, collectionName, rc)
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
{Key: LimitKey, Value: strconv.Itoa(3)},
|
||||
{Key: OffsetKey, Value: strconv.Itoa(2)},
|
||||
}
|
||||
qt := &hybridSearchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(context.TODO()),
|
||||
qc: nil,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
schema: schema,
|
||||
HybridSearchRequest: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
request: &milvuspb.HybridSearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
RankParams: rankParams,
|
||||
},
|
||||
resultBuf: typeutil.NewConcurrentSet[*querypb.HybridSearchResult](),
|
||||
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
|
||||
}
|
||||
|
||||
err = qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
})
|
||||
}
|
|
@ -5,9 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -23,7 +21,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
|
@ -32,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -50,8 +48,8 @@ const (
|
|||
|
||||
type searchTask struct {
|
||||
Condition
|
||||
*internalpb.SearchRequest
|
||||
ctx context.Context
|
||||
*internalpb.SearchRequest
|
||||
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.SearchRequest
|
||||
|
@ -64,196 +62,18 @@ type searchTask struct {
|
|||
|
||||
userOutputFields []string
|
||||
|
||||
offset int64
|
||||
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
|
||||
|
||||
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
|
||||
|
||||
qc types.QueryCoordClient
|
||||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
queryChannelsTs map[string]Timestamp
|
||||
queryInfo *planpb.QueryInfo
|
||||
}
|
||||
queryInfos []*planpb.QueryInfo
|
||||
|
||||
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
for _, tag := range partitionNames {
|
||||
if err := validatePartitionTag(tag, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
|
||||
|
||||
partitionsSet := typeutil.NewSet[int64]()
|
||||
for _, partitionName := range partitionNames {
|
||||
if useRegexp {
|
||||
// Legacy feature, use partition name as regexp
|
||||
pattern := fmt.Sprintf("^%s$", partitionName)
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid partition: %s", partitionName)
|
||||
}
|
||||
var found bool
|
||||
for name, pID := range partitionsMap {
|
||||
if re.MatchString(name) {
|
||||
partitionsSet.Insert(pID)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
||||
}
|
||||
} else {
|
||||
partitionID, found := partitionsMap[partitionName]
|
||||
if !found {
|
||||
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
|
||||
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
||||
}
|
||||
if !partitionsSet.Contain(partitionID) {
|
||||
partitionsSet.Insert(partitionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return partitionsSet.Collect(), nil
|
||||
}
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*planpb.QueryInfo, int64, error) {
|
||||
// 1. parse offset and real topk
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
||||
}
|
||||
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
}
|
||||
if err := validateTopKLimit(topK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
||||
}
|
||||
|
||||
var offset int64
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
}
|
||||
|
||||
if offset != 0 {
|
||||
if err := validateTopKLimit(offset); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
queryTopK := topK + offset
|
||||
if err := validateTopKLimit(queryTopK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
||||
}
|
||||
|
||||
// 2. parse metrics type
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
metricType = ""
|
||||
}
|
||||
|
||||
// 3. parse round decimal
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
// 4. parse search param str
|
||||
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
||||
if err != nil {
|
||||
searchParamStr = ""
|
||||
}
|
||||
|
||||
err = checkRangeSearchParams(searchParamStr, metricType)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 5. parse group by field
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
}
|
||||
var groupByFieldId int64 = -1
|
||||
if groupByFieldName != "" {
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.Name == groupByFieldName {
|
||||
groupByFieldId = field.FieldID
|
||||
break
|
||||
}
|
||||
}
|
||||
if groupByFieldId == -1 {
|
||||
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||
}
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
if isIterator == "True" && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do groupBy when doing iteration")
|
||||
}
|
||||
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do range-search when doing search-group-by")
|
||||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
||||
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
|
||||
for _, name := range outputFields {
|
||||
id, ok := schema.MapFieldID(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Field %s not exist", name)
|
||||
}
|
||||
outputFieldIDs = append(outputFieldIDs, id)
|
||||
}
|
||||
return outputFieldIDs, nil
|
||||
}
|
||||
|
||||
func getNq(req *milvuspb.SearchRequest) (int64, error) {
|
||||
if req.GetNq() == 0 {
|
||||
// keep compatible with older client version.
|
||||
x := &commonpb.PlaceholderGroup{}
|
||||
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := int64(0)
|
||||
for _, h := range x.GetPlaceholders() {
|
||||
total += int64(len(h.Values))
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
return req.GetNq(), nil
|
||||
reScorers []reScorer
|
||||
rankParams *rankParams
|
||||
}
|
||||
|
||||
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
|
@ -284,7 +104,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool {
|
|||
func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PreExecute")
|
||||
defer sp.End()
|
||||
|
||||
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
|
||||
t.Base.MsgType = commonpb.MsgType_Search
|
||||
t.Base.SourceID = paramtable.GetNodeID()
|
||||
|
||||
|
@ -315,7 +135,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
|
||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
|
||||
if err != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err))
|
||||
return err
|
||||
|
@ -330,7 +150,59 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
log.Debug("translate output fields",
|
||||
zap.Strings("output fields", t.request.GetOutputFields()))
|
||||
|
||||
err = initSearchRequest(ctx, t, false)
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
if len(t.request.GetSubReqs()) > defaultMaxSearchRequest {
|
||||
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
|
||||
}
|
||||
}
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Manually update nq if not set.
|
||||
nq, err := t.checkNq(ctx)
|
||||
if err != nil {
|
||||
log.Info("failed to check nq", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
var ignoreGrowing bool
|
||||
// parse common search params
|
||||
for i, kv := range t.request.GetSearchParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
||||
if err != nil {
|
||||
return errors.New("parse search growing failed")
|
||||
}
|
||||
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
if err != nil {
|
||||
log.Info("fail to get output field ids", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
// Currently, we get vectors by requery. Once we support getting vectors from search,
|
||||
// searches with small result size could no longer need requery.
|
||||
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||
})
|
||||
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
t.requery = len(t.request.OutputFields) > 0
|
||||
err = t.initAdvancedSearchRequest(ctx)
|
||||
} else {
|
||||
t.requery = len(vectorOutputFields) > 0
|
||||
err = t.initSearchRequest(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
log.Debug("init search request failed", zap.Error(err))
|
||||
return err
|
||||
|
@ -359,6 +231,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
||||
t.SearchRequest.ConsistencyLevel = consistencyLevel
|
||||
|
||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
}
|
||||
|
||||
// Set username of this search request for feature like task scheduling.
|
||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||
t.SearchRequest.Username = username
|
||||
}
|
||||
|
||||
log.Debug("search PreExecute done.",
|
||||
zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
|
@ -368,16 +250,229 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) checkNq(ctx context.Context) (int64, error) {
|
||||
var nq int64
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
// In the context of Advanced Search, it is essential to verify that the number of vectors
|
||||
// for each individual search, denoted as nq, remains consistent.
|
||||
nq = t.request.GetNq()
|
||||
for _, req := range t.request.GetSubReqs() {
|
||||
subNq, err := getNqFromSubSearch(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
req.Nq = subNq
|
||||
if nq == 0 {
|
||||
nq = subNq
|
||||
continue
|
||||
}
|
||||
if subNq != nq {
|
||||
err = merr.WrapErrParameterInvalid(nq, subNq, "sub search request nq should be the same")
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
t.request.Nq = nq
|
||||
} else {
|
||||
var err error
|
||||
nq, err = getNq(t.request)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
t.request.Nq = nq
|
||||
}
|
||||
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateNQLimit(nq); err != nil {
|
||||
return 0, fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
return nq, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request")
|
||||
defer sp.End()
|
||||
|
||||
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||
for index, subReq := range t.request.GetSubReqs() {
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryInfo.GetGroupByFieldId() != -1 {
|
||||
return errors.New("not support search_group_by operation in the hybrid search")
|
||||
}
|
||||
internalSubReq := &internalpb.SubSearchRequest{
|
||||
Dsl: subReq.GetDsl(),
|
||||
PlaceholderGroup: subReq.GetPlaceholderGroup(),
|
||||
DslType: subReq.GetDslType(),
|
||||
SerializedExprPlan: nil,
|
||||
Nq: subReq.GetNq(),
|
||||
PartitionIDs: nil,
|
||||
Topk: queryInfo.GetTopk(),
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// set PartitionIDs for sub search
|
||||
if t.partitionKeyMode {
|
||||
partitionIDs, err2 := t.tryParsePartitionIDsFromPlan(plan)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if len(partitionIDs) > 0 {
|
||||
internalSubReq.PartitionIDs = partitionIDs
|
||||
t.partitionIDsSet.Upsert(partitionIDs...)
|
||||
}
|
||||
} else {
|
||||
internalSubReq.PartitionIDs = t.SearchRequest.GetPartitionIDs()
|
||||
}
|
||||
|
||||
if t.requery {
|
||||
plan.OutputFieldIds = nil
|
||||
} else {
|
||||
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
|
||||
}
|
||||
|
||||
internalSubReq.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.SubReqs[index] = internalSubReq
|
||||
t.queryInfos[index] = queryInfo
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
var err error
|
||||
t.reScorers, err = NewReScorers(len(t.request.GetSubReqs()), t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
|
||||
defer sp.End()
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.Offset = offset
|
||||
|
||||
if t.partitionKeyMode {
|
||||
partitionIDs, err2 := t.tryParsePartitionIDsFromPlan(plan)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if len(partitionIDs) > 0 {
|
||||
t.SearchRequest.PartitionIDs = partitionIDs
|
||||
}
|
||||
}
|
||||
|
||||
if t.requery {
|
||||
plan.OutputFieldIds = nil
|
||||
} else {
|
||||
plan.OutputFieldIds = t.SearchRequest.OutputFieldsId
|
||||
}
|
||||
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.queryInfos = append(t.queryInfos, queryInfo)
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, ignoreOffset bool) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
|
||||
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
|
||||
if err != nil || len(annsFieldName) == 0 {
|
||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||
if len(vecFields) == 0 {
|
||||
return nil, nil, 0, errors.New(AnnsFieldKey + " not found in schema")
|
||||
}
|
||||
|
||||
if enableMultipleVectorFields && len(vecFields) > 1 {
|
||||
return nil, nil, 0, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||
}
|
||||
annsFieldName = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, ignoreOffset)
|
||||
if parseErr != nil {
|
||||
return nil, nil, 0, parseErr
|
||||
}
|
||||
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
|
||||
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
|
||||
}
|
||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
|
||||
if planErr != nil {
|
||||
log.Warn("failed to create query plan", zap.Error(planErr),
|
||||
zap.String("dsl", dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
|
||||
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
|
||||
}
|
||||
log.Debug("create query plan",
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
|
||||
return plan, queryInfo, offset, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) {
|
||||
expr, err := ParseExprFromPlan(plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to parse expr", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(t.ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
|
||||
if err != nil {
|
||||
log.Warn("failed to assign partition keys", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(hashedPartitionNames) > 0 {
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
PartitionIDs, err2 := getPartitionIDs(t.ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames)
|
||||
if err2 != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err2))
|
||||
return nil, err2
|
||||
}
|
||||
return PartitionIDs, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-Execute")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx).With(zap.Int64("nq", t.SearchRequest.GetNq()))
|
||||
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
|
||||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
|
||||
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.SearchRequest.CollectionID,
|
||||
|
@ -396,6 +491,43 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo) (*milvuspb.SearchResults, error) {
|
||||
metricType := ""
|
||||
if len(toReduceResults) >= 1 {
|
||||
metricType = toReduceResults[0].GetMetricType()
|
||||
}
|
||||
|
||||
// Decode all search results
|
||||
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
|
||||
if err != nil {
|
||||
log.Warn("failed to decode search results", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(validSearchResults) <= 0 {
|
||||
return fillInEmptyResult(nq), nil
|
||||
}
|
||||
|
||||
// Reduce all search results
|
||||
log.Debug("proxy search post execute reduce",
|
||||
zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
|
||||
zap.Int("number of valid search results", len(validSearchResults)))
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
var result *milvuspb.SearchResults
|
||||
result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, nq, topK,
|
||||
metricType, primaryFieldSchema.DataType, offset, queryInfo))
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute")
|
||||
defer sp.End()
|
||||
|
@ -406,11 +538,6 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}()
|
||||
log := log.Ctx(ctx).With(zap.Int64("nq", t.SearchRequest.GetNq()))
|
||||
|
||||
var (
|
||||
Nq = t.SearchRequest.GetNq()
|
||||
Topk = t.SearchRequest.GetTopk()
|
||||
MetricType = t.SearchRequest.GetMetricType()
|
||||
)
|
||||
toReduceResults, err := t.collectSearchResults(ctx)
|
||||
if err != nil {
|
||||
log.Warn("failed to collect search results", zap.Error(err))
|
||||
|
@ -424,45 +551,65 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
if len(toReduceResults) >= 1 {
|
||||
MetricType = toReduceResults[0].GetMetricType()
|
||||
}
|
||||
|
||||
// Decode all search results
|
||||
tr.CtxRecord(ctx, "decodeResultStart")
|
||||
validSearchResults, err := decodeSearchResults(ctx, toReduceResults)
|
||||
if err != nil {
|
||||
log.Warn("failed to decode search results", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
|
||||
if len(validSearchResults) <= 0 {
|
||||
t.fillInEmptyResult(Nq)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reduce all search results
|
||||
log.Debug("proxy search post execute reduce",
|
||||
zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
|
||||
zap.Int("number of valid search results", len(validSearchResults)))
|
||||
tr.CtxRecord(ctx, "reduceResultStart")
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, Nq, Topk,
|
||||
MetricType, primaryFieldSchema.DataType, t.offset, t.queryInfo))
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for _, searchResult := range toReduceResults {
|
||||
// if get a non-advanced result, skip all
|
||||
if !searchResult.GetIsAdvanced() {
|
||||
continue
|
||||
}
|
||||
for _, subResult := range searchResult.GetSubResults() {
|
||||
// swallow copy
|
||||
internalResults := &internalpb.SearchResults{
|
||||
MetricType: subResult.GetMetricType(),
|
||||
NumQueries: subResult.GetNumQueries(),
|
||||
TopK: subResult.GetTopK(),
|
||||
SlicedBlob: subResult.GetSlicedBlob(),
|
||||
SlicedNumCount: subResult.GetSlicedNumCount(),
|
||||
SlicedOffset: subResult.GetSlicedOffset(),
|
||||
IsAdvanced: false,
|
||||
}
|
||||
reqIndex := subResult.GetReqIndex()
|
||||
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
||||
}
|
||||
}
|
||||
|
||||
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for index, internalResults := range multipleInternalResults {
|
||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||
|
||||
metricType := ""
|
||||
if len(internalResults) >= 1 {
|
||||
metricType = internalResults[0].GetMetricType()
|
||||
}
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.reScorers[index].setMetricType(metricType)
|
||||
t.reScorers[index].reScore(result)
|
||||
multipleMilvusResults[index] = result
|
||||
}
|
||||
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
||||
t.rankParams,
|
||||
primaryFieldSchema.GetDataType(),
|
||||
multipleMilvusResults)
|
||||
if err != nil {
|
||||
log.Warn("rank search result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
|
@ -475,6 +622,9 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
|
||||
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
|
||||
log.Debug("Search post execute done",
|
||||
zap.Int64("collection", t.GetCollectionID()),
|
||||
|
@ -550,29 +700,30 @@ func (t *searchTask) Requery() error {
|
|||
MsgType: commonpb.MsgType_Retrieve,
|
||||
Timestamp: t.BeginTs(),
|
||||
},
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: "",
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
QueryParams: t.request.GetSearchParams(),
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
|
||||
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
|
||||
Expr: "",
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
UseDefaultConsistency: false,
|
||||
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
|
||||
}
|
||||
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
queryReq.QueryParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: LimitKey,
|
||||
Value: strconv.FormatInt(t.rankParams.limit, 10),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
queryReq.QueryParams = t.request.GetSearchParams()
|
||||
}
|
||||
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: merr.Success("search result is empty"),
|
||||
CollectionName: t.collectionName,
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: numQueries,
|
||||
Topks: make([]int64, numQueries),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInFieldInfo() {
|
||||
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
||||
for i, name := range t.request.OutputFields {
|
||||
|
|
|
@ -49,28 +49,46 @@ import (
|
|||
)
|
||||
|
||||
func TestSearchTask_PostExecute(t *testing.T) {
|
||||
t.Run("Test empty result", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
var err error
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(context.TODO()),
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
var (
|
||||
rc = NewRootCoordMock()
|
||||
qc = mocks.NewMockQueryCoordClient(t)
|
||||
ctx = context.TODO()
|
||||
)
|
||||
|
||||
defer rc.Close()
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSearchTask := func(t *testing.T, collName string) *searchTask {
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collName,
|
||||
SearchRequest: &internalpb.SearchRequest{},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
Nq: 1,
|
||||
SearchParams: getBaseSearchParams(),
|
||||
},
|
||||
request: nil,
|
||||
qc: nil,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.SearchResults]{},
|
||||
qc: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
// no result
|
||||
qt.resultBuf.Insert(&internalpb.SearchResults{})
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
t.Run("Test empty result", func(t *testing.T) {
|
||||
collName := "test_collection_empty_result" + funcutil.GenRandomStr()
|
||||
createColl(t, collName, rc)
|
||||
qt := getSearchTask(t, collName)
|
||||
|
||||
err = qt.PreExecute(qt.ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
qt.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
|
||||
qt.resultBuf.Insert(&internalpb.SearchResults{})
|
||||
err := qt.PostExecute(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
@ -1988,7 +2006,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.validParams, nil)
|
||||
info, offset, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
if test.description == "offsetParam" {
|
||||
|
@ -2077,7 +2095,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.invalidParams, nil)
|
||||
info, offset, err := parseSearchInfo(test.invalidParams, nil, false)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
assert.Zero(t, offset)
|
||||
|
@ -2155,7 +2173,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
})
|
||||
|
@ -2171,7 +2189,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
|
@ -2189,7 +2207,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
|
@ -2213,7 +2231,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema)
|
||||
info, _, err := parseSearchInfo(normalParam, schema, false)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
|
@ -2232,7 +2250,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema)
|
||||
info, _, err := parseSearchInfo(normalParam, schema, false)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
})
|
||||
|
@ -2504,9 +2522,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
qt.resultBuf.Insert(&internalpb.SearchResults{
|
||||
SlicedBlob: bytes,
|
||||
})
|
||||
qt.queryInfo = &planpb.QueryInfo{
|
||||
qt.queryInfos = []*planpb.QueryInfo{{
|
||||
GroupByFieldId: -1,
|
||||
}
|
||||
}}
|
||||
err = qt.PostExecute(ctx)
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
|
|
|
@ -469,61 +469,6 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(con
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNodeServer) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNodeServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockQueryNodeServer_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.HybridSearchRequest
|
||||
func (_e *MockQueryNodeServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_HybridSearch_Call {
|
||||
return &MockQueryNodeServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNodeServer_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeServer_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNodeServer_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadPartitions provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
|
|
@ -62,7 +62,6 @@ type ShardDelegator interface {
|
|||
GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry)
|
||||
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
|
||||
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
|
||||
HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)
|
||||
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
|
||||
QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error
|
||||
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
|
||||
|
@ -267,115 +266,81 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||
})
|
||||
|
||||
if req.GetReq().GetIsAdvanced() {
|
||||
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs()))
|
||||
for index, subReq := range req.GetReq().GetSubReqs() {
|
||||
newRequest := &internalpb.SearchRequest{
|
||||
Base: req.GetReq().GetBase(),
|
||||
ReqID: req.GetReq().GetReqID(),
|
||||
DbID: req.GetReq().GetDbID(),
|
||||
CollectionID: req.GetReq().GetCollectionID(),
|
||||
PartitionIDs: subReq.GetPartitionIDs(),
|
||||
Dsl: subReq.GetDsl(),
|
||||
PlaceholderGroup: subReq.GetPlaceholderGroup(),
|
||||
DslType: subReq.GetDslType(),
|
||||
SerializedExprPlan: subReq.GetSerializedExprPlan(),
|
||||
OutputFieldsId: req.GetReq().GetOutputFieldsId(),
|
||||
MvccTimestamp: req.GetReq().GetMvccTimestamp(),
|
||||
GuaranteeTimestamp: req.GetReq().GetGuaranteeTimestamp(),
|
||||
TimeoutTimestamp: req.GetReq().GetTimeoutTimestamp(),
|
||||
Nq: subReq.GetNq(),
|
||||
Topk: subReq.GetTopk(),
|
||||
MetricType: subReq.GetMetricType(),
|
||||
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
|
||||
Username: req.GetReq().GetUsername(),
|
||||
IsAdvanced: false,
|
||||
}
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
searchReq := &querypb.SearchRequest{
|
||||
Req: newRequest,
|
||||
DmlChannels: req.GetDmlChannels(),
|
||||
TotalChannelNum: req.GetTotalChannelNum(),
|
||||
FromShardLeader: true,
|
||||
}
|
||||
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
|
||||
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
|
||||
if searchReq.GetReq().GetMvccTimestamp() == 0 {
|
||||
searchReq.GetReq().MvccTimestamp = tSafe
|
||||
}
|
||||
|
||||
results, err := sd.search(ctx, searchReq, sealed, growing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return segments.ReduceSearchResults(ctx,
|
||||
results,
|
||||
searchReq.Req.GetNq(),
|
||||
searchReq.Req.GetTopk(),
|
||||
searchReq.Req.GetMetricType())
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
||||
err = conc.AwaitAll(futures...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := make([]*internalpb.SearchResults, len(futures))
|
||||
for i, future := range futures {
|
||||
result := future.Value()
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Debug("delegator hybrid search failed",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return nil, merr.Error(result.GetStatus())
|
||||
}
|
||||
results[i] = result
|
||||
}
|
||||
var ret *internalpb.SearchResults
|
||||
ret, err = segments.MergeToAdvancedResults(ctx, results)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []*internalpb.SearchResults{ret}, nil
|
||||
}
|
||||
return sd.search(ctx, req, sealed, growing)
|
||||
}
|
||||
|
||||
// HybridSearch preforms hybrid search operation on shard.
|
||||
func (sd *shardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if err := sd.lifetime.Add(lifetime.IsWorking); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sd.lifetime.Done()
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received hybrid search request not belongs to it",
|
||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||
)
|
||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||
}
|
||||
|
||||
partitions := req.GetReq().GetPartitionIDs()
|
||||
if !sd.collection.ExistPartition(partitions...) {
|
||||
return nil, merr.WrapErrPartitionNotLoaded(partitions)
|
||||
}
|
||||
|
||||
// wait tsafe
|
||||
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator hybrid search failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
}
|
||||
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel).
|
||||
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to hybrid search, current distribution is not serviceable")
|
||||
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
|
||||
}
|
||||
defer sd.distribution.Unpin(version)
|
||||
existPartitions := sd.collection.GetPartitions()
|
||||
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
|
||||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||
})
|
||||
|
||||
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetReqs()))
|
||||
for index := range req.GetReq().GetReqs() {
|
||||
request := req.GetReq().Reqs[index]
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
searchReq := &querypb.SearchRequest{
|
||||
Req: request,
|
||||
DmlChannels: req.GetDmlChannels(),
|
||||
TotalChannelNum: req.GetTotalChannelNum(),
|
||||
FromShardLeader: true,
|
||||
}
|
||||
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
|
||||
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
|
||||
if searchReq.GetReq().GetMvccTimestamp() == 0 {
|
||||
searchReq.GetReq().MvccTimestamp = tSafe
|
||||
}
|
||||
|
||||
results, err := sd.search(ctx, searchReq, sealed, growing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return segments.ReduceSearchResults(ctx,
|
||||
results,
|
||||
searchReq.Req.GetNq(),
|
||||
searchReq.Req.GetTopk(),
|
||||
searchReq.Req.GetMetricType())
|
||||
})
|
||||
futures[index] = future
|
||||
}
|
||||
|
||||
err = conc.AwaitAll(futures...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &querypb.HybridSearchResult{
|
||||
Status: merr.Success(),
|
||||
Results: make([]*internalpb.SearchResults, len(futures)),
|
||||
}
|
||||
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for i, future := range futures {
|
||||
result := future.Value()
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Debug("delegator hybrid search failed",
|
||||
zap.String("reason", result.GetStatus().GetReason()))
|
||||
return nil, merr.Error(result.GetStatus())
|
||||
}
|
||||
|
||||
ret.Results[i] = result
|
||||
for ch, ts := range result.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
}
|
||||
ret.ChannelsMvcc = channelsMvcc
|
||||
|
||||
log.Debug("Delegator hybrid search done")
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
|
|
|
@ -469,251 +469,6 @@ func (s *DelegatorSuite) TestSearch() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestHybridSearch() {
|
||||
s.delegator.Start()
|
||||
paramtable.SetNodeID(1)
|
||||
s.initSegments()
|
||||
s.Run("normal", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(1, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
if req.GetScope() == querypb.DataScope_Streaming {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
|
||||
}
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
|
||||
}
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
results, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Reqs: []*internalpb.SearchRequest{
|
||||
{Base: commonpbutil.NewMsgBase()},
|
||||
{Base: commonpbutil.NewMsgBase()},
|
||||
},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(2, len(results.Results))
|
||||
})
|
||||
|
||||
s.Run("partition_not_loaded", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
// not load partation -1,will return error
|
||||
PartitionIDs: []int64{-1},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.True(errors.Is(err, merr.ErrPartitionNotLoaded))
|
||||
})
|
||||
|
||||
s.Run("worker_return_error", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
|
||||
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Reqs: []*internalpb.SearchRequest{
|
||||
{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("worker_return_failure_code", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mocked error",
|
||||
},
|
||||
}, nil)
|
||||
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
Reqs: []*internalpb.SearchRequest{
|
||||
{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
},
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wrong_channel", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
DmlChannels: []string{"non_exist_channel"},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wait_tsafe_timeout", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
GuaranteeTimestamp: 10100,
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("tsafe_behind_max_lag", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 10001,
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("distribution_not_serviceable", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sd, ok := s.delegator.(*shardDelegator)
|
||||
s.Require().True(ok)
|
||||
sd.distribution.AddOfflines(1001)
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("cluster_not_serviceable", func() {
|
||||
s.delegator.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestQuery() {
|
||||
s.delegator.Start()
|
||||
paramtable.SetNodeID(1)
|
||||
|
|
|
@ -253,61 +253,6 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
|
|||
return _c
|
||||
}
|
||||
|
||||
// HybridSearch provides a mock function with given fields: ctx, req
|
||||
func (_m *MockShardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 *querypb.HybridSearchResult
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||
return rf(ctx, req)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardDelegator_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||
type MockShardDelegator_HybridSearch_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// HybridSearch is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.HybridSearchRequest
|
||||
func (_e *MockShardDelegator_Expecter) HybridSearch(ctx interface{}, req interface{}) *MockShardDelegator_HybridSearch_Call {
|
||||
return &MockShardDelegator_HybridSearch_Call{Call: _e.mock.On("HybridSearch", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_HybridSearch_Call) Run(run func(ctx context.Context, req *querypb.HybridSearchRequest)) *MockShardDelegator_HybridSearch_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockShardDelegator_HybridSearch_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockShardDelegator_HybridSearch_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadGrowing provides a mock function with given fields: ctx, infos, version
|
||||
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
||||
ret := _m.Called(ctx, infos, version)
|
||||
|
|
|
@ -383,7 +383,12 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
resp, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
var resp *internalpb.SearchResults
|
||||
if req.GetReq().GetIsAdvanced() {
|
||||
resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq())
|
||||
} else {
|
||||
resp, err = segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -402,63 +407,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.HybridSearchRequest, channel string) (*querypb.HybridSearchResult, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
|
||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||
zap.String("channel", channel),
|
||||
)
|
||||
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
|
||||
|
||||
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer node.lifetime.Done()
|
||||
|
||||
var err error
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("start to search channel")
|
||||
searchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// From Proxy
|
||||
tr := timerecord.NewTimeRecorder("hybridSearchDelegator")
|
||||
// get delegator
|
||||
sd, ok := node.delegators.Get(channel)
|
||||
if !ok {
|
||||
err := merr.WrapErrChannelNotFound(channel)
|
||||
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
// do hybrid search
|
||||
result, err := sd.HybridSearch(searchCtx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to hybrid search on delegator", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , traceID = %s, vChannel = %s",
|
||||
traceID,
|
||||
channel,
|
||||
))
|
||||
|
||||
// update metric to prometheus
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
|
||||
for _, searchReq := range req.GetReq().GetReqs() {
|
||||
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetNq()))
|
||||
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetTopk()))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||
|
|
|
@ -102,6 +102,88 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||
return searchResults, nil
|
||||
}
|
||||
|
||||
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64) (*internalpb.SearchResults, error) {
|
||||
if len(results) == 1 {
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for _, r := range results {
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
}
|
||||
searchResults := &internalpb.SearchResults{
|
||||
IsAdvanced: true,
|
||||
ChannelsMvcc: channelsMvcc,
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
if !result.GetIsAdvanced() {
|
||||
continue
|
||||
}
|
||||
// we just append here, no need to split subResult and reduce
|
||||
// defer this reduce to proxy
|
||||
searchResults.SubResults = append(searchResults.SubResults, result.GetSubResults()...)
|
||||
searchResults.NumQueries = result.GetNumQueries()
|
||||
}
|
||||
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
|
||||
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
|
||||
return result.GetCostAggregation(), true
|
||||
}
|
||||
|
||||
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
|
||||
return result.GetCostAggregation(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
})
|
||||
searchResults.CostAggregation = mergeRequestCost(requestCosts)
|
||||
return searchResults, nil
|
||||
}
|
||||
|
||||
func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
|
||||
searchResults := &internalpb.SearchResults{
|
||||
IsAdvanced: true,
|
||||
}
|
||||
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for _, r := range results {
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
}
|
||||
searchResults.ChannelsMvcc = channelsMvcc
|
||||
for index, result := range results {
|
||||
// we just append here, no need to split subResult and reduce
|
||||
// defer this reduce to proxy
|
||||
subResult := &internalpb.SubSearchResults{
|
||||
MetricType: result.GetMetricType(),
|
||||
NumQueries: result.GetNumQueries(),
|
||||
TopK: result.GetTopK(),
|
||||
SlicedBlob: result.GetSlicedBlob(),
|
||||
SlicedNumCount: result.GetSlicedNumCount(),
|
||||
SlicedOffset: result.GetSlicedOffset(),
|
||||
ReqIndex: int64(index),
|
||||
}
|
||||
searchResults.NumQueries = result.GetNumQueries()
|
||||
searchResults.SubResults = append(searchResults.SubResults, subResult)
|
||||
}
|
||||
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
|
||||
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
|
||||
return result.GetCostAggregation(), true
|
||||
}
|
||||
|
||||
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
|
||||
return result.GetCostAggregation(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
})
|
||||
searchResults.CostAggregation = mergeRequestCost(requestCosts)
|
||||
return searchResults, nil
|
||||
}
|
||||
|
||||
func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) {
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
|
|
|
@ -751,7 +751,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
|
||||
for i, ch := range req.GetDmlChannels() {
|
||||
ch := ch
|
||||
req := &querypb.SearchRequest{
|
||||
channelReq := &querypb.SearchRequest{
|
||||
Req: req.Req,
|
||||
DmlChannels: []string{ch},
|
||||
SegmentIDs: req.SegmentIDs,
|
||||
|
@ -762,7 +762,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
|
||||
i := i
|
||||
runningGp.Go(func() error {
|
||||
ret, err := node.searchChannel(runningCtx, req, ch)
|
||||
ret, err := node.searchChannel(runningCtx, channelReq, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -779,12 +779,20 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
}
|
||||
|
||||
tr.RecordSpan()
|
||||
result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
resp.Status = merr.Status(err)
|
||||
var result *internalpb.SearchResults
|
||||
var err2 error
|
||||
if req.GetReq().GetIsAdvanced() {
|
||||
result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq())
|
||||
} else {
|
||||
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err2))
|
||||
resp.Status = merr.Status(err2)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
reduceLatency := tr.RecordSpan()
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
|
||||
Observe(float64(reduceLatency.Milliseconds()))
|
||||
|
@ -800,103 +808,6 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// HybridSearch performs replica search tasks.
|
||||
func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
|
||||
zap.Strings("channels", req.GetDmlChannels()))
|
||||
|
||||
log.Debug("Received HybridSearchRequest",
|
||||
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
|
||||
zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()))
|
||||
|
||||
tr := timerecord.NewTimeRecorderWithTrace(ctx, "HybridSearchRequest")
|
||||
|
||||
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
|
||||
return &querypb.HybridSearchResult{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: node.GetNodeID(),
|
||||
},
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
defer node.lifetime.Done()
|
||||
|
||||
resp := &querypb.HybridSearchResult{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: node.GetNodeID(),
|
||||
},
|
||||
Status: merr.Success(),
|
||||
}
|
||||
collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
|
||||
if collection == nil {
|
||||
resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID()))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
MultipleResults := make([]*querypb.HybridSearchResult, len(req.GetDmlChannels()))
|
||||
runningGp, runningCtx := errgroup.WithContext(ctx)
|
||||
|
||||
for i, ch := range req.GetDmlChannels() {
|
||||
ch := ch
|
||||
req := &querypb.HybridSearchRequest{
|
||||
Req: req.Req,
|
||||
DmlChannels: []string{ch},
|
||||
TotalChannelNum: 1,
|
||||
}
|
||||
|
||||
i := i
|
||||
runningGp.Go(func() error {
|
||||
ret, err := node.hybridSearchChannel(runningCtx, req, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := merr.Error(ret.GetStatus()); err != nil {
|
||||
return err
|
||||
}
|
||||
MultipleResults[i] = ret
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := runningGp.Wait(); err != nil {
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
tr.RecordSpan()
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for i, searchReq := range req.GetReq().GetReqs() {
|
||||
toReduceResults := make([]*internalpb.SearchResults, len(MultipleResults))
|
||||
for index, hs := range MultipleResults {
|
||||
toReduceResults[index] = hs.Results[i]
|
||||
}
|
||||
result, err := segments.ReduceSearchResults(ctx, toReduceResults, searchReq.GetNq(), searchReq.GetTopk(), searchReq.GetMetricType())
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
for ch, ts := range result.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
}
|
||||
resp.Results = append(resp.Results, result)
|
||||
}
|
||||
resp.ChannelsMvcc = channelsMvcc
|
||||
|
||||
reduceLatency := tr.RecordSpan()
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
|
||||
Observe(float64(reduceLatency.Milliseconds()))
|
||||
|
||||
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
|
||||
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.HybridSearchLabel).
|
||||
Add(float64(proto.Size(req)))
|
||||
|
||||
if resp.GetCostAggregation() != nil {
|
||||
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// only used for delegator query segments from worker
|
||||
func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||
resp := &internalpb.RetrieveResults{
|
||||
|
|
|
@ -1337,47 +1337,6 @@ func (suite *ServiceSuite) TestSearchSegments_Failed() {
|
|||
suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestHybridSearch_Concurrent() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
concurrency := 16
|
||||
futures := make([]*conc.Future[*querypb.HybridSearchResult], 0, concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
future := conc.Go(func() (*querypb.HybridSearchResult, error) {
|
||||
creq1, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
suite.NoError(err)
|
||||
creq2, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
suite.NoError(err)
|
||||
req := &querypb.HybridSearchRequest{
|
||||
Req: &internalpb.HybridSearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63(),
|
||||
TargetID: suite.node.session.ServerID,
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
MvccTimestamp: typeutil.MaxTimestamp,
|
||||
Reqs: []*internalpb.SearchRequest{creq1, creq2},
|
||||
},
|
||||
DmlChannels: []string{suite.vchannel},
|
||||
}
|
||||
|
||||
return suite.node.HybridSearch(ctx, req)
|
||||
})
|
||||
futures = append(futures, future)
|
||||
}
|
||||
|
||||
err := conc.AwaitAll(futures...)
|
||||
suite.NoError(err)
|
||||
|
||||
for i := range futures {
|
||||
suite.True(merr.Ok(futures[i].Value().GetStatus()))
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
|
|
|
@ -82,10 +82,6 @@ func (m *GrpcQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequ
|
|||
return &internalpb.SearchResults{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
return &querypb.HybridSearchResult{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
||||
return &internalpb.SearchResults{}, m.Err
|
||||
}
|
||||
|
|
|
@ -93,10 +93,6 @@ func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest
|
|||
return qn.QueryNode.Search(ctx, in)
|
||||
}
|
||||
|
||||
func (qn *qnServerWrapper) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||
return qn.QueryNode.HybridSearch(ctx, in)
|
||||
}
|
||||
|
||||
func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
||||
return qn.QueryNode.SearchSegments(ctx, in)
|
||||
}
|
||||
|
|
|
@ -1790,26 +1790,72 @@ def extract_vector_field_name_list(collection_w):
|
|||
|
||||
return vector_name_list
|
||||
|
||||
def get_activate_func_from_metric_type(metric_type):
|
||||
activate_function = lambda x: x
|
||||
if metric_type == "COSINE":
|
||||
activate_function = lambda x: (1 + x) * 0.5
|
||||
elif metric_type == "IP":
|
||||
activate_function = lambda x: 0.5 + math.atan(x)/ math.pi
|
||||
else:
|
||||
activate_function = lambda x: 1.0 - 2*math.atan(x) / math.pi
|
||||
return activate_function
|
||||
|
||||
def get_hybrid_search_base_results(search_res_dict_array):
|
||||
def get_hybrid_search_base_results_rrf(search_res_dict_array, round_decimal=-1):
|
||||
"""
|
||||
merge the element in the dicts array
|
||||
search_res_dict_array : the dict array in which the elements to be merged
|
||||
return: the sorted id and score answer
|
||||
"""
|
||||
# calculate hybrid search base line
|
||||
|
||||
search_res_dict_merge = {}
|
||||
ids_answer = []
|
||||
score_answer = []
|
||||
for i in range(len(search_res_dict_array) - 1):
|
||||
for key in search_res_dict_array[i]:
|
||||
if search_res_dict_array[i + 1].get(key):
|
||||
search_res_dict_merge[key] = search_res_dict_array[i][key] + search_res_dict_array[i + 1][key]
|
||||
else:
|
||||
search_res_dict_merge[key] = search_res_dict_array[i][key]
|
||||
for key in search_res_dict_array[i + 1]:
|
||||
if not search_res_dict_array[i].get(key):
|
||||
search_res_dict_merge[key] = search_res_dict_array[i + 1][key]
|
||||
|
||||
for i, result in enumerate(search_res_dict_array, 0):
|
||||
for key, distance in result.items():
|
||||
search_res_dict_merge[key] = search_res_dict_merge.get(key, 0) + distance
|
||||
|
||||
if round_decimal != -1 :
|
||||
for k, v in search_res_dict_merge.items():
|
||||
multiplier = math.pow(10.0, round_decimal)
|
||||
v = math.floor(v*multiplier+0.5) / multiplier
|
||||
search_res_dict_merge[k] = v
|
||||
|
||||
sorted_list = sorted(search_res_dict_merge.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
for sort in sorted_list:
|
||||
ids_answer.append(int(sort[0]))
|
||||
score_answer.append(float(sort[1]))
|
||||
|
||||
return ids_answer, score_answer
|
||||
|
||||
|
||||
def get_hybrid_search_base_results(search_res_dict_array, weights, metric_types, round_decimal=-1):
|
||||
"""
|
||||
merge the element in the dicts array
|
||||
search_res_dict_array : the dict array in which the elements to be merged
|
||||
return: the sorted id and score answer
|
||||
"""
|
||||
# calculate hybrid search base line
|
||||
|
||||
search_res_dict_merge = {}
|
||||
ids_answer = []
|
||||
score_answer = []
|
||||
|
||||
for i, result in enumerate(search_res_dict_array, 0):
|
||||
activate_function = get_activate_func_from_metric_type(metric_types[i])
|
||||
for key, distance in result.items():
|
||||
activate_distance = activate_function(distance)
|
||||
weight = weights[i]
|
||||
search_res_dict_merge[key] = search_res_dict_merge.get(key, 0) + activate_function(distance) * weights[i]
|
||||
|
||||
if round_decimal != -1 :
|
||||
for k, v in search_res_dict_merge.items():
|
||||
multiplier = math.pow(10.0, round_decimal)
|
||||
v = math.floor(v*multiplier+0.5) / multiplier
|
||||
search_res_dict_merge[k] = v
|
||||
|
||||
sorted_list = sorted(search_res_dict_merge.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
for sort in sorted_list:
|
||||
|
|
|
@ -15,6 +15,7 @@ import decimal
|
|||
import multiprocessing
|
||||
import numbers
|
||||
import random
|
||||
import math
|
||||
import numpy
|
||||
import threading
|
||||
import pytest
|
||||
|
@ -10495,6 +10496,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
# 3. prepare search params
|
||||
req_list = []
|
||||
weights = [0.2, 0.3, 0.5]
|
||||
metrics = []
|
||||
search_res_dict_array = []
|
||||
for i in range(len(vector_name_list)):
|
||||
# 4. generate search data
|
||||
|
@ -10508,22 +10510,22 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
metrics.append("COSINE")
|
||||
# 5. search to get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, default_limit,
|
||||
default_search_exp,
|
||||
offset = offset,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
ids = search_res[0].ids
|
||||
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
|
||||
distance_array = search_res[0].distances
|
||||
for j in range(len(ids)):
|
||||
search_res_dict[ids[j]] = distance_array[j]
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 6. calculate hybrid search base line
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
|
||||
# 7. hybrid search
|
||||
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), default_limit,
|
||||
offset = offset,
|
||||
|
@ -11236,7 +11238,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
search_res_dict[ids[j]] = 1/(j + 60 +1)
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line for RRFRanker
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results_rrf(search_res_dict_array)
|
||||
# 5. hybrid search
|
||||
hybrid_search_0 = collection_w.hybrid_search(req_list, RRFRanker(), default_limit,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
|
@ -11296,7 +11298,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
# search for get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, default_limit,
|
||||
default_search_exp, offset=offset,
|
||||
default_search_exp, offset=0,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1,
|
||||
"ids": insert_ids,
|
||||
|
@ -11306,7 +11308,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
search_res_dict[ids[j]] = 1/(j + k +1)
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line for RRFRanker
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results_rrf(search_res_dict_array)
|
||||
# 5. hybrid search
|
||||
hybrid_res = collection_w.hybrid_search(req_list, RRFRanker(k), default_limit,
|
||||
offset=offset,
|
||||
|
@ -11444,7 +11446,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
search_res_dict[ids[j]] = 1/(j + k +1)
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line for RRFRanker
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results_rrf(search_res_dict_array)
|
||||
# 5. hybrid search
|
||||
hybrid_res = collection_w.hybrid_search(req_list, RRFRanker(k), default_limit,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
|
@ -11453,7 +11455,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"limit": default_limit})[0]
|
||||
# 6. compare results through the re-calculated distances
|
||||
for i in range(len(score_answer[:default_limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
|
||||
assert delta < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("limit", [1, 100, 16384])
|
||||
|
@ -11477,6 +11480,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
search_res_dict_array = []
|
||||
if limit > default_nb:
|
||||
limit = default_limit
|
||||
metrics = []
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_res_dict = {}
|
||||
|
@ -11488,6 +11492,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
metrics.append("COSINE")
|
||||
# search to get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, limit,
|
||||
|
@ -11497,12 +11502,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"ids": insert_ids,
|
||||
"limit": limit})[0]
|
||||
ids = search_res[0].ids
|
||||
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
|
||||
distance_array = search_res[0].distances
|
||||
for j in range(len(ids)):
|
||||
search_res_dict[ids[j]] = distance_array[j]
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics, 5)
|
||||
# 5. hybrid search
|
||||
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), limit,
|
||||
round_decimal=5,
|
||||
|
@ -11512,7 +11517,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"limit": limit})[0]
|
||||
# 6. compare results through the re-calculated distances
|
||||
for i in range(len(score_answer[:limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
|
||||
assert delta < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_hybrid_search_limit_out_of_range_max(self):
|
||||
|
@ -11598,6 +11604,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
req_list = []
|
||||
weights = [0.2, 0.3, 0.5]
|
||||
search_res_dict_array = []
|
||||
metrics = []
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_res_dict = {}
|
||||
|
@ -11609,6 +11616,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
metrics.append("COSINE")
|
||||
# search to get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, default_limit,
|
||||
|
@ -11618,12 +11626,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
ids = search_res[0].ids
|
||||
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
|
||||
distance_array = search_res[0].distances
|
||||
for j in range(len(ids)):
|
||||
search_res_dict[ids[j]] = distance_array[j]
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
|
||||
# 5. hybrid search
|
||||
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name,
|
||||
default_json_field_name]
|
||||
|
@ -11637,7 +11645,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"output_fields": output_fields})[0]
|
||||
# 6. compare results through the re-calculated distances
|
||||
for i in range(len(score_answer[:default_limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
|
||||
assert delta < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_hybrid_search_with_output_fields_all_fields_wildcard(self):
|
||||
|
@ -11656,6 +11665,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
req_list = []
|
||||
weights = [0.2, 0.3, 0.5]
|
||||
search_res_dict_array = []
|
||||
metrics = []
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_res_dict = {}
|
||||
|
@ -11667,6 +11677,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
metrics.append("COSINE")
|
||||
# search to get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, default_limit,
|
||||
|
@ -11676,12 +11687,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
ids = search_res[0].ids
|
||||
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
|
||||
distance_array = search_res[0].distances
|
||||
for j in range(len(ids)):
|
||||
search_res_dict[ids[j]] = distance_array[j]
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
|
||||
# 5. hybrid search
|
||||
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name,
|
||||
default_json_field_name]
|
||||
|
@ -11695,7 +11706,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"output_fields": output_fields})[0]
|
||||
# 6. compare results through the re-calculated distances
|
||||
for i in range(len(score_answer[:default_limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
|
||||
assert delta < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("output_fields", [[default_search_field], [default_search_field, default_int64_field_name]])
|
||||
|
@ -11717,6 +11729,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
req_list = []
|
||||
weights = [0.2, 0.3, 0.5]
|
||||
search_res_dict_array = []
|
||||
metrics = []
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_res_dict = {}
|
||||
|
@ -11728,6 +11741,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
metrics.append("COSINE")
|
||||
# search to get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, default_limit,
|
||||
|
@ -11742,12 +11756,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
search_res.done()
|
||||
search_res = search_res.result()
|
||||
ids = search_res[0].ids
|
||||
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
|
||||
distance_array = search_res[0].distances
|
||||
for j in range(len(ids)):
|
||||
search_res_dict[ids[j]] = distance_array[j]
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
|
||||
# 5. hybrid search
|
||||
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), default_limit, _async = _async,
|
||||
output_fields = output_fields,
|
||||
|
@ -11762,7 +11776,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
hybrid_res = hybrid_res.result()
|
||||
# 6. compare results through the re-calculated distances
|
||||
for i in range(len(score_answer[:default_limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
|
||||
assert delta < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("rerank", [RRFRanker(), WeightedRanker(0.1, 0.9, 1)])
|
||||
|
@ -11825,6 +11840,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
search_res_dict_array = []
|
||||
if limit > default_nb:
|
||||
limit = default_limit
|
||||
metrics = []
|
||||
for i in range(len(vector_name_list)):
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(1)]
|
||||
search_res_dict = {}
|
||||
|
@ -11836,6 +11852,7 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"expr": "int64 > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
metrics.append("COSINE")
|
||||
# search to get the base line of hybrid_search
|
||||
search_res = collection_w.search(vectors[:1], vector_name_list[i],
|
||||
default_search_params, limit,
|
||||
|
@ -11845,12 +11862,12 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"ids": insert_ids,
|
||||
"limit": limit})[0]
|
||||
ids = search_res[0].ids
|
||||
distance_array = [distance_single * weights[i] for distance_single in search_res[0].distances]
|
||||
distance_array = search_res[0].distances
|
||||
for j in range(len(ids)):
|
||||
search_res_dict[ids[j]] = distance_array[j]
|
||||
search_res_dict_array.append(search_res_dict)
|
||||
# 4. calculate hybrid search base line
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array)
|
||||
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics, 5)
|
||||
# 5. hybrid search
|
||||
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), limit,
|
||||
round_decimal=5,
|
||||
|
@ -11860,7 +11877,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
"limit": limit})[0]
|
||||
# 6. compare results through the re-calculated distances
|
||||
for i in range(len(score_answer[:limit])):
|
||||
assert score_answer[i] - hybrid_res[0].distances[i] < hybrid_search_epsilon
|
||||
delta = math.fabs(score_answer[i] - hybrid_res[0].distances[i])
|
||||
assert delta < hybrid_search_epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_hybrid_search_result_L2_order(self):
|
||||
|
@ -11896,8 +11914,8 @@ class TestCollectionHybridSearchValid(TestcaseBase):
|
|||
req_list.append(req)
|
||||
# 4. hybrid search
|
||||
res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), 10)[0]
|
||||
is_sorted_ascend = lambda lst: all(lst[i] <= lst[i+1] for i in range(len(lst)-1))
|
||||
assert is_sorted_ascend(res[0].distances)
|
||||
is_sorted_descend = lambda lst: all(lst[i] >= lst[i+1] for i in range(len(lst)-1))
|
||||
assert is_sorted_descend(res[0].distances)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_hybrid_search_result_order(self):
|
||||
|
|
Loading…
Reference in New Issue