feat: support segment pruner (#31003)

related: #30376
pull/31487/head
Chun Han 2024-03-22 13:57:06 +08:00 committed by GitHub
parent c27db43ba7
commit c3264ca3e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1844 additions and 291 deletions

View File

@ -370,6 +370,7 @@ queryNode:
serverMaxRecvSize: 268435456
clientMaxSendSize: 268435456
clientMaxRecvSize: 536870912
enableSegmentPrune: false # use partition prune function on shard delegator
indexCoord:
bindIndexNodeMode:

View File

@ -42,13 +42,14 @@ func (v *ParserVisitor) translateIdentifier(identifier string) (*ExprWithType, e
Expr: &planpb.Expr_ColumnExpr{
ColumnExpr: &planpb.ColumnExpr{
Info: &planpb.ColumnInfo{
FieldId: field.FieldID,
DataType: field.DataType,
IsPrimaryKey: field.IsPrimaryKey,
IsAutoID: field.AutoID,
NestedPath: nestedPath,
IsPartitionKey: field.IsPartitionKey,
ElementType: field.GetElementType(),
FieldId: field.FieldID,
DataType: field.DataType,
IsPrimaryKey: field.IsPrimaryKey,
IsAutoID: field.AutoID,
NestedPath: nestedPath,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
ElementType: field.GetElementType(),
},
},
},

View File

@ -71,6 +71,7 @@ message ColumnInfo {
repeated string nested_path = 5;
bool is_partition_key = 6;
schema.DataType element_type = 7;
bool is_clustering_key = 8;
}
message ColumnExpr {

View File

@ -1,114 +0,0 @@
package proxy
import (
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/proto/planpb"
)
func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) {
node := plan.GetNode()
if node == nil {
return nil, errors.New("can't get expr from empty plan node")
}
var expr *planpb.Expr
switch node := node.(type) {
case *planpb.PlanNode_VectorAnns:
expr = node.VectorAnns.GetPredicates()
case *planpb.PlanNode_Query:
expr = node.Query.GetPredicates()
default:
return nil, errors.New("unsupported plan node type")
}
return expr, nil
}
func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr) ([]*planpb.GenericValue, bool) {
leftRes, leftInRange := ParsePartitionKeysFromExpr(expr.Left)
RightRes, rightInRange := ParsePartitionKeysFromExpr(expr.Right)
if expr.Op == planpb.BinaryExpr_LogicalAnd {
// case: partition_key_field in [7, 8] && partition_key > 8
if len(leftRes)+len(RightRes) > 0 {
leftRes = append(leftRes, RightRes...)
return leftRes, false
}
// case: other_field > 10 && partition_key_field > 8
return nil, leftInRange || rightInRange
}
if expr.Op == planpb.BinaryExpr_LogicalOr {
// case: partition_key_field in [7, 8] or partition_key > 8
if leftInRange || rightInRange {
return nil, true
}
// case: partition_key_field in [7, 8] or other_field > 10
leftRes = append(leftRes, RightRes...)
return leftRes, false
}
return nil, false
}
func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr) ([]*planpb.GenericValue, bool) {
res, partitionInRange := ParsePartitionKeysFromExpr(expr.GetChild())
if expr.Op == planpb.UnaryExpr_Not {
// case: partition_key_field not in [7, 8]
if len(res) != 0 {
return nil, true
}
// case: other_field not in [10]
return nil, partitionInRange
}
// UnaryOp only includes "Not" for now
return res, partitionInRange
}
func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr) ([]*planpb.GenericValue, bool) {
if expr.GetColumnInfo().GetIsPartitionKey() {
return expr.GetValues(), false
}
return nil, false
}
func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr) ([]*planpb.GenericValue, bool) {
if expr.GetColumnInfo().GetIsPartitionKey() && expr.GetOp() == planpb.OpType_Equal {
return []*planpb.GenericValue{expr.Value}, false
}
return nil, true
}
func ParsePartitionKeysFromExpr(expr *planpb.Expr) ([]*planpb.GenericValue, bool) {
var res []*planpb.GenericValue
partitionKeyInRange := false
switch expr := expr.GetExpr().(type) {
case *planpb.Expr_BinaryExpr:
res, partitionKeyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr)
case *planpb.Expr_UnaryExpr:
res, partitionKeyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr)
case *planpb.Expr_TermExpr:
res, partitionKeyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr)
case *planpb.Expr_UnaryRangeExpr:
res, partitionKeyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr)
}
return res, partitionKeyInRange
}
func ParsePartitionKeys(expr *planpb.Expr) []*planpb.GenericValue {
res, partitionKeyInRange := ParsePartitionKeysFromExpr(expr)
if partitionKeyInRange {
res = nil
}
return res
}

View File

@ -1,143 +0,0 @@
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestParsePartitionKeys(t *testing.T) {
prefix := "TestParsePartitionKeys"
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
partitionKeyField := &schemapb.FieldSchema{
Name: "partition_key_field",
DataType: schemapb.DataType_Int64,
IsPartitionKey: true,
}
schema.Fields = append(schema.Fields, partitionKeyField)
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
fieldID := common.StartOfUserFieldID
for _, field := range schema.Fields {
field.FieldID = int64(fieldID)
fieldID++
}
queryInfo := &planpb.QueryInfo{
Topk: 10,
MetricType: "L2",
SearchParams: "",
RoundDecimal: -1,
}
type testCase struct {
name string
expr string
expected int
validPartitionKeys []int64
invalidPartitionKeys []int64
}
cases := []testCase{
{
name: "binary_expr_and with term",
expr: "partition_key_field in [7, 8] && int64_field >= 10",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with equal",
expr: "partition_key_field == 7 && int64_field >= 10",
expected: 1,
validPartitionKeys: []int64{7},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with term2",
expr: "partition_key_field in [7, 8] && int64_field == 10",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{10},
},
{
name: "binary_expr_and with partition key in range",
expr: "partition_key_field in [7, 8] && partition_key_field > 9",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{9},
},
{
name: "binary_expr_and with partition key in range2",
expr: "int64_field == 10 && partition_key_field > 9",
expected: 0,
validPartitionKeys: []int64{},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with term and not",
expr: "partition_key_field in [7, 8] && partition_key_field not in [10, 20]",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{10, 20},
},
{
name: "binary_expr_or with term and not",
expr: "partition_key_field in [7, 8] or partition_key_field not in [10, 20]",
expected: 0,
validPartitionKeys: []int64{},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_or with term and not 2",
expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{10, 20},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// test search plan
searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo)
assert.NoError(t, err)
expr, err := ParseExprFromPlan(searchPlan)
assert.NoError(t, err)
partitionKeys := ParsePartitionKeys(expr)
assert.Equal(t, tc.expected, len(partitionKeys))
for _, key := range partitionKeys {
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
assert.Contains(t, tc.validPartitionKeys, int64Val)
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
}
// test query plan
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, tc.expr)
assert.NoError(t, err)
expr, err = ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
partitionKeys = ParsePartitionKeys(expr)
assert.Equal(t, tc.expected, len(partitionKeys))
for _, key := range partitionKeys {
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
assert.Contains(t, tc.validPartitionKeys, int64Val)
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
}
})
}
}

View File

@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -92,12 +93,12 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
if t.partitionKeyMode {
expr, err := ParseExprFromPlan(plan)
expr, err := exprutil.ParseExprFromPlan(plan)
if err != nil {
log.Warn("failed to parse expr", zap.Error(err))
return err
}
partitionKeys := ParsePartitionKeys(expr)
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
if err != nil {
log.Warn("failed to assign partition keys", zap.Error(err))

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
@ -356,11 +357,11 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
// optimize query when partitionKey on
if dr.partitionKeyMode {
expr, err := ParseExprFromPlan(plan)
expr, err := exprutil.ParseExprFromPlan(plan)
if err != nil {
return err
}
partitionKeys := ParsePartitionKeys(expr)
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
hashedPartitionNames, err := assignPartitionKeys(ctx, dr.req.GetDbName(), dr.req.GetCollectionName(), partitionKeys)
if err != nil {
return err

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -358,11 +359,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if !t.reQuery {
partitionNames := t.request.GetPartitionNames()
if t.partitionKeyMode {
expr, err := ParseExprFromPlan(t.plan)
expr, err := exprutil.ParseExprFromPlan(t.plan)
if err != nil {
return err
}
partitionKeys := ParsePartitionKeys(expr)
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
if err != nil {
return err

View File

@ -20,6 +20,8 @@ package delegator
import (
"context"
"fmt"
"path"
"strconv"
"sync"
"time"
@ -42,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
@ -49,6 +52,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -115,7 +119,9 @@ type shardDelegator struct {
tsCond *sync.Cond
latestTsafe *atomic.Uint64
// queryHook
queryHook optimizers.QueryHook
queryHook optimizers.QueryHook
partitionStats map[UniqueID]*storage.PartitionStatsSnapshot
chunkManager storage.ChunkManager
}
// getLogger returns the zap logger with pre-defined shard attributes.
@ -203,6 +209,9 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
log.Warn("failed to optimize search params", zap.Error(err))
return nil, err
}
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed, PruneInfo{filterRatio: defaultFilterRatio})
}
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
if err != nil {
@ -485,12 +494,17 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
})
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
} else {
existPartitions := sd.collection.GetPartitions()
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
})
}
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{defaultFilterRatio})
}
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
@ -774,10 +788,72 @@ func (sd *shardDelegator) Close() {
sd.lifetime.Wait()
}
// As partition stats is an optimization for search/query which is not mandatory for milvus instance,
// loading partitionStats will be a try-best process and will skip+logError when running across errors rather than
// return an error status
func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs ...UniqueID) {
var partsToReload []UniqueID
if len(partIDs) > 0 {
partsToReload = partIDs
} else {
partsToReload = append(partsToReload, sd.collection.GetPartitions()...)
}
colID := sd.Collection()
findMaxVersion := func(filePaths []string) (int64, string) {
maxVersion := int64(-1)
maxVersionFilePath := ""
for _, filePath := range filePaths {
versionStr := path.Base(filePath)
version, err := strconv.ParseInt(versionStr, 10, 64)
if err != nil {
continue
}
if version > maxVersion {
maxVersion = version
maxVersionFilePath = filePath
}
}
return maxVersion, maxVersionFilePath
}
for _, partID := range partsToReload {
idPath := metautil.JoinIDPath(colID, partID)
idPath = path.Join(idPath, sd.vchannelName)
statsPathPrefix := path.Join(sd.chunkManager.RootPath(), common.PartitionStatsPath, idPath)
filePaths, _, err := sd.chunkManager.ListWithPrefix(ctx, statsPathPrefix, true)
if err != nil {
log.Error("Skip initializing partition stats for failing to list files with prefix",
zap.String("statsPathPrefix", statsPathPrefix))
continue
}
maxVersion, maxVersionFilePath := findMaxVersion(filePaths)
if maxVersion < 0 {
log.Info("failed to find valid partition stats file for partition", zap.Int64("partitionID", partID))
continue
}
partStats, exists := sd.partitionStats[partID]
if !exists || (exists && partStats.GetVersion() < maxVersion) {
statsBytes, err := sd.chunkManager.Read(ctx, maxVersionFilePath)
if err != nil {
log.Error("failed to read stats file from object storage", zap.String("path", maxVersionFilePath))
continue
}
partStats, err := storage.DeserializePartitionsStatsSnapshot(statsBytes)
if err != nil {
log.Error("failed to parse partition stats from bytes", zap.Int("bytes_length", len(statsBytes)))
continue
}
sd.partitionStats[partID] = partStats
partStats.SetVersion(maxVersion)
log.Info("Updated partitionStats for partition", zap.Int64("partitionID", partID))
}
}
}
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
) (ShardDelegator, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
zap.Int64("replicaID", replicaID),
@ -812,6 +888,8 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
loader: loader,
factory: factory,
queryHook: queryHook,
chunkManager: chunkManager,
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
@ -819,5 +897,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
go sd.watchTSafe()
}
log.Info("finish build new shardDelegator")
sd.maybeReloadPartitionStats(ctx)
return sd, nil
}

View File

@ -475,6 +475,12 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
// alter distribution
sd.distribution.AddDistributions(entries...)
partStatsToReload := make([]UniqueID, 0)
lo.ForEach(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) {
partStatsToReload = append(partStatsToReload, info.PartitionID)
})
sd.maybeReloadPartitionStats(ctx, partStatsToReload...)
return nil
}
@ -850,7 +856,14 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
if hasLevel0 {
sd.GenerateLevel0DeletionCache()
}
partitionsToReload := make([]UniqueID, 0)
lo.ForEach(req.GetSegmentIDs(), func(segmentID int64, _ int) {
segment := sd.segmentManager.Get(segmentID)
if segment != nil {
partitionsToReload = append(partitionsToReload, segment.Partition())
}
})
sd.maybeReloadPartitionStats(ctx, partitionsToReload...)
return nil
}

View File

@ -18,6 +18,8 @@ package delegator
import (
"context"
"path"
"strconv"
"testing"
bloom "github.com/bits-and-blooms/bloom/v3"
@ -41,6 +43,7 @@ import (
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -58,7 +61,9 @@ type DelegatorDataSuite struct {
loader *segments.MockLoader
mq *msgstream.MockMsgStream
delegator *shardDelegator
delegator *shardDelegator
rootPath string
chunkManager storage.ChunkManager
}
func (s *DelegatorDataSuite) SetupSuite() {
@ -126,16 +131,19 @@ func (s *DelegatorDataSuite) SetupTest() {
},
},
}, &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
LoadType: querypb.LoadType_LoadCollection,
PartitionIDs: []int64{1001, 1002},
})
s.mq = &msgstream.MockMsgStream{}
s.rootPath = s.Suite.T().Name()
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil)
}, 10000, nil, s.chunkManager)
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)
s.Require().True(ok)
@ -609,7 +617,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil)
}, 10000, nil, nil)
s.NoError(err)
growing0 := segments.NewMockSegment(s.T())
@ -968,6 +976,78 @@ func (s *DelegatorDataSuite) TestReleaseSegment() {
s.NoError(err)
}
func (s *DelegatorDataSuite) TestLoadPartitionStats() {
segStats := make(map[UniqueID]storage.SegmentStats)
centroid := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}
var segID int64 = 1
rows := 1990
{
// p1 stats
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: 1,
Type: schemapb.DataType_Int64,
Max: storage.NewInt64FieldValue(200),
Min: storage.NewInt64FieldValue(100),
}
fieldStat2 := storage.FieldStats{
FieldID: 2,
Type: schemapb.DataType_Int64,
Max: storage.NewInt64FieldValue(400),
Min: storage.NewInt64FieldValue(300),
}
fieldStat3 := storage.FieldStats{
FieldID: 3,
Type: schemapb.DataType_FloatVector,
Centroids: []storage.VectorFieldValue{
&storage.FloatVectorFieldValue{
Value: centroid,
},
&storage.FloatVectorFieldValue{
Value: centroid,
},
},
}
fieldStats = append(fieldStats, fieldStat1)
fieldStats = append(fieldStats, fieldStat2)
fieldStats = append(fieldStats, fieldStat3)
segStats[segID] = *storage.NewSegmentStats(fieldStats, rows)
}
partitionStats1 := &storage.PartitionStatsSnapshot{
SegmentStats: segStats,
}
statsData1, err := storage.SerializePartitionStatsSnapshot(partitionStats1)
s.NoError(err)
partitionID1 := int64(1001)
idPath1 := metautil.JoinIDPath(s.collectionID, partitionID1)
idPath1 = path.Join(idPath1, s.delegator.vchannelName)
statsPath1 := path.Join(s.chunkManager.RootPath(), common.PartitionStatsPath, idPath1, strconv.Itoa(1))
s.chunkManager.Write(context.Background(), statsPath1, statsData1)
defer s.chunkManager.Remove(context.Background(), statsPath1)
// reload and check partition stats
s.delegator.maybeReloadPartitionStats(context.Background())
s.Equal(1, len(s.delegator.partitionStats))
s.NotNil(s.delegator.partitionStats[partitionID1])
p1Stats := s.delegator.partitionStats[partitionID1]
s.Equal(int64(1), p1Stats.GetVersion())
s.Equal(rows, p1Stats.SegmentStats[segID].NumRows)
s.Equal(3, len(p1Stats.SegmentStats[segID].FieldStats))
// judge vector stats
vecFieldStats := p1Stats.SegmentStats[segID].FieldStats[2]
s.Equal(2, len(vecFieldStats.Centroids))
s.Equal(8, len(vecFieldStats.Centroids[0].GetValue().([]float32)))
// judge scalar stats
fieldStats1 := p1Stats.SegmentStats[segID].FieldStats[0]
s.Equal(int64(100), fieldStats1.Min.GetValue().(int64))
s.Equal(int64(200), fieldStats1.Max.GetValue().(int64))
fieldStats2 := p1Stats.SegmentStats[segID].FieldStats[1]
s.Equal(int64(300), fieldStats2.Min.GetValue().(int64))
s.Equal(int64(400), fieldStats2.Max.GetValue().(int64))
}
func (s *DelegatorDataSuite) TestSyncTargetVersion() {
for i := int64(0); i < 5; i++ {
ms := &segments.MockSegment{}

View File

@ -40,6 +40,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
@ -64,7 +65,9 @@ type DelegatorSuite struct {
loader *segments.MockLoader
mq *msgstream.MockMsgStream
delegator ShardDelegator
delegator ShardDelegator
chunkManager storage.ChunkManager
rootPath string
}
func (s *DelegatorSuite) SetupSuite() {
@ -154,6 +157,11 @@ func (s *DelegatorSuite) SetupTest() {
})
s.mq = &msgstream.MockMsgStream{}
s.rootPath = "delegator_test"
// init chunkManager
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
var err error
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
@ -161,7 +169,7 @@ func (s *DelegatorSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil)
}, 10000, nil, s.chunkManager)
s.Require().NoError(err)
}

View File

@ -0,0 +1,228 @@
package delegator
import (
"context"
"sort"
"strconv"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/clustering"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
)
const defaultFilterRatio float64 = 0.5
type PruneInfo struct {
filterRatio float64
}
func PruneSegments(ctx context.Context,
partitionStats map[UniqueID]*storage.PartitionStatsSnapshot,
searchReq *internalpb.SearchRequest,
queryReq *internalpb.RetrieveRequest,
schema *schemapb.CollectionSchema,
sealedSegments []SnapshotItem,
info PruneInfo,
) {
log := log.Ctx(ctx)
// 1. calculate filtered segments
filteredSegments := make(map[UniqueID]struct{}, 0)
clusteringKeyField := typeutil.GetClusteringKeyField(schema.Fields)
if clusteringKeyField == nil {
return
}
if searchReq != nil {
// parse searched vectors
var vectorsHolder commonpb.PlaceholderGroup
err := proto.Unmarshal(searchReq.GetPlaceholderGroup(), &vectorsHolder)
if err != nil || len(vectorsHolder.GetPlaceholders()) == 0 {
return
}
vectorsBytes := vectorsHolder.GetPlaceholders()[0].GetValues()
// parse dim
dimStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.DimKey, clusteringKeyField.GetTypeParams())
if err != nil {
return
}
dimValue, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
return
}
for _, partID := range searchReq.GetPartitionIDs() {
partStats := partitionStats[partID]
FilterSegmentsByVector(partStats, searchReq, vectorsBytes, dimValue, clusteringKeyField, filteredSegments, info.filterRatio)
}
} else if queryReq != nil {
// 0. parse expr from plan
plan := planpb.PlanNode{}
err := proto.Unmarshal(queryReq.GetSerializedExprPlan(), &plan)
if err != nil {
log.Error("failed to unmarshall serialized expr from bytes, failed the operation")
return
}
expr, err := exprutil.ParseExprFromPlan(&plan)
if err != nil {
log.Error("failed to parse expr from plan, failed the operation")
return
}
targetRanges, matchALL := exprutil.ParseRanges(expr, exprutil.ClusteringKey)
if matchALL || targetRanges == nil {
return
}
for _, partID := range queryReq.GetPartitionIDs() {
partStats := partitionStats[partID]
FilterSegmentsOnScalarField(partStats, targetRanges, clusteringKeyField, filteredSegments)
}
}
// 2. remove filtered segments from sealed segment list
if len(filteredSegments) > 0 {
totalSegNum := 0
for idx, item := range sealedSegments {
newSegments := make([]SegmentEntry, 0)
totalSegNum += len(item.Segments)
for _, segment := range item.Segments {
if _, ok := filteredSegments[segment.SegmentID]; !ok {
newSegments = append(newSegments, segment)
}
}
item.Segments = newSegments
sealedSegments[idx] = item
}
log.Debug("Pruned segment for search/query",
zap.Int("pruned_segment_num", len(filteredSegments)),
zap.Int("total_segment_num", totalSegNum),
)
}
}
type segmentDisStruct struct {
segmentID UniqueID
distance float32
rows int // for keep track of sufficiency of topK
}
func FilterSegmentsByVector(partitionStats *storage.PartitionStatsSnapshot,
searchReq *internalpb.SearchRequest,
vectorBytes [][]byte,
dim int64,
keyField *schemapb.FieldSchema,
filteredSegments map[UniqueID]struct{},
filterRatio float64,
) {
// 1. calculate vectors' distances
neededSegments := make(map[UniqueID]struct{})
for _, vecBytes := range vectorBytes {
segmentsToSearch := make([]segmentDisStruct, 0)
for segId, segStats := range partitionStats.SegmentStats {
// here, we do not skip needed segments required by former query vector
// meaning that repeated calculation will be carried and the larger the nq is
// the more segments have to be included and prune effect will decline
// 1. calculate distances from centroids
for _, fieldStat := range segStats.FieldStats {
if fieldStat.FieldID == keyField.GetFieldID() {
if fieldStat.Centroids == nil || len(fieldStat.Centroids) == 0 {
neededSegments[segId] = struct{}{}
break
}
var dis []float32
var disErr error
switch keyField.GetDataType() {
case schemapb.DataType_FloatVector:
dis, disErr = clustering.CalcVectorDistance(dim, keyField.GetDataType(),
vecBytes, fieldStat.Centroids[0].GetValue().([]float32), searchReq.GetMetricType())
default:
neededSegments[segId] = struct{}{}
disErr = merr.WrapErrParameterInvalid(schemapb.DataType_FloatVector, keyField.GetDataType(),
"Currently, pruning by cluster only support float_vector type")
}
// currently, we only support float vector and only one center one segment
if disErr != nil {
neededSegments[segId] = struct{}{}
break
}
segmentsToSearch = append(segmentsToSearch, segmentDisStruct{
segmentID: segId,
distance: dis[0],
rows: segStats.NumRows,
})
break
}
}
}
// 2. sort the distances
switch searchReq.GetMetricType() {
case distance.L2:
sort.SliceStable(segmentsToSearch, func(i, j int) bool {
return segmentsToSearch[i].distance < segmentsToSearch[j].distance
})
case distance.IP, distance.COSINE:
sort.SliceStable(segmentsToSearch, func(i, j int) bool {
return segmentsToSearch[i].distance > segmentsToSearch[j].distance
})
}
// 3. filtered non-target segments
segmentCount := len(segmentsToSearch)
targetSegNum := int(float64(segmentCount) * filterRatio)
optimizedRowCount := 0
// set the last n - targetSegNum as being filtered
for i := 0; i < segmentCount; i++ {
optimizedRowCount += segmentsToSearch[i].rows
neededSegments[segmentsToSearch[i].segmentID] = struct{}{}
if int64(optimizedRowCount) >= searchReq.GetTopk() && i >= targetSegNum {
break
}
}
}
// 3. set not needed segments as removed
for segId := range partitionStats.SegmentStats {
if _, ok := neededSegments[segId]; !ok {
filteredSegments[segId] = struct{}{}
}
}
}
func FilterSegmentsOnScalarField(partitionStats *storage.PartitionStatsSnapshot,
targetRanges []*exprutil.PlanRange,
keyField *schemapb.FieldSchema,
filteredSegments map[UniqueID]struct{},
) {
// 1. try to filter segments
overlap := func(min storage.ScalarFieldValue, max storage.ScalarFieldValue) bool {
for _, tRange := range targetRanges {
switch keyField.DataType {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64:
targetRange := tRange.ToIntRange()
statRange := exprutil.NewIntRange(min.GetValue().(int64), max.GetValue().(int64), true, true)
return exprutil.IntRangeOverlap(targetRange, statRange)
case schemapb.DataType_String, schemapb.DataType_VarChar:
targetRange := tRange.ToStrRange()
statRange := exprutil.NewStrRange(min.GetValue().(string), max.GetValue().(string), true, true)
return exprutil.StrRangeOverlap(targetRange, statRange)
}
}
return false
}
for segID, segStats := range partitionStats.SegmentStats {
for _, fieldStat := range segStats.FieldStats {
if keyField.FieldID == fieldStat.FieldID && !overlap(fieldStat.Min, fieldStat.Max) {
filteredSegments[segID] = struct{}{}
}
}
}
}

View File

@ -0,0 +1,422 @@
package delegator
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/clustering"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type SegmentPrunerSuite struct {
suite.Suite
partitionStats map[UniqueID]*storage.PartitionStatsSnapshot
schema *schemapb.CollectionSchema
collectionName string
primaryFieldName string
clusterKeyFieldName string
autoID bool
targetPartition int64
dim int
sealedSegments []SnapshotItem
}
func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string,
clusterKeyFieldType schemapb.DataType,
) {
sps.collectionName = "test_segment_prune"
sps.primaryFieldName = "pk"
sps.clusterKeyFieldName = clusterKeyFieldName
sps.autoID = true
sps.dim = 8
fieldName2DataType := make(map[string]schemapb.DataType)
fieldName2DataType[sps.primaryFieldName] = schemapb.DataType_Int64
fieldName2DataType[sps.clusterKeyFieldName] = clusterKeyFieldType
fieldName2DataType["info"] = schemapb.DataType_VarChar
fieldName2DataType["age"] = schemapb.DataType_Int32
fieldName2DataType["vec"] = schemapb.DataType_FloatVector
sps.schema = testutil.ConstructCollectionSchemaWithKeys(sps.collectionName,
fieldName2DataType,
sps.primaryFieldName,
"",
sps.clusterKeyFieldName,
false,
sps.dim)
var clusteringKeyFieldID int64 = 0
for _, field := range sps.schema.GetFields() {
if field.IsClusteringKey {
clusteringKeyFieldID = field.FieldID
break
}
}
centroids1 := []storage.VectorFieldValue{
&storage.FloatVectorFieldValue{
Value: []float32{0.6951474, 0.45225978, 0.51508516, 0.24968886, 0.6085484, 0.964968, 0.32239532, 0.7771577},
},
}
centroids2 := []storage.VectorFieldValue{
&storage.FloatVectorFieldValue{
Value: []float32{0.12345678, 0.23456789, 0.34567890, 0.45678901, 0.56789012, 0.67890123, 0.78901234, 0.89012345},
},
}
centroids3 := []storage.VectorFieldValue{
&storage.FloatVectorFieldValue{
Value: []float32{0.98765432, 0.87654321, 0.76543210, 0.65432109, 0.54321098, 0.43210987, 0.32109876, 0.21098765},
},
}
centroids4 := []storage.VectorFieldValue{
&storage.FloatVectorFieldValue{
Value: []float32{0.11111111, 0.22222222, 0.33333333, 0.44444444, 0.55555555, 0.66666666, 0.77777777, 0.88888888},
},
}
// init partition stats
// here, for convenience, we set up both min/max and Centroids
// into the same struct, in the real user cases, a field stat
// can either contain min&&max or centroids
segStats := make(map[UniqueID]storage.SegmentStats)
switch clusterKeyFieldType {
case schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8:
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_Int64,
Min: storage.NewInt64FieldValue(100),
Max: storage.NewInt64FieldValue(200),
Centroids: centroids1,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[1] = *storage.NewSegmentStats(fieldStats, 80)
}
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_Int64,
Min: storage.NewInt64FieldValue(100),
Max: storage.NewInt64FieldValue(400),
Centroids: centroids2,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[2] = *storage.NewSegmentStats(fieldStats, 80)
}
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_Int64,
Min: storage.NewInt64FieldValue(600),
Max: storage.NewInt64FieldValue(900),
Centroids: centroids3,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[3] = *storage.NewSegmentStats(fieldStats, 80)
}
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_Int64,
Min: storage.NewInt64FieldValue(500),
Max: storage.NewInt64FieldValue(1000),
Centroids: centroids4,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[4] = *storage.NewSegmentStats(fieldStats, 80)
}
default:
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_VarChar,
Min: storage.NewStringFieldValue("ab"),
Max: storage.NewStringFieldValue("bbc"),
Centroids: centroids1,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[1] = *storage.NewSegmentStats(fieldStats, 80)
}
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_VarChar,
Min: storage.NewStringFieldValue("hhh"),
Max: storage.NewStringFieldValue("jjx"),
Centroids: centroids2,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[2] = *storage.NewSegmentStats(fieldStats, 80)
}
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_VarChar,
Min: storage.NewStringFieldValue("kkk"),
Max: storage.NewStringFieldValue("lmn"),
Centroids: centroids3,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[3] = *storage.NewSegmentStats(fieldStats, 80)
}
{
fieldStats := make([]storage.FieldStats, 0)
fieldStat1 := storage.FieldStats{
FieldID: clusteringKeyFieldID,
Type: schemapb.DataType_VarChar,
Min: storage.NewStringFieldValue("oo2"),
Max: storage.NewStringFieldValue("pptt"),
Centroids: centroids4,
}
fieldStats = append(fieldStats, fieldStat1)
segStats[4] = *storage.NewSegmentStats(fieldStats, 80)
}
}
sps.partitionStats = make(map[UniqueID]*storage.PartitionStatsSnapshot)
sps.targetPartition = 11111
sps.partitionStats[sps.targetPartition] = &storage.PartitionStatsSnapshot{
SegmentStats: segStats,
}
sealedSegments := make([]SnapshotItem, 0)
item1 := SnapshotItem{
NodeID: 1,
Segments: []SegmentEntry{
{
NodeID: 1,
SegmentID: 1,
},
{
NodeID: 1,
SegmentID: 2,
},
},
}
item2 := SnapshotItem{
NodeID: 2,
Segments: []SegmentEntry{
{
NodeID: 2,
SegmentID: 3,
},
{
NodeID: 2,
SegmentID: 4,
},
},
}
sealedSegments = append(sealedSegments, item1)
sealedSegments = append(sealedSegments, item2)
sps.sealedSegments = sealedSegments
}
func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() {
sps.SetupForClustering("age", schemapb.DataType_Int32)
targetPartitions := make([]UniqueID, 0)
targetPartitions = append(targetPartitions, sps.targetPartition)
{
// test for exact values
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age==156"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(0, len(testSegments[1].Segments))
}
{
// test for range one expr part
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=700"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for unlogical binary range
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=700 and age<=500"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(2, len(testSegments[1].Segments))
}
{
// test for unlogical binary range
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := "age>=500 and age<=550"
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
}
}
func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarStrField() {
sps.SetupForClustering("info", schemapb.DataType_VarChar)
targetPartitions := make([]UniqueID, 0)
targetPartitions = append(targetPartitions, sps.targetPartition)
{
// test for exact str values
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := `info=="rag"`
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(0, len(testSegments[1].Segments))
// there should be no segments fulfilling the info=="rag"
}
{
// test for exact str values
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := `info=="kpl"`
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(0, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
// there should be no segments fulfilling the info=="rag"
}
{
// test for unary str values
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
copy(testSegments, sps.sealedSegments)
exprStr := `info<="less"`
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
sps.NoError(err)
serializedPlan, _ := proto.Marshal(planNode)
queryReq := &internalpb.RetrieveRequest{
SerializedExprPlan: serializedPlan,
PartitionIDs: targetPartitions,
}
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
sps.Equal(2, len(testSegments[0].Segments))
sps.Equal(1, len(testSegments[1].Segments))
// there should be no segments fulfilling the info=="rag"
}
}
func vector2Placeholder(vectors [][]float32) *commonpb.PlaceholderValue {
ph := &commonpb.PlaceholderValue{
Tag: "$0",
Values: make([][]byte, 0, len(vectors)),
}
if len(vectors) == 0 {
return ph
}
ph.Type = commonpb.PlaceholderType_FloatVector
for _, vector := range vectors {
ph.Values = append(ph.Values, clustering.SerializeFloatVector(vector))
}
return ph
}
func (sps *SegmentPrunerSuite) TestPruneSegmentsByVectorField() {
sps.SetupForClustering("vec", schemapb.DataType_FloatVector)
vector1 := []float32{0.8877872002188053, 0.6131822285635065, 0.8476814632326242, 0.6645877829359371, 0.9962627712600025, 0.8976183052440327, 0.41941169325798844, 0.7554387854258499}
vector2 := []float32{0.8644394874390322, 0.023327886647378615, 0.08330118483461302, 0.7068040179963112, 0.6983994910799851, 0.5562075958994153, 0.3288536247938002, 0.07077341010237759}
vectors := [][]float32{vector1, vector2}
phg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
vector2Placeholder(vectors),
},
}
bs, _ := proto.Marshal(phg)
// test for L2 metrics
req := &internalpb.SearchRequest{
MetricType: "L2",
PlaceholderGroup: bs,
PartitionIDs: []UniqueID{sps.targetPartition},
Topk: 100,
}
PruneSegments(context.TODO(), sps.partitionStats, req, nil, sps.schema, sps.sealedSegments, PruneInfo{0.25})
sps.Equal(1, len(sps.sealedSegments[0].Segments))
sps.Equal(int64(1), sps.sealedSegments[0].Segments[0].SegmentID)
sps.Equal(1, len(sps.sealedSegments[1].Segments))
sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID)
// test for IP metrics
req = &internalpb.SearchRequest{
MetricType: "IP",
PlaceholderGroup: bs,
PartitionIDs: []UniqueID{sps.targetPartition},
Topk: 100,
}
PruneSegments(context.TODO(), sps.partitionStats, req, nil, sps.schema, sps.sealedSegments, PruneInfo{0.25})
sps.Equal(1, len(sps.sealedSegments[0].Segments))
sps.Equal(int64(1), sps.sealedSegments[0].Segments[0].SegmentID)
sps.Equal(1, len(sps.sealedSegments[1].Segments))
sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID)
}
func TestSegmentPrunerSuite(t *testing.T) {
suite.Run(t, new(SegmentPrunerSuite))
}

View File

@ -60,7 +60,7 @@ func (suite *ReduceSuite) SetupTest() {
msgLength := 100
suite.rootPath = suite.T().Name()
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
initcore.InitRemoteChunkManager(paramtable.Get())

View File

@ -61,7 +61,7 @@ func (suite *RetrieveSuite) SetupTest() {
msgLength := 100
suite.rootPath = suite.T().Name()
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
initcore.InitRemoteChunkManager(paramtable.Get())

View File

@ -78,7 +78,7 @@ func (suite *SegmentLoaderSuite) SetupTest() {
// TODO:: cpp chunk manager not support local chunk manager
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
suite.loader = NewLoader(suite.manager, suite.chunkManager)
initcore.InitRemoteChunkManager(paramtable.Get())
@ -678,7 +678,7 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() {
}
ctx := context.Background()
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
suite.loader = NewLoader(suite.manager, suite.chunkManager)
initcore.InitRemoteChunkManager(paramtable.Get())
@ -847,7 +847,7 @@ func (suite *SegmentLoaderV2Suite) SetupTest() {
// TODO:: cpp chunk manager not support local chunk manager
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
suite.loader = NewLoaderV2(suite.manager, suite.chunkManager)
initcore.InitRemoteChunkManager(paramtable.Get())

View File

@ -39,7 +39,7 @@ func (suite *SegmentSuite) SetupTest() {
msgLength := 100
suite.rootPath = suite.T().Name()
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
initcore.InitRemoteChunkManager(paramtable.Get())

View File

@ -254,6 +254,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
node.factory,
channel.GetSeekPosition().GetTimestamp(),
node.queryHook,
node.chunkManager,
)
if err != nil {
log.Warn("failed to create shard delegator", zap.Error(err))

View File

@ -116,7 +116,7 @@ func (suite *ServiceSuite) SetupTest() {
suite.msgStream = msgstream.NewMockMsgStream(suite.T())
// TODO:: cpp chunk manager not support local chunk manager
// suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test"))
suite.chunkManagerFactory = segments.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManagerFactory = storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.factory.EXPECT().Init(mock.Anything).Return()
suite.factory.EXPECT().NewPersistentStorageChunkManager(mock.Anything).Return(suite.chunkManagerFactory.NewPersistentStorageChunkManager(ctx))

View File

@ -20,6 +20,14 @@ import "encoding/json"
type SegmentStats struct {
FieldStats []FieldStats `json:"fieldStats"`
NumRows int
}
func NewSegmentStats(fieldStats []FieldStats, rows int) *SegmentStats {
return &SegmentStats{
FieldStats: fieldStats,
NumRows: rows,
}
}
type PartitionStatsSnapshot struct {

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -1247,3 +1248,17 @@ func Min(a, b int64) int64 {
}
return b
}
func NewTestChunkManagerFactory(params *paramtable.ComponentParam, rootPath string) *ChunkManagerFactory {
return NewChunkManagerFactory("minio",
RootPath(rootPath),
Address(params.MinioCfg.Address.GetValue()),
AccessKeyID(params.MinioCfg.AccessKeyID.GetValue()),
SecretAccessKeyID(params.MinioCfg.SecretAccessKey.GetValue()),
UseSSL(params.MinioCfg.UseSSL.GetAsBool()),
BucketName(params.MinioCfg.BucketName.GetValue()),
UseIAM(params.MinioCfg.UseIAM.GetAsBool()),
CloudProvider(params.MinioCfg.CloudProvider.GetValue()),
IAMEndpoint(params.MinioCfg.IAMEndpoint.GetValue()),
CreateBucket(true))
}

View File

@ -0,0 +1,50 @@
package clustering
import (
"encoding/binary"
"math"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func CalcVectorDistance(dim int64, dataType schemapb.DataType, left []byte, right []float32, metric string) ([]float32, error) {
switch dataType {
case schemapb.DataType_FloatVector:
distance, err := distance.CalcFloatDistance(dim, DeserializeFloatVector(left), right, metric)
if err != nil {
return nil, err
}
return distance, nil
// todo support other vector type
case schemapb.DataType_BinaryVector:
case schemapb.DataType_Float16Vector:
case schemapb.DataType_BFloat16Vector:
default:
return nil, merr.ErrParameterInvalid
}
return nil, nil
}
func DeserializeFloatVector(data []byte) []float32 {
vectorLen := len(data) / 4 // Each float32 occupies 4 bytes
fv := make([]float32, vectorLen)
for i := 0; i < vectorLen; i++ {
bits := binary.LittleEndian.Uint32(data[i*4 : (i+1)*4])
fv[i] = math.Float32frombits(bits)
}
return fv
}
func SerializeFloatVector(fv []float32) []byte {
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
buf := make([]byte, 4)
for _, f := range fv {
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
data = append(data, buf...)
}
return data
}

View File

@ -0,0 +1,511 @@
package exprutil
import (
"math"
"strings"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
)
type KeyType int64
const (
PartitionKey KeyType = iota
ClusteringKey KeyType = PartitionKey + 1
)
func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) {
node := plan.GetNode()
if node == nil {
return nil, errors.New("can't get expr from empty plan node")
}
var expr *planpb.Expr
switch node := node.(type) {
case *planpb.PlanNode_VectorAnns:
expr = node.VectorAnns.GetPredicates()
case *planpb.PlanNode_Query:
expr = node.Query.GetPredicates()
default:
return nil, errors.New("unsupported plan node type")
}
return expr, nil
}
func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
leftRes, leftInRange := ParseKeysFromExpr(expr.Left, keyType)
rightRes, rightInRange := ParseKeysFromExpr(expr.Right, keyType)
if expr.Op == planpb.BinaryExpr_LogicalAnd {
// case: partition_key_field in [7, 8] && partition_key > 8
if len(leftRes)+len(rightRes) > 0 {
leftRes = append(leftRes, rightRes...)
return leftRes, false
}
// case: other_field > 10 && partition_key_field > 8
return nil, leftInRange || rightInRange
}
if expr.Op == planpb.BinaryExpr_LogicalOr {
// case: partition_key_field in [7, 8] or partition_key > 8
if leftInRange || rightInRange {
return nil, true
}
// case: partition_key_field in [7, 8] or other_field > 10
leftRes = append(leftRes, rightRes...)
return leftRes, false
}
return nil, false
}
func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
res, partitionInRange := ParseKeysFromExpr(expr.GetChild(), keyType)
if expr.Op == planpb.UnaryExpr_Not {
// case: partition_key_field not in [7, 8]
if len(res) != 0 {
return nil, true
}
// case: other_field not in [10]
return nil, partitionInRange
}
// UnaryOp only includes "Not" for now
return res, partitionInRange
}
func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
if keyType == PartitionKey && expr.GetColumnInfo().GetIsPartitionKey() {
return expr.GetValues(), false
} else if keyType == ClusteringKey && expr.GetColumnInfo().GetIsClusteringKey() {
return expr.GetValues(), false
}
return nil, false
}
func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
if expr.GetOp() == planpb.OpType_Equal {
if expr.GetColumnInfo().GetIsPartitionKey() && keyType == PartitionKey ||
expr.GetColumnInfo().GetIsClusteringKey() && keyType == ClusteringKey {
return []*planpb.GenericValue{expr.Value}, false
}
}
return nil, true
}
func ParseKeysFromExpr(expr *planpb.Expr, keyType KeyType) ([]*planpb.GenericValue, bool) {
var res []*planpb.GenericValue
keyInRange := false
switch expr := expr.GetExpr().(type) {
case *planpb.Expr_BinaryExpr:
res, keyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr, keyType)
case *planpb.Expr_UnaryExpr:
res, keyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr, keyType)
case *planpb.Expr_TermExpr:
res, keyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr, keyType)
case *planpb.Expr_UnaryRangeExpr:
res, keyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr, keyType)
}
return res, keyInRange
}
func ParseKeys(expr *planpb.Expr, kType KeyType) []*planpb.GenericValue {
res, keyInRange := ParseKeysFromExpr(expr, kType)
if keyInRange {
res = nil
}
return res
}
type PlanRange struct {
lower *planpb.GenericValue
upper *planpb.GenericValue
includeLower bool
includeUpper bool
}
func (planRange *PlanRange) ToIntRange() *IntRange {
iRange := &IntRange{}
if planRange.lower == nil {
iRange.lower = math.MinInt64
iRange.includeLower = false
} else {
iRange.lower = planRange.lower.GetInt64Val()
iRange.includeLower = planRange.includeLower
}
if planRange.upper == nil {
iRange.upper = math.MaxInt64
iRange.includeUpper = false
} else {
iRange.upper = planRange.upper.GetInt64Val()
iRange.includeUpper = planRange.includeUpper
}
return iRange
}
func (planRange *PlanRange) ToStrRange() *StrRange {
sRange := &StrRange{}
if planRange.lower == nil {
sRange.lower = ""
sRange.includeLower = false
} else {
sRange.lower = planRange.lower.GetStringVal()
sRange.includeLower = planRange.includeLower
}
if planRange.upper == nil {
sRange.upper = ""
sRange.includeUpper = false
} else {
sRange.upper = planRange.upper.GetStringVal()
sRange.includeUpper = planRange.includeUpper
}
return sRange
}
type IntRange struct {
lower int64
upper int64
includeLower bool
includeUpper bool
}
func NewIntRange(l int64, r int64, includeL bool, includeR bool) *IntRange {
return &IntRange{
lower: l,
upper: r,
includeLower: includeL,
includeUpper: includeR,
}
}
func IntRangeOverlap(range1 *IntRange, range2 *IntRange) bool {
var leftBound int64
if range1.lower < range2.lower {
leftBound = range2.lower
} else {
leftBound = range1.lower
}
var rightBound int64
if range1.upper < range2.upper {
rightBound = range1.upper
} else {
rightBound = range2.upper
}
return leftBound <= rightBound
}
type StrRange struct {
lower string
upper string
includeLower bool
includeUpper bool
}
func NewStrRange(l string, r string, includeL bool, includeR bool) *StrRange {
return &StrRange{
lower: l,
upper: r,
includeLower: includeL,
includeUpper: includeR,
}
}
func StrRangeOverlap(range1 *StrRange, range2 *StrRange) bool {
var leftBound string
if range1.lower < range2.lower {
leftBound = range2.lower
} else {
leftBound = range1.lower
}
var rightBound string
if range1.upper < range2.upper || range2.upper == "" {
rightBound = range1.upper
} else {
rightBound = range2.upper
}
return leftBound <= rightBound
}
/*
principles for range parsing
1. no handling unary expr like 'NOT'
2. no handling 'or' expr, no matter on clusteringKey or not, just terminate all possible prune
3. for any unlogical 'and' expr, we check and terminate upper away
4. no handling Term and Range at the same time
*/
func ParseRanges(expr *planpb.Expr, kType KeyType) ([]*PlanRange, bool) {
var res []*PlanRange
matchALL := true
switch expr := expr.GetExpr().(type) {
case *planpb.Expr_BinaryExpr:
res, matchALL = ParseRangesFromBinaryExpr(expr.BinaryExpr, kType)
case *planpb.Expr_UnaryRangeExpr:
res, matchALL = ParseRangesFromUnaryRangeExpr(expr.UnaryRangeExpr, kType)
case *planpb.Expr_TermExpr:
res, matchALL = ParseRangesFromTermExpr(expr.TermExpr, kType)
case *planpb.Expr_UnaryExpr:
res, matchALL = nil, true
// we don't handle NOT operation, just consider as unable_to_parse_range
}
return res, matchALL
}
func ParseRangesFromBinaryExpr(expr *planpb.BinaryExpr, kType KeyType) ([]*PlanRange, bool) {
if expr.Op == planpb.BinaryExpr_LogicalOr {
return nil, true
}
_, leftIsTerm := expr.GetLeft().GetExpr().(*planpb.Expr_TermExpr)
_, rightIsTerm := expr.GetRight().GetExpr().(*planpb.Expr_TermExpr)
if leftIsTerm || rightIsTerm {
// either of lower or upper is term query like x IN [1,2,3]
// we will terminate the prune process
return nil, true
}
leftRanges, leftALL := ParseRanges(expr.Left, kType)
rightRanges, rightALL := ParseRanges(expr.Right, kType)
if leftALL && rightALL {
return nil, true
} else if leftALL && !rightALL {
return rightRanges, rightALL
} else if rightALL && !leftALL {
return leftRanges, leftALL
}
// only unary ranges or further binary ranges are lower
// calculate the intersection and return the resulting ranges
// it's expected that only single range can be returned from lower and upper child
if len(leftRanges) != 1 || len(rightRanges) != 1 {
return nil, true
}
intersected := Intersect(leftRanges[0], rightRanges[0])
matchALL := intersected == nil
return []*PlanRange{intersected}, matchALL
}
func ParseRangesFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, kType KeyType) ([]*PlanRange, bool) {
if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey ||
expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey {
switch expr.GetOp() {
case planpb.OpType_Equal:
{
return []*PlanRange{
{
lower: expr.Value,
upper: expr.Value,
includeLower: true,
includeUpper: true,
},
}, false
}
case planpb.OpType_GreaterThan:
{
return []*PlanRange{
{
lower: expr.Value,
upper: nil,
includeLower: false,
includeUpper: false,
},
}, false
}
case planpb.OpType_GreaterEqual:
{
return []*PlanRange{
{
lower: expr.Value,
upper: nil,
includeLower: true,
includeUpper: false,
},
}, false
}
case planpb.OpType_LessThan:
{
return []*PlanRange{
{
lower: nil,
upper: expr.Value,
includeLower: false,
includeUpper: false,
},
}, false
}
case planpb.OpType_LessEqual:
{
return []*PlanRange{
{
lower: nil,
upper: expr.Value,
includeLower: false,
includeUpper: true,
},
}, false
}
}
}
return nil, true
}
func ParseRangesFromTermExpr(expr *planpb.TermExpr, kType KeyType) ([]*PlanRange, bool) {
if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey ||
expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey {
res := make([]*PlanRange, 0)
for _, value := range expr.GetValues() {
res = append(res, &PlanRange{
lower: value,
upper: value,
includeLower: true,
includeUpper: true,
})
}
return res, false
}
return nil, true
}
var minusInfiniteInt = &planpb.GenericValue{
Val: &planpb.GenericValue_Int64Val{
Int64Val: math.MinInt64,
},
}
var positiveInfiniteInt = &planpb.GenericValue{
Val: &planpb.GenericValue_Int64Val{
Int64Val: math.MaxInt64,
},
}
var minStrVal = &planpb.GenericValue{
Val: &planpb.GenericValue_StringVal{
StringVal: "",
},
}
var maxStrVal = &planpb.GenericValue{}
func complementPlanRange(pr *PlanRange, dataType schemapb.DataType) *PlanRange {
if dataType == schemapb.DataType_Int64 {
if pr.lower == nil {
pr.lower = minusInfiniteInt
}
if pr.upper == nil {
pr.upper = positiveInfiniteInt
}
} else {
if pr.lower == nil {
pr.lower = minStrVal
}
if pr.upper == nil {
pr.upper = maxStrVal
}
}
return pr
}
func GetCommonDataType(a *PlanRange, b *PlanRange) schemapb.DataType {
var bound *planpb.GenericValue
if a.lower != nil {
bound = a.lower
} else if a.upper != nil {
bound = a.upper
}
if bound == nil {
if b.lower != nil {
bound = b.lower
} else if b.upper != nil {
bound = b.upper
}
}
if bound == nil {
return schemapb.DataType_None
}
switch bound.Val.(type) {
case *planpb.GenericValue_Int64Val:
{
return schemapb.DataType_Int64
}
case *planpb.GenericValue_StringVal:
{
return schemapb.DataType_VarChar
}
}
return schemapb.DataType_None
}
func Intersect(a *PlanRange, b *PlanRange) *PlanRange {
dataType := GetCommonDataType(a, b)
complementPlanRange(a, dataType)
complementPlanRange(b, dataType)
// Check if 'a' and 'b' non-overlapping at all
rightBound := minGenericValue(a.upper, b.upper)
leftBound := maxGenericValue(a.lower, b.lower)
if compareGenericValue(leftBound, rightBound) > 0 {
return nil
}
// Check if 'a' range ends exactly where 'b' range starts
if !a.includeUpper && !b.includeLower && (compareGenericValue(a.upper, b.lower) == 0) {
return nil
}
// Check if 'b' range ends exactly where 'a' range starts
if !b.includeUpper && !a.includeLower && (compareGenericValue(b.upper, a.lower) == 0) {
return nil
}
return &PlanRange{
lower: leftBound,
upper: rightBound,
includeLower: a.includeLower || b.includeLower,
includeUpper: a.includeUpper || b.includeUpper,
}
}
func compareGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) int64 {
if right == nil || left == nil {
return -1
}
switch left.Val.(type) {
case *planpb.GenericValue_Int64Val:
if left.GetInt64Val() == right.GetInt64Val() {
return 0
} else if left.GetInt64Val() < right.GetInt64Val() {
return -1
} else {
return 1
}
case *planpb.GenericValue_StringVal:
if right.Val == nil {
return -1
}
return int64(strings.Compare(left.GetStringVal(), right.GetStringVal()))
}
return 0
}
func minGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue {
if compareGenericValue(left, right) < 0 {
return left
}
return right
}
func maxGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue {
if compareGenericValue(left, right) >= 0 {
return left
}
return right
}

View File

@ -0,0 +1,279 @@
package exprutil
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestParsePartitionKeys(t *testing.T) {
prefix := "TestParsePartitionKeys"
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type,
"int64_field", false, 8)
partitionKeyField := &schemapb.FieldSchema{
Name: "partition_key_field",
DataType: schemapb.DataType_Int64,
IsPartitionKey: true,
}
schema.Fields = append(schema.Fields, partitionKeyField)
fieldID := common.StartOfUserFieldID
for _, field := range schema.Fields {
field.FieldID = int64(fieldID)
fieldID++
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
queryInfo := &planpb.QueryInfo{
Topk: 10,
MetricType: "L2",
SearchParams: "",
RoundDecimal: -1,
}
type testCase struct {
name string
expr string
expected int
validPartitionKeys []int64
invalidPartitionKeys []int64
}
cases := []testCase{
{
name: "binary_expr_and with term",
expr: "partition_key_field in [7, 8] && int64_field >= 10",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with equal",
expr: "partition_key_field == 7 && int64_field >= 10",
expected: 1,
validPartitionKeys: []int64{7},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with term2",
expr: "partition_key_field in [7, 8] && int64_field == 10",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{10},
},
{
name: "binary_expr_and with partition key in range",
expr: "partition_key_field in [7, 8] && partition_key_field > 9",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{9},
},
{
name: "binary_expr_and with partition key in range2",
expr: "int64_field == 10 && partition_key_field > 9",
expected: 0,
validPartitionKeys: []int64{},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with term and not",
expr: "partition_key_field in [7, 8] && partition_key_field not in [10, 20]",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{10, 20},
},
{
name: "binary_expr_or with term and not",
expr: "partition_key_field in [7, 8] or partition_key_field not in [10, 20]",
expected: 0,
validPartitionKeys: []int64{},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_or with term and not 2",
expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{10, 20},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// test search plan
searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo)
assert.NoError(t, err)
expr, err := ParseExprFromPlan(searchPlan)
assert.NoError(t, err)
partitionKeys := ParseKeys(expr, PartitionKey)
assert.Equal(t, tc.expected, len(partitionKeys))
for _, key := range partitionKeys {
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
assert.Contains(t, tc.validPartitionKeys, int64Val)
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
}
// test query plan
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, tc.expr)
assert.NoError(t, err)
expr, err = ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
partitionKeys = ParseKeys(expr, PartitionKey)
assert.Equal(t, tc.expected, len(partitionKeys))
for _, key := range partitionKeys {
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
assert.Contains(t, tc.validPartitionKeys, int64Val)
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
}
})
}
}
func TestParseIntRanges(t *testing.T) {
prefix := "TestParseRanges"
clusterKeyField := "cluster_key_field"
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type,
"int64_field", false, 8)
clusterKeyFieldSchema := &schemapb.FieldSchema{
Name: clusterKeyField,
DataType: schemapb.DataType_Int64,
IsClusteringKey: true,
}
schema.Fields = append(schema.Fields, clusterKeyFieldSchema)
fieldID := common.StartOfUserFieldID
for _, field := range schema.Fields {
field.FieldID = int64(fieldID)
fieldID++
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
// test query plan
{
expr := "cluster_key_field > 50"
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
assert.NoError(t, err)
planExpr, err := ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
assert.False(t, matchALL)
assert.Equal(t, 1, len(parsedRanges))
range0 := parsedRanges[0]
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50))
assert.Nil(t, range0.upper)
assert.Equal(t, range0.includeLower, false)
assert.Equal(t, range0.includeUpper, false)
}
// test binary query plan
{
expr := "cluster_key_field > 50 and cluster_key_field <= 100"
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
assert.NoError(t, err)
planExpr, err := ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
assert.False(t, matchALL)
assert.Equal(t, 1, len(parsedRanges))
range0 := parsedRanges[0]
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50))
assert.Equal(t, false, range0.includeLower)
assert.Equal(t, true, range0.includeUpper)
}
// test binary query plan
{
expr := "cluster_key_field >= 50 and cluster_key_field < 100"
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
assert.NoError(t, err)
planExpr, err := ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
assert.False(t, matchALL)
assert.Equal(t, 1, len(parsedRanges))
range0 := parsedRanges[0]
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50))
assert.Equal(t, true, range0.includeLower)
assert.Equal(t, false, range0.includeUpper)
}
// test binary query plan
{
expr := "cluster_key_field in [100]"
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
assert.NoError(t, err)
planExpr, err := ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
assert.False(t, matchALL)
assert.Equal(t, 1, len(parsedRanges))
range0 := parsedRanges[0]
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(100))
assert.Equal(t, true, range0.includeLower)
assert.Equal(t, true, range0.includeUpper)
}
}
func TestParseStrRanges(t *testing.T) {
prefix := "TestParseRanges"
clusterKeyField := "cluster_key_field"
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type,
"int64_field", false, 8)
clusterKeyFieldSchema := &schemapb.FieldSchema{
Name: clusterKeyField,
DataType: schemapb.DataType_VarChar,
IsClusteringKey: true,
}
schema.Fields = append(schema.Fields, clusterKeyFieldSchema)
fieldID := common.StartOfUserFieldID
for _, field := range schema.Fields {
field.FieldID = int64(fieldID)
fieldID++
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
// test query plan
{
expr := "cluster_key_field >= \"aaa\""
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
assert.NoError(t, err)
planExpr, err := ParseExprFromPlan(queryPlan)
assert.NoError(t, err)
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
assert.False(t, matchALL)
assert.Equal(t, 1, len(parsedRanges))
range0 := parsedRanges[0]
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_StringVal).StringVal, "aaa")
assert.Nil(t, range0.upper)
assert.Equal(t, range0.includeLower, true)
assert.Equal(t, range0.includeUpper, false)
}
}

View File

@ -0,0 +1,90 @@
package testutil
import (
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
)
const (
testMaxVarCharLength = 100
)
func ConstructCollectionSchemaWithKeys(collectionName string,
fieldName2DataType map[string]schemapb.DataType,
primaryFieldName string,
partitionKeyFieldName string,
clusteringKeyFieldName string,
autoID bool,
dim int,
) *schemapb.CollectionSchema {
schema := ConstructCollectionSchemaByDataType(collectionName,
fieldName2DataType,
primaryFieldName,
autoID,
dim)
for _, field := range schema.Fields {
if field.Name == partitionKeyFieldName {
field.IsPartitionKey = true
}
if field.Name == clusteringKeyFieldName {
field.IsClusteringKey = true
}
}
return schema
}
func isVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_FloatVector ||
dataType == schemapb.DataType_BinaryVector ||
dataType == schemapb.DataType_Float16Vector ||
dataType == schemapb.DataType_BFloat16Vector
}
func ConstructCollectionSchemaByDataType(collectionName string,
fieldName2DataType map[string]schemapb.DataType,
primaryFieldName string,
autoID bool,
dim int,
) *schemapb.CollectionSchema {
fieldsSchema := make([]*schemapb.FieldSchema, 0)
fieldIdx := int64(0)
for fieldName, dataType := range fieldName2DataType {
fieldSchema := &schemapb.FieldSchema{
Name: fieldName,
DataType: dataType,
FieldID: fieldIdx,
}
fieldIdx += 1
if isVectorType(dataType) {
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
}
}
if dataType == schemapb.DataType_VarChar {
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: strconv.Itoa(testMaxVarCharLength),
},
}
}
if fieldName == primaryFieldName {
fieldSchema.IsPrimaryKey = true
fieldSchema.AutoID = autoID
}
fieldsSchema = append(fieldsSchema, fieldSchema)
}
return &schemapb.CollectionSchema{
Name: collectionName,
Fields: fieldsSchema,
}
}

View File

@ -119,3 +119,12 @@ func convertToArrowType(dataType schemapb.DataType) (arrow.DataType, error) {
return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", dataType.String())
}
}
func GetClusteringKeyField(fields []*schemapb.FieldSchema) *schemapb.FieldSchema {
for _, field := range fields {
if field.IsClusteringKey {
return field
}
}
return nil
}

View File

@ -88,6 +88,9 @@ const (
// SegmentIndexPath storage path const for segment index files.
SegmentIndexPath = `index_files`
// PartitionStatsPath storage path const for partition stats files
PartitionStatsPath = `part_stats`
)
// Search, Index parameter keys

View File

@ -2039,6 +2039,7 @@ type queryNodeConfig struct {
FlowGraphMaxParallelism ParamItem `refreshable:"false"`
MemoryIndexLoadPredictMemoryUsageFactor ParamItem `refreshable:"true"`
EnableSegmentPrune ParamItem `refreshable:"false"`
}
func (p *queryNodeConfig) init(base *BaseTable) {
@ -2512,6 +2513,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to
Doc: "memory usage prediction factor for memory index loaded",
}
p.MemoryIndexLoadPredictMemoryUsageFactor.Init(base.mgr)
p.EnableSegmentPrune = ParamItem{
Key: "queryNode.enableSegmentPrune",
Version: "2.3.4",
DefaultValue: "false",
Doc: "use partition prune function on shard delegator",
}
p.EnableSegmentPrune.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////