Support Count(*) on querynodev2 (#23321)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/23358/head
Jiquan Long 2023-04-11 15:40:31 +08:00 committed by GitHub
parent 383915cfcd
commit ad7f3d4c3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 287 additions and 5 deletions

View File

@ -197,7 +197,9 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
return failRet, nil
}
ret, err := segments.MergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), collection.Schema())
reducer := segments.CreateInternalReducer(req, collection.Schema())
ret, err := reducer.Reduce(ctx, results)
if err != nil {
failRet.Status.Reason = err.Error()
return failRet, nil
@ -243,7 +245,9 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ
return nil, err
}
reducedResult, err := segments.MergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.Req.GetOutputFieldsId(), collection.Schema())
reducer := segments.CreateSegCoreReducer(req, collection.Schema())
reducedResult, err := reducer.Reduce(ctx, results)
if err != nil {
return nil, err
}

View File

@ -0,0 +1,39 @@
package segments
import (
"context"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/funcutil"
)
type cntReducer struct {
}
func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
cnt := int64(0)
for _, res := range results {
c, err := funcutil.CntOfInternalResult(res)
if err != nil {
return nil, err
}
cnt += c
}
return funcutil.WrapCntToInternalResult(cnt), nil
}
type cntReducerSegCore struct {
}
func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
cnt := int64(0)
for _, res := range results {
c, err := funcutil.CntOfSegCoreResult(res)
if err != nil {
return nil, err
}
cnt += c
}
return funcutil.WrapCntToSegCoreResult(cnt), nil
}

View File

@ -0,0 +1,98 @@
package segments
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
type InternalCntReducerSuite struct {
suite.Suite
r *cntReducer
}
func (suite *InternalCntReducerSuite) SetupTest() {
suite.r = &cntReducer{}
}
func (suite *InternalCntReducerSuite) TearDownTest() {}
func TestInternalCntReducerSuite(t *testing.T) {
suite.Run(t, new(InternalCntReducerSuite))
}
func (suite *InternalCntReducerSuite) TestInvalid() {
results := []*internalpb.RetrieveResults{
{
FieldsData: []*schemapb.FieldData{nil, nil},
},
}
_, err := suite.r.Reduce(context.TODO(), results)
suite.Error(err)
}
func (suite *InternalCntReducerSuite) TestNormalCase() {
results := []*internalpb.RetrieveResults{
funcutil.WrapCntToInternalResult(1),
funcutil.WrapCntToInternalResult(2),
funcutil.WrapCntToInternalResult(3),
funcutil.WrapCntToInternalResult(4),
}
res, err := suite.r.Reduce(context.TODO(), results)
suite.NoError(err)
total, err := funcutil.CntOfInternalResult(res)
suite.NoError(err)
suite.Equal(int64(1+2+3+4), total)
}
type SegCoreCntReducerSuite struct {
suite.Suite
r *cntReducerSegCore
}
func (suite *SegCoreCntReducerSuite) SetupTest() {
suite.r = &cntReducerSegCore{}
}
func (suite *SegCoreCntReducerSuite) TearDownTest() {}
func TestSegCoreCntReducerSuite(t *testing.T) {
suite.Run(t, new(SegCoreCntReducerSuite))
}
func (suite *SegCoreCntReducerSuite) TestInvalid() {
results := []*segcorepb.RetrieveResults{
{
FieldsData: []*schemapb.FieldData{nil, nil},
},
}
_, err := suite.r.Reduce(context.TODO(), results)
suite.Error(err)
}
func (suite *SegCoreCntReducerSuite) TestNormalCase() {
results := []*segcorepb.RetrieveResults{
funcutil.WrapCntToSegCoreResult(1),
funcutil.WrapCntToSegCoreResult(2),
funcutil.WrapCntToSegCoreResult(3),
funcutil.WrapCntToSegCoreResult(4),
}
res, err := suite.r.Reduce(context.TODO(), results)
suite.NoError(err)
total, err := funcutil.CntOfSegCoreResult(res)
suite.NoError(err)
suite.Equal(int64(1+2+3+4), total)
}

View File

@ -0,0 +1,42 @@
package segments
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
)
type defaultLimitReducer struct {
req *querypb.QueryRequest
schema *schemapb.CollectionSchema
}
func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema)
}
func newDefaultLimitReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducer {
return &defaultLimitReducer{
req: req,
schema: schema,
}
}
type defaultLimitReducerSegcore struct {
req *querypb.QueryRequest
schema *schemapb.CollectionSchema
}
func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema)
}
func newDefaultLimitReducerSegcore(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducerSegcore {
return &defaultLimitReducerSegcore{
req: req,
schema: schema,
}
}

View File

@ -0,0 +1,33 @@
package segments
import (
"context"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
type internalReducer interface {
Reduce(context.Context, []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error)
}
func CreateInternalReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) internalReducer {
if req.GetReq().GetIsCount() {
return &cntReducer{}
}
return newDefaultLimitReducer(req, schema)
}
type segCoreReducer interface {
Reduce(context.Context, []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error)
}
func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) segCoreReducer {
if req.GetReq().GetIsCount() {
return &cntReducerSegCore{}
}
return newDefaultLimitReducerSegcore(req, schema)
}

View File

@ -0,0 +1,62 @@
package segments
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
type ReducerFactorySuite struct {
suite.Suite
ir internalReducer
sr segCoreReducer
ok bool
}
func (suite *ReducerFactorySuite) SetupTest() {}
func (suite *ReducerFactorySuite) TearDownTest() {}
func TestReducerFactorySuite(t *testing.T) {
suite.Run(t, new(ReducerFactorySuite))
}
func (suite *ReducerFactorySuite) TestCreateInternalReducer() {
req := &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
IsCount: false,
},
}
suite.ir = CreateInternalReducer(req, nil)
_, suite.ok = suite.ir.(*defaultLimitReducer)
suite.True(suite.ok)
req.Req.IsCount = true
suite.ir = CreateInternalReducer(req, nil)
_, suite.ok = suite.ir.(*cntReducer)
suite.True(suite.ok)
}
func (suite *ReducerFactorySuite) TestCreateSegCoreReducer() {
req := &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
IsCount: false,
},
}
suite.sr = CreateSegCoreReducer(req, nil)
_, suite.ok = suite.sr.(*defaultLimitReducerSegcore)
suite.True(suite.ok)
req.Req.IsCount = true
suite.sr = CreateSegCoreReducer(req, nil)
_, suite.ok = suite.sr.(*cntReducerSegCore)
suite.True(suite.ok)
}

View File

@ -352,7 +352,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
return ret, nil
}
func MergeInternalRetrieveResultsAndFillIfEmpty(
func mergeInternalRetrieveResultsAndFillIfEmpty(
ctx context.Context,
retrieveResults []*internalpb.RetrieveResults,
limit int64,
@ -372,7 +372,7 @@ func MergeInternalRetrieveResultsAndFillIfEmpty(
return mergedResult, nil
}
func MergeSegcoreRetrieveResultsAndFillIfEmpty(
func mergeSegcoreRetrieveResultsAndFillIfEmpty(
ctx context.Context,
retrieveResults []*segcorepb.RetrieveResults,
limit int64,

View File

@ -678,6 +678,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("travelTimestamp", req.GetReq().GetTravelTimestamp()),
zap.Bool("isCount", req.GetReq().GetIsCount()),
)
if !node.lifetime.Add(commonpbutil.IsHealthy) {
@ -725,7 +726,10 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if err := runningGp.Wait(); err != nil {
return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil
}
ret, err := segments.MergeInternalRetrieveResult(ctx, toMergeResults, req.GetReq().GetLimit())
reducer := segments.CreateInternalReducer(req, node.manager.Collection.Get(req.GetReq().GetCollectionID()).Schema())
ret, err := reducer.Reduce(ctx, toMergeResults)
if err != nil {
return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil
}