From ad7f3d4c3ec458cf6220d39563e90ca876953fa0 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Tue, 11 Apr 2023 15:40:31 +0800 Subject: [PATCH] Support Count(*) on querynodev2 (#23321) Signed-off-by: longjiquan --- internal/querynodev2/handlers.go | 8 +- .../querynodev2/segments/count_reducer.go | 39 ++++++++ .../segments/count_reducer_test.go | 98 +++++++++++++++++++ .../segments/default_limit_reducer.go | 42 ++++++++ internal/querynodev2/segments/reducer.go | 33 +++++++ internal/querynodev2/segments/reducer_test.go | 62 ++++++++++++ internal/querynodev2/segments/result.go | 4 +- internal/querynodev2/services.go | 6 +- 8 files changed, 287 insertions(+), 5 deletions(-) create mode 100644 internal/querynodev2/segments/count_reducer.go create mode 100644 internal/querynodev2/segments/count_reducer_test.go create mode 100644 internal/querynodev2/segments/default_limit_reducer.go create mode 100644 internal/querynodev2/segments/reducer.go create mode 100644 internal/querynodev2/segments/reducer_test.go diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 147c27eb4d..d775228597 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -197,7 +197,9 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque return failRet, nil } - ret, err := segments.MergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), collection.Schema()) + reducer := segments.CreateInternalReducer(req, collection.Schema()) + + ret, err := reducer.Reduce(ctx, results) if err != nil { failRet.Status.Reason = err.Error() return failRet, nil @@ -243,7 +245,9 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ return nil, err } - reducedResult, err := segments.MergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.Req.GetOutputFieldsId(), collection.Schema()) + reducer := segments.CreateSegCoreReducer(req, collection.Schema()) + + reducedResult, err := reducer.Reduce(ctx, results) if err != nil { return nil, err } diff --git a/internal/querynodev2/segments/count_reducer.go b/internal/querynodev2/segments/count_reducer.go new file mode 100644 index 0000000000..758030b34e --- /dev/null +++ b/internal/querynodev2/segments/count_reducer.go @@ -0,0 +1,39 @@ +package segments + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/funcutil" +) + +type cntReducer struct { +} + +func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { + cnt := int64(0) + for _, res := range results { + c, err := funcutil.CntOfInternalResult(res) + if err != nil { + return nil, err + } + cnt += c + } + return funcutil.WrapCntToInternalResult(cnt), nil +} + +type cntReducerSegCore struct { +} + +func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { + cnt := int64(0) + for _, res := range results { + c, err := funcutil.CntOfSegCoreResult(res) + if err != nil { + return nil, err + } + cnt += c + } + return funcutil.WrapCntToSegCoreResult(cnt), nil +} diff --git a/internal/querynodev2/segments/count_reducer_test.go b/internal/querynodev2/segments/count_reducer_test.go new file mode 100644 index 0000000000..99e58ce526 --- /dev/null +++ b/internal/querynodev2/segments/count_reducer_test.go @@ -0,0 +1,98 @@ +package segments + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/funcutil" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" +) + +type InternalCntReducerSuite struct { + suite.Suite + r *cntReducer +} + +func (suite *InternalCntReducerSuite) SetupTest() { + suite.r = &cntReducer{} +} + +func (suite *InternalCntReducerSuite) TearDownTest() {} + +func TestInternalCntReducerSuite(t *testing.T) { + suite.Run(t, new(InternalCntReducerSuite)) +} + +func (suite *InternalCntReducerSuite) TestInvalid() { + results := []*internalpb.RetrieveResults{ + { + FieldsData: []*schemapb.FieldData{nil, nil}, + }, + } + + _, err := suite.r.Reduce(context.TODO(), results) + suite.Error(err) +} + +func (suite *InternalCntReducerSuite) TestNormalCase() { + results := []*internalpb.RetrieveResults{ + funcutil.WrapCntToInternalResult(1), + funcutil.WrapCntToInternalResult(2), + funcutil.WrapCntToInternalResult(3), + funcutil.WrapCntToInternalResult(4), + } + + res, err := suite.r.Reduce(context.TODO(), results) + suite.NoError(err) + + total, err := funcutil.CntOfInternalResult(res) + suite.NoError(err) + suite.Equal(int64(1+2+3+4), total) +} + +type SegCoreCntReducerSuite struct { + suite.Suite + r *cntReducerSegCore +} + +func (suite *SegCoreCntReducerSuite) SetupTest() { + suite.r = &cntReducerSegCore{} +} + +func (suite *SegCoreCntReducerSuite) TearDownTest() {} + +func TestSegCoreCntReducerSuite(t *testing.T) { + suite.Run(t, new(SegCoreCntReducerSuite)) +} + +func (suite *SegCoreCntReducerSuite) TestInvalid() { + results := []*segcorepb.RetrieveResults{ + { + FieldsData: []*schemapb.FieldData{nil, nil}, + }, + } + + _, err := suite.r.Reduce(context.TODO(), results) + suite.Error(err) +} + +func (suite *SegCoreCntReducerSuite) TestNormalCase() { + results := []*segcorepb.RetrieveResults{ + funcutil.WrapCntToSegCoreResult(1), + funcutil.WrapCntToSegCoreResult(2), + funcutil.WrapCntToSegCoreResult(3), + funcutil.WrapCntToSegCoreResult(4), + } + + res, err := suite.r.Reduce(context.TODO(), results) + suite.NoError(err) + + total, err := funcutil.CntOfSegCoreResult(res) + suite.NoError(err) + suite.Equal(int64(1+2+3+4), total) +} diff --git a/internal/querynodev2/segments/default_limit_reducer.go b/internal/querynodev2/segments/default_limit_reducer.go new file mode 100644 index 0000000000..a0ba17a002 --- /dev/null +++ b/internal/querynodev2/segments/default_limit_reducer.go @@ -0,0 +1,42 @@ +package segments + +import ( + "context" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" +) + +type defaultLimitReducer struct { + req *querypb.QueryRequest + schema *schemapb.CollectionSchema +} + +func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { + return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema) +} + +func newDefaultLimitReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducer { + return &defaultLimitReducer{ + req: req, + schema: schema, + } +} + +type defaultLimitReducerSegcore struct { + req *querypb.QueryRequest + schema *schemapb.CollectionSchema +} + +func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { + return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema) +} + +func newDefaultLimitReducerSegcore(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducerSegcore { + return &defaultLimitReducerSegcore{ + req: req, + schema: schema, + } +} diff --git a/internal/querynodev2/segments/reducer.go b/internal/querynodev2/segments/reducer.go new file mode 100644 index 0000000000..102b554d12 --- /dev/null +++ b/internal/querynodev2/segments/reducer.go @@ -0,0 +1,33 @@ +package segments + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/segcorepb" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +type internalReducer interface { + Reduce(context.Context, []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) +} + +func CreateInternalReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) internalReducer { + if req.GetReq().GetIsCount() { + return &cntReducer{} + } + return newDefaultLimitReducer(req, schema) +} + +type segCoreReducer interface { + Reduce(context.Context, []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) +} + +func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) segCoreReducer { + if req.GetReq().GetIsCount() { + return &cntReducerSegCore{} + } + return newDefaultLimitReducerSegcore(req, schema) +} diff --git a/internal/querynodev2/segments/reducer_test.go b/internal/querynodev2/segments/reducer_test.go new file mode 100644 index 0000000000..832b9c94b9 --- /dev/null +++ b/internal/querynodev2/segments/reducer_test.go @@ -0,0 +1,62 @@ +package segments + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +type ReducerFactorySuite struct { + suite.Suite + ir internalReducer + sr segCoreReducer + ok bool +} + +func (suite *ReducerFactorySuite) SetupTest() {} + +func (suite *ReducerFactorySuite) TearDownTest() {} + +func TestReducerFactorySuite(t *testing.T) { + suite.Run(t, new(ReducerFactorySuite)) +} + +func (suite *ReducerFactorySuite) TestCreateInternalReducer() { + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + IsCount: false, + }, + } + + suite.ir = CreateInternalReducer(req, nil) + _, suite.ok = suite.ir.(*defaultLimitReducer) + suite.True(suite.ok) + + req.Req.IsCount = true + + suite.ir = CreateInternalReducer(req, nil) + _, suite.ok = suite.ir.(*cntReducer) + suite.True(suite.ok) +} + +func (suite *ReducerFactorySuite) TestCreateSegCoreReducer() { + + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + IsCount: false, + }, + } + + suite.sr = CreateSegCoreReducer(req, nil) + _, suite.ok = suite.sr.(*defaultLimitReducerSegcore) + suite.True(suite.ok) + + req.Req.IsCount = true + + suite.sr = CreateSegCoreReducer(req, nil) + _, suite.ok = suite.sr.(*cntReducerSegCore) + suite.True(suite.ok) +} diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 7adb8b871d..c1e0baf6a9 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -352,7 +352,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore return ret, nil } -func MergeInternalRetrieveResultsAndFillIfEmpty( +func mergeInternalRetrieveResultsAndFillIfEmpty( ctx context.Context, retrieveResults []*internalpb.RetrieveResults, limit int64, @@ -372,7 +372,7 @@ func MergeInternalRetrieveResultsAndFillIfEmpty( return mergedResult, nil } -func MergeSegcoreRetrieveResultsAndFillIfEmpty( +func mergeSegcoreRetrieveResultsAndFillIfEmpty( ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, limit int64, diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 73b71f005e..1fea2253c2 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -678,6 +678,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i zap.Int64s("segmentIDs", req.GetSegmentIDs()), zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), zap.Uint64("travelTimestamp", req.GetReq().GetTravelTimestamp()), + zap.Bool("isCount", req.GetReq().GetIsCount()), ) if !node.lifetime.Add(commonpbutil.IsHealthy) { @@ -725,7 +726,10 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i if err := runningGp.Wait(); err != nil { return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil } - ret, err := segments.MergeInternalRetrieveResult(ctx, toMergeResults, req.GetReq().GetLimit()) + + reducer := segments.CreateInternalReducer(req, node.manager.Collection.Get(req.GetReq().GetCollectionID()).Schema()) + + ret, err := reducer.Reduce(ctx, toMergeResults) if err != nil { return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil }