milvus/internal/querynodev2/optimizers/query_hook_test.go

177 lines
4.6 KiB
Go

package optimizers
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type QueryHookSuite struct {
suite.Suite
queryHook QueryHook
}
func (suite *QueryHookSuite) SetupTest() {
}
func (suite *QueryHookSuite) TearDownTest() {
suite.queryHook = nil
}
func (suite *QueryHookSuite) TestOptimizeSearchParam() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
paramtable.Init()
paramtable.Get().Save(paramtable.Get().AutoIndexConfig.EnableOptimize.Key, "true")
suite.Run("normal_run", func() {
mockHook := NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil)
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.NoError(err)
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
})
suite.Run("no_hook", func() {
suite.queryHook = nil
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.NoError(err)
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
})
suite.Run("other_plannode", func() {
mockHook := NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil).Maybe()
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_Query{},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.NoError(err)
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
})
suite.Run("no_serialized_plan", func() {
mockHook := NewMockQueryHook(suite.T())
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
_, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.Error(err)
})
suite.Run("hook_run_error", func() {
mockHook := NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(merr.WrapErrServiceInternal("mocked"))
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
_, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
}, suite.queryHook, 2)
suite.Error(err)
})
}
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
planBytes := req.GetReq().GetSerializedExprPlan()
plan := planpb.PlanNode{}
err := proto.Unmarshal(planBytes, &plan)
suite.Require().NoError(err)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
suite.Equal(topK, queryInfo.GetTopk())
suite.Equal(param, queryInfo.GetSearchParams())
}
func TestOptimizeSearchParam(t *testing.T) {
suite.Run(t, new(QueryHookSuite))
}