enhance: reconstruct scalar part's code for segment-pruner(#30376) (#34346)

related: #30376
1. support more complex expr
2. add more ut test for unrelated fields

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/34278/head
Chun Han 2024-07-04 04:36:09 -04:00 committed by GitHub
parent 0b404bff22
commit fcafdb6d5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 571 additions and 27 deletions

View File

@ -0,0 +1,264 @@
package delegator
import (
"github.com/bits-and-blooms/bitset"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/storage"
)
type EvalCtx struct {
segmentStats []storage.SegmentStats
size uint
allTrueBitSet *bitset.BitSet
}
func NewEvalCtx(segStats []storage.SegmentStats, size uint, allTrueBst *bitset.BitSet) *EvalCtx {
return &EvalCtx{segStats, size, allTrueBst}
}
type Expr interface {
Inputs() []Expr
Eval(evalCtx *EvalCtx) *bitset.BitSet
}
func PruneByScalarField(expr Expr, segmentStats []storage.SegmentStats, segmentIDs []UniqueID, filteredSegments map[UniqueID]struct{}) {
if expr != nil {
size := uint(len(segmentIDs))
allTrueBst := bitset.New(size)
allTrueBst.FlipRange(0, size)
resBst := expr.Eval(NewEvalCtx(segmentStats, size, allTrueBst))
resBst.FlipRange(0, resBst.Len())
for i, e := resBst.NextSet(0); e; i, e = resBst.NextSet(i + 1) {
filteredSegments[segmentIDs[i]] = struct{}{}
}
}
// for input nil expr, nothing will happen
}
type LogicalBinaryExpr struct {
left Expr
right Expr
op planpb.BinaryExpr_BinaryOp
}
func NewLogicalBinaryExpr(l Expr, r Expr, op planpb.BinaryExpr_BinaryOp) *LogicalBinaryExpr {
return &LogicalBinaryExpr{left: l, right: r, op: op}
}
func (lbe *LogicalBinaryExpr) Eval(evalCtx *EvalCtx) *bitset.BitSet {
// 1. eval left
leftExpr := lbe.Inputs()[0]
var leftRes *bitset.BitSet
if leftExpr != nil {
leftRes = leftExpr.Eval(evalCtx)
}
// 2. eval right
rightExpr := lbe.Inputs()[1]
var rightRes *bitset.BitSet
if rightExpr != nil {
rightRes = rightExpr.Eval(evalCtx)
}
// 3. set true for possible nil expr
if leftRes == nil {
leftRes = evalCtx.allTrueBitSet
}
if rightRes == nil {
rightRes = evalCtx.allTrueBitSet
}
// 4. and/or left/right results
if lbe.op == planpb.BinaryExpr_LogicalAnd {
leftRes.InPlaceIntersection(rightRes)
} else if lbe.op == planpb.BinaryExpr_LogicalOr {
leftRes.InPlaceUnion(rightRes)
}
return leftRes
}
func (lbe *LogicalBinaryExpr) Inputs() []Expr {
return []Expr{lbe.left, lbe.right}
}
type PhysicalExpr struct {
Expr
}
func (lbe *PhysicalExpr) Inputs() []Expr {
return nil
}
type BinaryRangeExpr struct {
PhysicalExpr
lowerVal storage.ScalarFieldValue
upperVal storage.ScalarFieldValue
includeLower bool
includeUpper bool
}
func NewBinaryRangeExpr(lower storage.ScalarFieldValue,
upper storage.ScalarFieldValue, inLower bool, inUpper bool,
) *BinaryRangeExpr {
return &BinaryRangeExpr{lowerVal: lower, upperVal: upper, includeLower: inLower, includeUpper: inUpper}
}
func (bre *BinaryRangeExpr) Eval(evalCtx *EvalCtx) *bitset.BitSet {
localBst := bitset.New(evalCtx.size)
for i, segStat := range evalCtx.segmentStats {
fieldStat := &(segStat.FieldStats[0])
idx := uint(i)
commonMin := storage.MaxScalar(fieldStat.Min, bre.lowerVal)
commonMax := storage.MinScalar(fieldStat.Max, bre.upperVal)
if !((commonMin).GT(commonMax)) {
localBst.Set(idx)
}
}
return localBst
}
type UnaryRangeExpr struct {
PhysicalExpr
op planpb.OpType
val storage.ScalarFieldValue
}
func NewUnaryRangeExpr(value storage.ScalarFieldValue, op planpb.OpType) *UnaryRangeExpr {
return &UnaryRangeExpr{op: op, val: value}
}
func (ure *UnaryRangeExpr) Eval(
evalCtx *EvalCtx,
) *bitset.BitSet {
localBst := bitset.New(evalCtx.size)
for i, segStat := range evalCtx.segmentStats {
fieldStat := &(segStat.FieldStats[0])
idx := uint(i)
val := ure.val
switch ure.op {
case planpb.OpType_Equal:
if val.GE(fieldStat.Min) && val.LE(fieldStat.Max) {
localBst.Set(idx)
}
case planpb.OpType_LessEqual:
if !(val.LT(fieldStat.Min)) {
localBst.Set(idx)
}
case planpb.OpType_LessThan:
if !(val.LE(fieldStat.Min)) {
localBst.Set(idx)
}
case planpb.OpType_GreaterEqual:
if !(val.GT(fieldStat.Max)) {
localBst.Set(idx)
}
case planpb.OpType_GreaterThan:
if !(val.GE(fieldStat.Max)) {
localBst.Set(idx)
}
default:
return evalCtx.allTrueBitSet
}
}
return localBst
}
type TermExpr struct {
PhysicalExpr
vals []storage.ScalarFieldValue
}
func NewTermExpr(values []storage.ScalarFieldValue) *TermExpr {
return &TermExpr{vals: values}
}
func (te *TermExpr) Eval(evalCtx *EvalCtx) *bitset.BitSet {
localBst := bitset.New(evalCtx.size)
for i, segStat := range evalCtx.segmentStats {
fieldStat := &(segStat.FieldStats[0])
for _, val := range te.vals {
if val.GT(fieldStat.Max) {
// as the vals inside expr has been sorted before executed, if current val has exceeded the max, then
// no need to iterate over other values
break
}
if fieldStat.Min.LE(val) && (val).LE(fieldStat.Max) {
localBst.Set(uint(i))
break
}
}
}
return localBst
}
type ParseContext struct {
keyFieldIDToPrune FieldID
dataType schemapb.DataType
}
func NewParseContext(keyField FieldID, dType schemapb.DataType) *ParseContext {
return &ParseContext{keyField, dType}
}
func ParseExpr(exprPb *planpb.Expr, parseCtx *ParseContext) Expr {
var res Expr
switch exp := exprPb.GetExpr().(type) {
case *planpb.Expr_BinaryExpr:
res = ParseLogicalBinaryExpr(exp.BinaryExpr, parseCtx)
case *planpb.Expr_UnaryExpr:
res = ParseLogicalUnaryExpr(exp.UnaryExpr, parseCtx)
case *planpb.Expr_BinaryRangeExpr:
res = ParseBinaryRangeExpr(exp.BinaryRangeExpr, parseCtx)
case *planpb.Expr_UnaryRangeExpr:
res = ParseUnaryRangeExpr(exp.UnaryRangeExpr, parseCtx)
case *planpb.Expr_TermExpr:
res = ParseTermExpr(exp.TermExpr, parseCtx)
}
return res
}
func ParseLogicalBinaryExpr(exprPb *planpb.BinaryExpr, parseCtx *ParseContext) Expr {
leftExpr := ParseExpr(exprPb.Left, parseCtx)
rightExpr := ParseExpr(exprPb.Right, parseCtx)
return NewLogicalBinaryExpr(leftExpr, rightExpr, exprPb.GetOp())
}
func ParseLogicalUnaryExpr(exprPb *planpb.UnaryExpr, parseCtx *ParseContext) Expr {
// currently we don't handle NOT expr, this part of code is left for logical integrity
return nil
}
func ParseBinaryRangeExpr(exprPb *planpb.BinaryRangeExpr, parseCtx *ParseContext) Expr {
if exprPb.GetColumnInfo().GetFieldId() != parseCtx.keyFieldIDToPrune {
return nil
}
lower := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetLowerValue())
upper := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetUpperValue())
return NewBinaryRangeExpr(lower, upper, exprPb.LowerInclusive, exprPb.UpperInclusive)
}
func ParseUnaryRangeExpr(exprPb *planpb.UnaryRangeExpr, parseCtx *ParseContext) Expr {
if exprPb.GetColumnInfo().GetFieldId() != parseCtx.keyFieldIDToPrune {
return nil
}
if exprPb.GetOp() == planpb.OpType_NotEqual {
return nil
// segment-prune based on min-max cannot support not equal semantic
}
innerVal := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetValue())
return NewUnaryRangeExpr(innerVal, exprPb.GetOp())
}
func ParseTermExpr(exprPb *planpb.TermExpr, parseCtx *ParseContext) Expr {
if exprPb.GetColumnInfo().GetFieldId() != parseCtx.keyFieldIDToPrune {
return nil
}
scalarVals := make([]storage.ScalarFieldValue, 0)
for _, val := range exprPb.GetValues() {
scalarVals = append(scalarVals, storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, val))
}
return NewTermExpr(scalarVals)
}

View File

@ -42,8 +42,7 @@ func PruneSegments(ctx context.Context,
) {
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "segmentPrune")
defer span.End()
// 1. calculate filtered segments
filteredSegments := make(map[UniqueID]struct{}, 0)
// 1. select collection, partitions and expr
clusteringKeyField := clustering.GetClusteringKeyField(schema)
if clusteringKeyField == nil {
// no need to prune
@ -52,14 +51,18 @@ func PruneSegments(ctx context.Context,
tr := timerecord.NewTimeRecorder("PruneSegments")
var collectionID int64
var expr []byte
var partitionIDs []int64
if searchReq != nil {
collectionID = searchReq.CollectionID
expr = searchReq.GetSerializedExprPlan()
partitionIDs = searchReq.GetPartitionIDs()
} else {
collectionID = queryReq.CollectionID
expr = queryReq.GetSerializedExprPlan()
partitionIDs = queryReq.GetPartitionIDs()
}
filteredSegments := make(map[UniqueID]struct{}, 0)
// currently we only prune based on one column
if typeutil.IsVectorType(clusteringKeyField.GetDataType()) {
// parse searched vectors
@ -89,18 +92,26 @@ func PruneSegments(ctx context.Context,
log.Ctx(ctx).Error("failed to unmarshall serialized expr from bytes, failed the operation")
return
}
expr, err := exprutil.ParseExprFromPlan(&plan)
exprPb, err := exprutil.ParseExprFromPlan(&plan)
if err != nil {
log.Ctx(ctx).Error("failed to parse expr from plan, failed the operation")
return
}
targetRanges, matchALL := exprutil.ParseRanges(expr, exprutil.ClusteringKey)
if matchALL || targetRanges == nil {
return
}
for _, partStats := range partitionStats {
FilterSegmentsOnScalarField(partStats, targetRanges, clusteringKeyField, filteredSegments)
// 1. parse expr for prune
expr := ParseExpr(exprPb, NewParseContext(clusteringKeyField.GetFieldID(), clusteringKeyField.GetDataType()))
// 2. prune segments by scalar field
targetSegmentStats := make([]storage.SegmentStats, 0, 32)
targetSegmentIDs := make([]int64, 0, 32)
for _, partID := range partitionIDs {
partStats := partitionStats[partID]
for segID, segStat := range partStats.SegmentStats {
targetSegmentIDs = append(targetSegmentIDs, segID)
targetSegmentStats = append(targetSegmentStats, segStat)
}
}
PruneByScalarField(expr, targetSegmentStats, targetSegmentIDs, filteredSegments)
}
// 2. remove filtered segments from sealed segment list

View File

@ -248,6 +248,58 @@ func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() {
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(0, len(testSegments[1].Segments))
}
{
// test for not-equal operator, which is unsupported
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age!=156"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for term operator
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age in [100,200,300]"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(0, len(testSegments[1].Segments))
}
{
// test for not operator, segment prune don't support not operator
// so it's expected to get all segments here
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age not in [100,200,300]"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for range one expr part
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
@ -266,24 +318,6 @@ func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() {
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for unlogical binary range
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=700 and age<=500"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for unlogical binary range
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=500 and age<=550"
@ -299,6 +333,209 @@ func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() {
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
}
{
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "500<=age<=550"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
}
{
// test for multiple ranges connected with or operator
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "(age>=500 and age<=550) or (age>800 and age<950) or (age>300 and age<330)"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(1, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for multiple ranges connected with or operator
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "(age>=500 and age<=550) or (age>800 and age<950) or (age>300 and age<330) or age < 150"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for multiple ranges connected with or operator
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age > 600 or age < 300"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for multiple ranges connected with or operator
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age > 600 or age < 30"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
}
func (sps *SegmentPrunerSuite) TestPruneSegmentsWithUnrelatedField() {
sps.SetupForClustering("age", schemapb.DataType_Int32)
paramtable.Init()
targetPartitions := make([]UniqueID, 0)
targetPartitions = append(targetPartitions, sps.targetPartition)
{
// test for unrelated fields
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=500 and age<=550 and info != 'xxx'"
// as info is not cluster key field, so 'and' one more info condition will not influence the pruned result
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
}
{
// test for unrelated fields
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=500 and info != 'xxx' and age<=550"
// as info is not cluster key field, so 'and' one more info condition will not influence the pruned result
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
}
{
// test for unrelated fields
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=500 and age<=550 or info != 'xxx'"
// as info is not cluster key field, so 'or' one more will make it impossible to prune any segments
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for multiple ranges + unrelated field + or connector
// as info is not cluster key and or operator is applied, so prune cannot work and have to search all segments in this case
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "(age>=500 and age<=550) or info != 'xxx' or (age>800 and age<950) or (age>300 and age<330) or age < 50"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for multiple ranges + unrelated field + and connector
// as info is not cluster key and 'and' operator is applied, so prune conditions can work
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "(age>=500 and age<=550) and info != 'xxx' or (age>800 and age<950) or (age>300 and age<330) or age < 50"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(1, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "info in ['aa','bb','cc']"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
}
func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarStrField() {

View File

@ -22,6 +22,7 @@ import (
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/log"
)
@ -39,6 +40,20 @@ type ScalarFieldValue interface {
Size() int64
}
func MaxScalar(val1 ScalarFieldValue, val2 ScalarFieldValue) ScalarFieldValue {
if val1.GE(val2) {
return val1
}
return val2
}
func MinScalar(val1 ScalarFieldValue, val2 ScalarFieldValue) ScalarFieldValue {
if (val1).LE(val2) {
return val1
}
return val2
}
// DataType_Int8
type Int8FieldValue struct {
Value int8 `json:"value"`
@ -1014,6 +1029,23 @@ func (ifv *FloatVectorFieldValue) Size() int64 {
return int64(len(ifv.Value) * 8)
}
func NewScalarFieldValueFromGenericValue(dtype schemapb.DataType, gVal *planpb.GenericValue) ScalarFieldValue {
switch dtype {
case schemapb.DataType_Int64:
i64Val := gVal.Val.(*planpb.GenericValue_Int64Val)
return NewInt64FieldValue(i64Val.Int64Val)
case schemapb.DataType_Float:
floatVal := gVal.Val.(*planpb.GenericValue_FloatVal)
return NewFloatFieldValue(float32(floatVal.FloatVal))
case schemapb.DataType_String, schemapb.DataType_VarChar:
strVal := gVal.Val.(*planpb.GenericValue_StringVal)
return NewStringFieldValue(strVal.StringVal)
default:
// should not be reach
panic(fmt.Sprintf("not supported datatype: %s", dtype.String()))
}
}
func NewScalarFieldValue(dtype schemapb.DataType, data interface{}) ScalarFieldValue {
switch dtype {
case schemapb.DataType_Int8: