mirror of https://github.com/milvus-io/milvus.git
177 lines
4.6 KiB
Go
177 lines
4.6 KiB
Go
package optimizers
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
)
|
|
|
|
type QueryHookSuite struct {
|
|
suite.Suite
|
|
queryHook QueryHook
|
|
}
|
|
|
|
func (suite *QueryHookSuite) SetupTest() {
|
|
}
|
|
|
|
func (suite *QueryHookSuite) TearDownTest() {
|
|
suite.queryHook = nil
|
|
}
|
|
|
|
func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
paramtable.Init()
|
|
paramtable.Get().Save(paramtable.Get().AutoIndexConfig.EnableOptimize.Key, "true")
|
|
|
|
suite.Run("normal_run", func() {
|
|
mockHook := NewMockQueryHook(suite.T())
|
|
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
|
params[common.TopKKey] = int64(50)
|
|
params[common.SearchParamKey] = `{"param": 2}`
|
|
}).Return(nil)
|
|
suite.queryHook = mockHook
|
|
defer func() { suite.queryHook = nil }()
|
|
|
|
plan := &planpb.PlanNode{
|
|
Node: &planpb.PlanNode_VectorAnns{
|
|
VectorAnns: &planpb.VectorANNS{
|
|
QueryInfo: &planpb.QueryInfo{
|
|
Topk: 100,
|
|
SearchParams: `{"param": 1}`,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
bs, err := proto.Marshal(plan)
|
|
suite.Require().NoError(err)
|
|
|
|
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
|
Req: &internalpb.SearchRequest{
|
|
SerializedExprPlan: bs,
|
|
},
|
|
TotalChannelNum: 2,
|
|
}, suite.queryHook, 2)
|
|
suite.NoError(err)
|
|
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
|
|
})
|
|
|
|
suite.Run("no_hook", func() {
|
|
suite.queryHook = nil
|
|
plan := &planpb.PlanNode{
|
|
Node: &planpb.PlanNode_VectorAnns{
|
|
VectorAnns: &planpb.VectorANNS{
|
|
QueryInfo: &planpb.QueryInfo{
|
|
Topk: 100,
|
|
SearchParams: `{"param": 1}`,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
bs, err := proto.Marshal(plan)
|
|
suite.Require().NoError(err)
|
|
|
|
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
|
Req: &internalpb.SearchRequest{
|
|
SerializedExprPlan: bs,
|
|
},
|
|
TotalChannelNum: 2,
|
|
}, suite.queryHook, 2)
|
|
suite.NoError(err)
|
|
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
|
|
})
|
|
|
|
suite.Run("other_plannode", func() {
|
|
mockHook := NewMockQueryHook(suite.T())
|
|
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
|
params[common.TopKKey] = int64(50)
|
|
params[common.SearchParamKey] = `{"param": 2}`
|
|
}).Return(nil).Maybe()
|
|
suite.queryHook = mockHook
|
|
defer func() { suite.queryHook = nil }()
|
|
|
|
plan := &planpb.PlanNode{
|
|
Node: &planpb.PlanNode_Query{},
|
|
}
|
|
bs, err := proto.Marshal(plan)
|
|
suite.Require().NoError(err)
|
|
|
|
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
|
Req: &internalpb.SearchRequest{
|
|
SerializedExprPlan: bs,
|
|
},
|
|
TotalChannelNum: 2,
|
|
}, suite.queryHook, 2)
|
|
suite.NoError(err)
|
|
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
|
|
})
|
|
|
|
suite.Run("no_serialized_plan", func() {
|
|
mockHook := NewMockQueryHook(suite.T())
|
|
suite.queryHook = mockHook
|
|
defer func() { suite.queryHook = nil }()
|
|
|
|
_, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
|
Req: &internalpb.SearchRequest{},
|
|
TotalChannelNum: 2,
|
|
}, suite.queryHook, 2)
|
|
suite.Error(err)
|
|
})
|
|
|
|
suite.Run("hook_run_error", func() {
|
|
mockHook := NewMockQueryHook(suite.T())
|
|
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
|
params[common.TopKKey] = int64(50)
|
|
params[common.SearchParamKey] = `{"param": 2}`
|
|
}).Return(merr.WrapErrServiceInternal("mocked"))
|
|
suite.queryHook = mockHook
|
|
defer func() { suite.queryHook = nil }()
|
|
|
|
plan := &planpb.PlanNode{
|
|
Node: &planpb.PlanNode_VectorAnns{
|
|
VectorAnns: &planpb.VectorANNS{
|
|
QueryInfo: &planpb.QueryInfo{
|
|
Topk: 100,
|
|
SearchParams: `{"param": 1}`,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
bs, err := proto.Marshal(plan)
|
|
suite.Require().NoError(err)
|
|
|
|
_, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
|
Req: &internalpb.SearchRequest{
|
|
SerializedExprPlan: bs,
|
|
},
|
|
}, suite.queryHook, 2)
|
|
suite.Error(err)
|
|
})
|
|
}
|
|
|
|
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
|
|
planBytes := req.GetReq().GetSerializedExprPlan()
|
|
|
|
plan := planpb.PlanNode{}
|
|
err := proto.Unmarshal(planBytes, &plan)
|
|
suite.Require().NoError(err)
|
|
|
|
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
|
suite.Equal(topK, queryInfo.GetTopk())
|
|
suite.Equal(param, queryInfo.GetSearchParams())
|
|
}
|
|
|
|
func TestOptimizeSearchParam(t *testing.T) {
|
|
suite.Run(t, new(QueryHookSuite))
|
|
}
|