mirror of https://github.com/milvus-io/milvus.git
parent
c27db43ba7
commit
c3264ca3e3
|
@ -370,6 +370,7 @@ queryNode:
|
|||
serverMaxRecvSize: 268435456
|
||||
clientMaxSendSize: 268435456
|
||||
clientMaxRecvSize: 536870912
|
||||
enableSegmentPrune: false # use partition prune function on shard delegator
|
||||
|
||||
indexCoord:
|
||||
bindIndexNodeMode:
|
||||
|
|
|
@ -42,13 +42,14 @@ func (v *ParserVisitor) translateIdentifier(identifier string) (*ExprWithType, e
|
|||
Expr: &planpb.Expr_ColumnExpr{
|
||||
ColumnExpr: &planpb.ColumnExpr{
|
||||
Info: &planpb.ColumnInfo{
|
||||
FieldId: field.FieldID,
|
||||
DataType: field.DataType,
|
||||
IsPrimaryKey: field.IsPrimaryKey,
|
||||
IsAutoID: field.AutoID,
|
||||
NestedPath: nestedPath,
|
||||
IsPartitionKey: field.IsPartitionKey,
|
||||
ElementType: field.GetElementType(),
|
||||
FieldId: field.FieldID,
|
||||
DataType: field.DataType,
|
||||
IsPrimaryKey: field.IsPrimaryKey,
|
||||
IsAutoID: field.AutoID,
|
||||
NestedPath: nestedPath,
|
||||
IsPartitionKey: field.IsPartitionKey,
|
||||
IsClusteringKey: field.IsClusteringKey,
|
||||
ElementType: field.GetElementType(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -71,6 +71,7 @@ message ColumnInfo {
|
|||
repeated string nested_path = 5;
|
||||
bool is_partition_key = 6;
|
||||
schema.DataType element_type = 7;
|
||||
bool is_clustering_key = 8;
|
||||
}
|
||||
|
||||
message ColumnExpr {
|
||||
|
|
|
@ -1,114 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
)
|
||||
|
||||
func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) {
|
||||
node := plan.GetNode()
|
||||
|
||||
if node == nil {
|
||||
return nil, errors.New("can't get expr from empty plan node")
|
||||
}
|
||||
|
||||
var expr *planpb.Expr
|
||||
switch node := node.(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
expr = node.VectorAnns.GetPredicates()
|
||||
case *planpb.PlanNode_Query:
|
||||
expr = node.Query.GetPredicates()
|
||||
default:
|
||||
return nil, errors.New("unsupported plan node type")
|
||||
}
|
||||
|
||||
return expr, nil
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr) ([]*planpb.GenericValue, bool) {
|
||||
leftRes, leftInRange := ParsePartitionKeysFromExpr(expr.Left)
|
||||
RightRes, rightInRange := ParsePartitionKeysFromExpr(expr.Right)
|
||||
|
||||
if expr.Op == planpb.BinaryExpr_LogicalAnd {
|
||||
// case: partition_key_field in [7, 8] && partition_key > 8
|
||||
if len(leftRes)+len(RightRes) > 0 {
|
||||
leftRes = append(leftRes, RightRes...)
|
||||
return leftRes, false
|
||||
}
|
||||
|
||||
// case: other_field > 10 && partition_key_field > 8
|
||||
return nil, leftInRange || rightInRange
|
||||
}
|
||||
|
||||
if expr.Op == planpb.BinaryExpr_LogicalOr {
|
||||
// case: partition_key_field in [7, 8] or partition_key > 8
|
||||
if leftInRange || rightInRange {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// case: partition_key_field in [7, 8] or other_field > 10
|
||||
leftRes = append(leftRes, RightRes...)
|
||||
return leftRes, false
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr) ([]*planpb.GenericValue, bool) {
|
||||
res, partitionInRange := ParsePartitionKeysFromExpr(expr.GetChild())
|
||||
if expr.Op == planpb.UnaryExpr_Not {
|
||||
// case: partition_key_field not in [7, 8]
|
||||
if len(res) != 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// case: other_field not in [10]
|
||||
return nil, partitionInRange
|
||||
}
|
||||
|
||||
// UnaryOp only includes "Not" for now
|
||||
return res, partitionInRange
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr) ([]*planpb.GenericValue, bool) {
|
||||
if expr.GetColumnInfo().GetIsPartitionKey() {
|
||||
return expr.GetValues(), false
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr) ([]*planpb.GenericValue, bool) {
|
||||
if expr.GetColumnInfo().GetIsPartitionKey() && expr.GetOp() == planpb.OpType_Equal {
|
||||
return []*planpb.GenericValue{expr.Value}, false
|
||||
}
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromExpr(expr *planpb.Expr) ([]*planpb.GenericValue, bool) {
|
||||
var res []*planpb.GenericValue
|
||||
partitionKeyInRange := false
|
||||
switch expr := expr.GetExpr().(type) {
|
||||
case *planpb.Expr_BinaryExpr:
|
||||
res, partitionKeyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr)
|
||||
case *planpb.Expr_UnaryExpr:
|
||||
res, partitionKeyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr)
|
||||
case *planpb.Expr_TermExpr:
|
||||
res, partitionKeyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr)
|
||||
case *planpb.Expr_UnaryRangeExpr:
|
||||
res, partitionKeyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr)
|
||||
}
|
||||
|
||||
return res, partitionKeyInRange
|
||||
}
|
||||
|
||||
func ParsePartitionKeys(expr *planpb.Expr) []*planpb.GenericValue {
|
||||
res, partitionKeyInRange := ParsePartitionKeysFromExpr(expr)
|
||||
if partitionKeyInRange {
|
||||
res = nil
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
|
@ -1,143 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestParsePartitionKeys(t *testing.T) {
|
||||
prefix := "TestParsePartitionKeys"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
fieldName2Type := make(map[string]schemapb.DataType)
|
||||
fieldName2Type["int64_field"] = schemapb.DataType_Int64
|
||||
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
|
||||
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
|
||||
partitionKeyField := &schemapb.FieldSchema{
|
||||
Name: "partition_key_field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPartitionKey: true,
|
||||
}
|
||||
schema.Fields = append(schema.Fields, partitionKeyField)
|
||||
|
||||
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
fieldID := common.StartOfUserFieldID
|
||||
for _, field := range schema.Fields {
|
||||
field.FieldID = int64(fieldID)
|
||||
fieldID++
|
||||
}
|
||||
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
Topk: 10,
|
||||
MetricType: "L2",
|
||||
SearchParams: "",
|
||||
RoundDecimal: -1,
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
expr string
|
||||
expected int
|
||||
validPartitionKeys []int64
|
||||
invalidPartitionKeys []int64
|
||||
}
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "binary_expr_and with term",
|
||||
expr: "partition_key_field in [7, 8] && int64_field >= 10",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with equal",
|
||||
expr: "partition_key_field == 7 && int64_field >= 10",
|
||||
expected: 1,
|
||||
validPartitionKeys: []int64{7},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with term2",
|
||||
expr: "partition_key_field in [7, 8] && int64_field == 10",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{10},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with partition key in range",
|
||||
expr: "partition_key_field in [7, 8] && partition_key_field > 9",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{9},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with partition key in range2",
|
||||
expr: "int64_field == 10 && partition_key_field > 9",
|
||||
expected: 0,
|
||||
validPartitionKeys: []int64{},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with term and not",
|
||||
expr: "partition_key_field in [7, 8] && partition_key_field not in [10, 20]",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{10, 20},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_or with term and not",
|
||||
expr: "partition_key_field in [7, 8] or partition_key_field not in [10, 20]",
|
||||
expected: 0,
|
||||
validPartitionKeys: []int64{},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_or with term and not 2",
|
||||
expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{10, 20},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// test search plan
|
||||
searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo)
|
||||
assert.NoError(t, err)
|
||||
expr, err := ParseExprFromPlan(searchPlan)
|
||||
assert.NoError(t, err)
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
assert.Equal(t, tc.expected, len(partitionKeys))
|
||||
for _, key := range partitionKeys {
|
||||
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
|
||||
assert.Contains(t, tc.validPartitionKeys, int64Val)
|
||||
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
|
||||
}
|
||||
|
||||
// test query plan
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, tc.expr)
|
||||
assert.NoError(t, err)
|
||||
expr, err = ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
partitionKeys = ParsePartitionKeys(expr)
|
||||
assert.Equal(t, tc.expected, len(partitionKeys))
|
||||
for _, key := range partitionKeys {
|
||||
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
|
||||
assert.Contains(t, tc.validPartitionKeys, int64Val)
|
||||
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
@ -92,12 +93,12 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
|
|||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||
|
||||
if t.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(plan)
|
||||
expr, err := exprutil.ParseExprFromPlan(plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to parse expr", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
|
||||
if err != nil {
|
||||
log.Warn("failed to assign partition keys", zap.Error(err))
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
|
@ -356,11 +357,11 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
|
|||
|
||||
// optimize query when partitionKey on
|
||||
if dr.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(plan)
|
||||
expr, err := exprutil.ParseExprFromPlan(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, dr.req.GetDbName(), dr.req.GetCollectionName(), partitionKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
|
@ -358,11 +359,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
if !t.reQuery {
|
||||
partitionNames := t.request.GetPartitionNames()
|
||||
if t.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(t.plan)
|
||||
expr, err := exprutil.ParseExprFromPlan(t.plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -20,6 +20,8 @@ package delegator
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -42,6 +44,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
|
@ -49,6 +52,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
|
@ -115,7 +119,9 @@ type shardDelegator struct {
|
|||
tsCond *sync.Cond
|
||||
latestTsafe *atomic.Uint64
|
||||
// queryHook
|
||||
queryHook optimizers.QueryHook
|
||||
queryHook optimizers.QueryHook
|
||||
partitionStats map[UniqueID]*storage.PartitionStatsSnapshot
|
||||
chunkManager storage.ChunkManager
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
|
@ -203,6 +209,9 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
|||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
|
||||
PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed, PruneInfo{filterRatio: defaultFilterRatio})
|
||||
}
|
||||
|
||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||
if err != nil {
|
||||
|
@ -485,12 +494,17 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
|
|||
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
|
||||
}
|
||||
defer sd.distribution.Unpin(version)
|
||||
existPartitions := sd.collection.GetPartitions()
|
||||
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
|
||||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||
})
|
||||
if req.Req.IgnoreGrowing {
|
||||
growing = []SegmentEntry{}
|
||||
} else {
|
||||
existPartitions := sd.collection.GetPartitions()
|
||||
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
|
||||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||
})
|
||||
}
|
||||
|
||||
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
|
||||
PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{defaultFilterRatio})
|
||||
}
|
||||
|
||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||
|
@ -774,10 +788,72 @@ func (sd *shardDelegator) Close() {
|
|||
sd.lifetime.Wait()
|
||||
}
|
||||
|
||||
// As partition stats is an optimization for search/query which is not mandatory for milvus instance,
|
||||
// loading partitionStats will be a try-best process and will skip+logError when running across errors rather than
|
||||
// return an error status
|
||||
func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs ...UniqueID) {
|
||||
var partsToReload []UniqueID
|
||||
if len(partIDs) > 0 {
|
||||
partsToReload = partIDs
|
||||
} else {
|
||||
partsToReload = append(partsToReload, sd.collection.GetPartitions()...)
|
||||
}
|
||||
|
||||
colID := sd.Collection()
|
||||
findMaxVersion := func(filePaths []string) (int64, string) {
|
||||
maxVersion := int64(-1)
|
||||
maxVersionFilePath := ""
|
||||
for _, filePath := range filePaths {
|
||||
versionStr := path.Base(filePath)
|
||||
version, err := strconv.ParseInt(versionStr, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if version > maxVersion {
|
||||
maxVersion = version
|
||||
maxVersionFilePath = filePath
|
||||
}
|
||||
}
|
||||
return maxVersion, maxVersionFilePath
|
||||
}
|
||||
for _, partID := range partsToReload {
|
||||
idPath := metautil.JoinIDPath(colID, partID)
|
||||
idPath = path.Join(idPath, sd.vchannelName)
|
||||
statsPathPrefix := path.Join(sd.chunkManager.RootPath(), common.PartitionStatsPath, idPath)
|
||||
filePaths, _, err := sd.chunkManager.ListWithPrefix(ctx, statsPathPrefix, true)
|
||||
if err != nil {
|
||||
log.Error("Skip initializing partition stats for failing to list files with prefix",
|
||||
zap.String("statsPathPrefix", statsPathPrefix))
|
||||
continue
|
||||
}
|
||||
maxVersion, maxVersionFilePath := findMaxVersion(filePaths)
|
||||
if maxVersion < 0 {
|
||||
log.Info("failed to find valid partition stats file for partition", zap.Int64("partitionID", partID))
|
||||
continue
|
||||
}
|
||||
partStats, exists := sd.partitionStats[partID]
|
||||
if !exists || (exists && partStats.GetVersion() < maxVersion) {
|
||||
statsBytes, err := sd.chunkManager.Read(ctx, maxVersionFilePath)
|
||||
if err != nil {
|
||||
log.Error("failed to read stats file from object storage", zap.String("path", maxVersionFilePath))
|
||||
continue
|
||||
}
|
||||
partStats, err := storage.DeserializePartitionsStatsSnapshot(statsBytes)
|
||||
if err != nil {
|
||||
log.Error("failed to parse partition stats from bytes", zap.Int("bytes_length", len(statsBytes)))
|
||||
continue
|
||||
}
|
||||
sd.partitionStats[partID] = partStats
|
||||
partStats.SetVersion(maxVersion)
|
||||
log.Info("Updated partitionStats for partition", zap.Int64("partitionID", partID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
|
||||
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
|
||||
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
|
||||
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook,
|
||||
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
|
||||
) (ShardDelegator, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", replicaID),
|
||||
|
@ -812,6 +888,8 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
loader: loader,
|
||||
factory: factory,
|
||||
queryHook: queryHook,
|
||||
chunkManager: chunkManager,
|
||||
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
|
||||
}
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
|
@ -819,5 +897,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
go sd.watchTSafe()
|
||||
}
|
||||
log.Info("finish build new shardDelegator")
|
||||
sd.maybeReloadPartitionStats(ctx)
|
||||
return sd, nil
|
||||
}
|
||||
|
|
|
@ -475,6 +475,12 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
|
|||
// alter distribution
|
||||
sd.distribution.AddDistributions(entries...)
|
||||
|
||||
partStatsToReload := make([]UniqueID, 0)
|
||||
lo.ForEach(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) {
|
||||
partStatsToReload = append(partStatsToReload, info.PartitionID)
|
||||
})
|
||||
sd.maybeReloadPartitionStats(ctx, partStatsToReload...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -850,7 +856,14 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
|
|||
if hasLevel0 {
|
||||
sd.GenerateLevel0DeletionCache()
|
||||
}
|
||||
|
||||
partitionsToReload := make([]UniqueID, 0)
|
||||
lo.ForEach(req.GetSegmentIDs(), func(segmentID int64, _ int) {
|
||||
segment := sd.segmentManager.Get(segmentID)
|
||||
if segment != nil {
|
||||
partitionsToReload = append(partitionsToReload, segment.Partition())
|
||||
}
|
||||
})
|
||||
sd.maybeReloadPartitionStats(ctx, partitionsToReload...)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ package delegator
|
|||
|
||||
import (
|
||||
"context"
|
||||
"path"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
bloom "github.com/bits-and-blooms/bloom/v3"
|
||||
|
@ -41,6 +43,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
@ -58,7 +61,9 @@ type DelegatorDataSuite struct {
|
|||
loader *segments.MockLoader
|
||||
mq *msgstream.MockMsgStream
|
||||
|
||||
delegator *shardDelegator
|
||||
delegator *shardDelegator
|
||||
rootPath string
|
||||
chunkManager storage.ChunkManager
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) SetupSuite() {
|
||||
|
@ -126,16 +131,19 @@ func (s *DelegatorDataSuite) SetupTest() {
|
|||
},
|
||||
},
|
||||
}, &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
PartitionIDs: []int64{1001, 1002},
|
||||
})
|
||||
|
||||
s.mq = &msgstream.MockMsgStream{}
|
||||
|
||||
s.rootPath = s.Suite.T().Name()
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil)
|
||||
}, 10000, nil, s.chunkManager)
|
||||
s.Require().NoError(err)
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
s.Require().True(ok)
|
||||
|
@ -609,7 +617,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
|||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil)
|
||||
}, 10000, nil, nil)
|
||||
s.NoError(err)
|
||||
|
||||
growing0 := segments.NewMockSegment(s.T())
|
||||
|
@ -968,6 +976,78 @@ func (s *DelegatorDataSuite) TestReleaseSegment() {
|
|||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestLoadPartitionStats() {
|
||||
segStats := make(map[UniqueID]storage.SegmentStats)
|
||||
centroid := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}
|
||||
var segID int64 = 1
|
||||
rows := 1990
|
||||
{
|
||||
// p1 stats
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: 1,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Max: storage.NewInt64FieldValue(200),
|
||||
Min: storage.NewInt64FieldValue(100),
|
||||
}
|
||||
fieldStat2 := storage.FieldStats{
|
||||
FieldID: 2,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Max: storage.NewInt64FieldValue(400),
|
||||
Min: storage.NewInt64FieldValue(300),
|
||||
}
|
||||
fieldStat3 := storage.FieldStats{
|
||||
FieldID: 3,
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Centroids: []storage.VectorFieldValue{
|
||||
&storage.FloatVectorFieldValue{
|
||||
Value: centroid,
|
||||
},
|
||||
&storage.FloatVectorFieldValue{
|
||||
Value: centroid,
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
fieldStats = append(fieldStats, fieldStat2)
|
||||
fieldStats = append(fieldStats, fieldStat3)
|
||||
segStats[segID] = *storage.NewSegmentStats(fieldStats, rows)
|
||||
}
|
||||
partitionStats1 := &storage.PartitionStatsSnapshot{
|
||||
SegmentStats: segStats,
|
||||
}
|
||||
statsData1, err := storage.SerializePartitionStatsSnapshot(partitionStats1)
|
||||
s.NoError(err)
|
||||
partitionID1 := int64(1001)
|
||||
idPath1 := metautil.JoinIDPath(s.collectionID, partitionID1)
|
||||
idPath1 = path.Join(idPath1, s.delegator.vchannelName)
|
||||
statsPath1 := path.Join(s.chunkManager.RootPath(), common.PartitionStatsPath, idPath1, strconv.Itoa(1))
|
||||
s.chunkManager.Write(context.Background(), statsPath1, statsData1)
|
||||
defer s.chunkManager.Remove(context.Background(), statsPath1)
|
||||
|
||||
// reload and check partition stats
|
||||
s.delegator.maybeReloadPartitionStats(context.Background())
|
||||
s.Equal(1, len(s.delegator.partitionStats))
|
||||
s.NotNil(s.delegator.partitionStats[partitionID1])
|
||||
p1Stats := s.delegator.partitionStats[partitionID1]
|
||||
s.Equal(int64(1), p1Stats.GetVersion())
|
||||
s.Equal(rows, p1Stats.SegmentStats[segID].NumRows)
|
||||
s.Equal(3, len(p1Stats.SegmentStats[segID].FieldStats))
|
||||
|
||||
// judge vector stats
|
||||
vecFieldStats := p1Stats.SegmentStats[segID].FieldStats[2]
|
||||
s.Equal(2, len(vecFieldStats.Centroids))
|
||||
s.Equal(8, len(vecFieldStats.Centroids[0].GetValue().([]float32)))
|
||||
|
||||
// judge scalar stats
|
||||
fieldStats1 := p1Stats.SegmentStats[segID].FieldStats[0]
|
||||
s.Equal(int64(100), fieldStats1.Min.GetValue().(int64))
|
||||
s.Equal(int64(200), fieldStats1.Max.GetValue().(int64))
|
||||
fieldStats2 := p1Stats.SegmentStats[segID].FieldStats[1]
|
||||
s.Equal(int64(300), fieldStats2.Min.GetValue().(int64))
|
||||
s.Equal(int64(400), fieldStats2.Max.GetValue().(int64))
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestSyncTargetVersion() {
|
||||
for i := int64(0); i < 5; i++ {
|
||||
ms := &segments.MockSegment{}
|
||||
|
|
|
@ -40,6 +40,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
|
@ -64,7 +65,9 @@ type DelegatorSuite struct {
|
|||
loader *segments.MockLoader
|
||||
mq *msgstream.MockMsgStream
|
||||
|
||||
delegator ShardDelegator
|
||||
delegator ShardDelegator
|
||||
chunkManager storage.ChunkManager
|
||||
rootPath string
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) SetupSuite() {
|
||||
|
@ -154,6 +157,11 @@ func (s *DelegatorSuite) SetupTest() {
|
|||
})
|
||||
|
||||
s.mq = &msgstream.MockMsgStream{}
|
||||
s.rootPath = "delegator_test"
|
||||
|
||||
// init chunkManager
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
|
||||
|
||||
var err error
|
||||
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
|
||||
|
@ -161,7 +169,7 @@ func (s *DelegatorSuite) SetupTest() {
|
|||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil)
|
||||
}, 10000, nil, s.chunkManager)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
package delegator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/clustering"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
const defaultFilterRatio float64 = 0.5
|
||||
|
||||
type PruneInfo struct {
|
||||
filterRatio float64
|
||||
}
|
||||
|
||||
func PruneSegments(ctx context.Context,
|
||||
partitionStats map[UniqueID]*storage.PartitionStatsSnapshot,
|
||||
searchReq *internalpb.SearchRequest,
|
||||
queryReq *internalpb.RetrieveRequest,
|
||||
schema *schemapb.CollectionSchema,
|
||||
sealedSegments []SnapshotItem,
|
||||
info PruneInfo,
|
||||
) {
|
||||
log := log.Ctx(ctx)
|
||||
// 1. calculate filtered segments
|
||||
filteredSegments := make(map[UniqueID]struct{}, 0)
|
||||
clusteringKeyField := typeutil.GetClusteringKeyField(schema.Fields)
|
||||
if clusteringKeyField == nil {
|
||||
return
|
||||
}
|
||||
if searchReq != nil {
|
||||
// parse searched vectors
|
||||
var vectorsHolder commonpb.PlaceholderGroup
|
||||
err := proto.Unmarshal(searchReq.GetPlaceholderGroup(), &vectorsHolder)
|
||||
if err != nil || len(vectorsHolder.GetPlaceholders()) == 0 {
|
||||
return
|
||||
}
|
||||
vectorsBytes := vectorsHolder.GetPlaceholders()[0].GetValues()
|
||||
// parse dim
|
||||
dimStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.DimKey, clusteringKeyField.GetTypeParams())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dimValue, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, partID := range searchReq.GetPartitionIDs() {
|
||||
partStats := partitionStats[partID]
|
||||
FilterSegmentsByVector(partStats, searchReq, vectorsBytes, dimValue, clusteringKeyField, filteredSegments, info.filterRatio)
|
||||
}
|
||||
} else if queryReq != nil {
|
||||
// 0. parse expr from plan
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(queryReq.GetSerializedExprPlan(), &plan)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshall serialized expr from bytes, failed the operation")
|
||||
return
|
||||
}
|
||||
expr, err := exprutil.ParseExprFromPlan(&plan)
|
||||
if err != nil {
|
||||
log.Error("failed to parse expr from plan, failed the operation")
|
||||
return
|
||||
}
|
||||
targetRanges, matchALL := exprutil.ParseRanges(expr, exprutil.ClusteringKey)
|
||||
if matchALL || targetRanges == nil {
|
||||
return
|
||||
}
|
||||
for _, partID := range queryReq.GetPartitionIDs() {
|
||||
partStats := partitionStats[partID]
|
||||
FilterSegmentsOnScalarField(partStats, targetRanges, clusteringKeyField, filteredSegments)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. remove filtered segments from sealed segment list
|
||||
if len(filteredSegments) > 0 {
|
||||
totalSegNum := 0
|
||||
for idx, item := range sealedSegments {
|
||||
newSegments := make([]SegmentEntry, 0)
|
||||
totalSegNum += len(item.Segments)
|
||||
for _, segment := range item.Segments {
|
||||
if _, ok := filteredSegments[segment.SegmentID]; !ok {
|
||||
newSegments = append(newSegments, segment)
|
||||
}
|
||||
}
|
||||
item.Segments = newSegments
|
||||
sealedSegments[idx] = item
|
||||
}
|
||||
log.Debug("Pruned segment for search/query",
|
||||
zap.Int("pruned_segment_num", len(filteredSegments)),
|
||||
zap.Int("total_segment_num", totalSegNum),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type segmentDisStruct struct {
|
||||
segmentID UniqueID
|
||||
distance float32
|
||||
rows int // for keep track of sufficiency of topK
|
||||
}
|
||||
|
||||
func FilterSegmentsByVector(partitionStats *storage.PartitionStatsSnapshot,
|
||||
searchReq *internalpb.SearchRequest,
|
||||
vectorBytes [][]byte,
|
||||
dim int64,
|
||||
keyField *schemapb.FieldSchema,
|
||||
filteredSegments map[UniqueID]struct{},
|
||||
filterRatio float64,
|
||||
) {
|
||||
// 1. calculate vectors' distances
|
||||
neededSegments := make(map[UniqueID]struct{})
|
||||
for _, vecBytes := range vectorBytes {
|
||||
segmentsToSearch := make([]segmentDisStruct, 0)
|
||||
for segId, segStats := range partitionStats.SegmentStats {
|
||||
// here, we do not skip needed segments required by former query vector
|
||||
// meaning that repeated calculation will be carried and the larger the nq is
|
||||
// the more segments have to be included and prune effect will decline
|
||||
// 1. calculate distances from centroids
|
||||
for _, fieldStat := range segStats.FieldStats {
|
||||
if fieldStat.FieldID == keyField.GetFieldID() {
|
||||
if fieldStat.Centroids == nil || len(fieldStat.Centroids) == 0 {
|
||||
neededSegments[segId] = struct{}{}
|
||||
break
|
||||
}
|
||||
var dis []float32
|
||||
var disErr error
|
||||
switch keyField.GetDataType() {
|
||||
case schemapb.DataType_FloatVector:
|
||||
dis, disErr = clustering.CalcVectorDistance(dim, keyField.GetDataType(),
|
||||
vecBytes, fieldStat.Centroids[0].GetValue().([]float32), searchReq.GetMetricType())
|
||||
default:
|
||||
neededSegments[segId] = struct{}{}
|
||||
disErr = merr.WrapErrParameterInvalid(schemapb.DataType_FloatVector, keyField.GetDataType(),
|
||||
"Currently, pruning by cluster only support float_vector type")
|
||||
}
|
||||
// currently, we only support float vector and only one center one segment
|
||||
if disErr != nil {
|
||||
neededSegments[segId] = struct{}{}
|
||||
break
|
||||
}
|
||||
segmentsToSearch = append(segmentsToSearch, segmentDisStruct{
|
||||
segmentID: segId,
|
||||
distance: dis[0],
|
||||
rows: segStats.NumRows,
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// 2. sort the distances
|
||||
switch searchReq.GetMetricType() {
|
||||
case distance.L2:
|
||||
sort.SliceStable(segmentsToSearch, func(i, j int) bool {
|
||||
return segmentsToSearch[i].distance < segmentsToSearch[j].distance
|
||||
})
|
||||
case distance.IP, distance.COSINE:
|
||||
sort.SliceStable(segmentsToSearch, func(i, j int) bool {
|
||||
return segmentsToSearch[i].distance > segmentsToSearch[j].distance
|
||||
})
|
||||
}
|
||||
|
||||
// 3. filtered non-target segments
|
||||
segmentCount := len(segmentsToSearch)
|
||||
targetSegNum := int(float64(segmentCount) * filterRatio)
|
||||
optimizedRowCount := 0
|
||||
// set the last n - targetSegNum as being filtered
|
||||
for i := 0; i < segmentCount; i++ {
|
||||
optimizedRowCount += segmentsToSearch[i].rows
|
||||
neededSegments[segmentsToSearch[i].segmentID] = struct{}{}
|
||||
if int64(optimizedRowCount) >= searchReq.GetTopk() && i >= targetSegNum {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. set not needed segments as removed
|
||||
for segId := range partitionStats.SegmentStats {
|
||||
if _, ok := neededSegments[segId]; !ok {
|
||||
filteredSegments[segId] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FilterSegmentsOnScalarField(partitionStats *storage.PartitionStatsSnapshot,
|
||||
targetRanges []*exprutil.PlanRange,
|
||||
keyField *schemapb.FieldSchema,
|
||||
filteredSegments map[UniqueID]struct{},
|
||||
) {
|
||||
// 1. try to filter segments
|
||||
overlap := func(min storage.ScalarFieldValue, max storage.ScalarFieldValue) bool {
|
||||
for _, tRange := range targetRanges {
|
||||
switch keyField.DataType {
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64:
|
||||
targetRange := tRange.ToIntRange()
|
||||
statRange := exprutil.NewIntRange(min.GetValue().(int64), max.GetValue().(int64), true, true)
|
||||
return exprutil.IntRangeOverlap(targetRange, statRange)
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
targetRange := tRange.ToStrRange()
|
||||
statRange := exprutil.NewStrRange(min.GetValue().(string), max.GetValue().(string), true, true)
|
||||
return exprutil.StrRangeOverlap(targetRange, statRange)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
for segID, segStats := range partitionStats.SegmentStats {
|
||||
for _, fieldStat := range segStats.FieldStats {
|
||||
if keyField.FieldID == fieldStat.FieldID && !overlap(fieldStat.Min, fieldStat.Max) {
|
||||
filteredSegments[segID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,422 @@
|
|||
package delegator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/clustering"
|
||||
"github.com/milvus-io/milvus/internal/util/testutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type SegmentPrunerSuite struct {
|
||||
suite.Suite
|
||||
partitionStats map[UniqueID]*storage.PartitionStatsSnapshot
|
||||
schema *schemapb.CollectionSchema
|
||||
collectionName string
|
||||
primaryFieldName string
|
||||
clusterKeyFieldName string
|
||||
autoID bool
|
||||
targetPartition int64
|
||||
dim int
|
||||
sealedSegments []SnapshotItem
|
||||
}
|
||||
|
||||
func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string,
|
||||
clusterKeyFieldType schemapb.DataType,
|
||||
) {
|
||||
sps.collectionName = "test_segment_prune"
|
||||
sps.primaryFieldName = "pk"
|
||||
sps.clusterKeyFieldName = clusterKeyFieldName
|
||||
sps.autoID = true
|
||||
sps.dim = 8
|
||||
|
||||
fieldName2DataType := make(map[string]schemapb.DataType)
|
||||
fieldName2DataType[sps.primaryFieldName] = schemapb.DataType_Int64
|
||||
fieldName2DataType[sps.clusterKeyFieldName] = clusterKeyFieldType
|
||||
fieldName2DataType["info"] = schemapb.DataType_VarChar
|
||||
fieldName2DataType["age"] = schemapb.DataType_Int32
|
||||
fieldName2DataType["vec"] = schemapb.DataType_FloatVector
|
||||
|
||||
sps.schema = testutil.ConstructCollectionSchemaWithKeys(sps.collectionName,
|
||||
fieldName2DataType,
|
||||
sps.primaryFieldName,
|
||||
"",
|
||||
sps.clusterKeyFieldName,
|
||||
false,
|
||||
sps.dim)
|
||||
|
||||
var clusteringKeyFieldID int64 = 0
|
||||
for _, field := range sps.schema.GetFields() {
|
||||
if field.IsClusteringKey {
|
||||
clusteringKeyFieldID = field.FieldID
|
||||
break
|
||||
}
|
||||
}
|
||||
centroids1 := []storage.VectorFieldValue{
|
||||
&storage.FloatVectorFieldValue{
|
||||
Value: []float32{0.6951474, 0.45225978, 0.51508516, 0.24968886, 0.6085484, 0.964968, 0.32239532, 0.7771577},
|
||||
},
|
||||
}
|
||||
centroids2 := []storage.VectorFieldValue{
|
||||
&storage.FloatVectorFieldValue{
|
||||
Value: []float32{0.12345678, 0.23456789, 0.34567890, 0.45678901, 0.56789012, 0.67890123, 0.78901234, 0.89012345},
|
||||
},
|
||||
}
|
||||
centroids3 := []storage.VectorFieldValue{
|
||||
&storage.FloatVectorFieldValue{
|
||||
Value: []float32{0.98765432, 0.87654321, 0.76543210, 0.65432109, 0.54321098, 0.43210987, 0.32109876, 0.21098765},
|
||||
},
|
||||
}
|
||||
centroids4 := []storage.VectorFieldValue{
|
||||
&storage.FloatVectorFieldValue{
|
||||
Value: []float32{0.11111111, 0.22222222, 0.33333333, 0.44444444, 0.55555555, 0.66666666, 0.77777777, 0.88888888},
|
||||
},
|
||||
}
|
||||
|
||||
// init partition stats
|
||||
// here, for convenience, we set up both min/max and Centroids
|
||||
// into the same struct, in the real user cases, a field stat
|
||||
// can either contain min&&max or centroids
|
||||
segStats := make(map[UniqueID]storage.SegmentStats)
|
||||
switch clusterKeyFieldType {
|
||||
case schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8:
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Min: storage.NewInt64FieldValue(100),
|
||||
Max: storage.NewInt64FieldValue(200),
|
||||
Centroids: centroids1,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[1] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Min: storage.NewInt64FieldValue(100),
|
||||
Max: storage.NewInt64FieldValue(400),
|
||||
Centroids: centroids2,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[2] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Min: storage.NewInt64FieldValue(600),
|
||||
Max: storage.NewInt64FieldValue(900),
|
||||
Centroids: centroids3,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[3] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Min: storage.NewInt64FieldValue(500),
|
||||
Max: storage.NewInt64FieldValue(1000),
|
||||
Centroids: centroids4,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[4] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
default:
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Min: storage.NewStringFieldValue("ab"),
|
||||
Max: storage.NewStringFieldValue("bbc"),
|
||||
Centroids: centroids1,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[1] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Min: storage.NewStringFieldValue("hhh"),
|
||||
Max: storage.NewStringFieldValue("jjx"),
|
||||
Centroids: centroids2,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[2] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Min: storage.NewStringFieldValue("kkk"),
|
||||
Max: storage.NewStringFieldValue("lmn"),
|
||||
Centroids: centroids3,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[3] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
{
|
||||
fieldStats := make([]storage.FieldStats, 0)
|
||||
fieldStat1 := storage.FieldStats{
|
||||
FieldID: clusteringKeyFieldID,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Min: storage.NewStringFieldValue("oo2"),
|
||||
Max: storage.NewStringFieldValue("pptt"),
|
||||
Centroids: centroids4,
|
||||
}
|
||||
fieldStats = append(fieldStats, fieldStat1)
|
||||
segStats[4] = *storage.NewSegmentStats(fieldStats, 80)
|
||||
}
|
||||
}
|
||||
sps.partitionStats = make(map[UniqueID]*storage.PartitionStatsSnapshot)
|
||||
sps.targetPartition = 11111
|
||||
sps.partitionStats[sps.targetPartition] = &storage.PartitionStatsSnapshot{
|
||||
SegmentStats: segStats,
|
||||
}
|
||||
|
||||
sealedSegments := make([]SnapshotItem, 0)
|
||||
item1 := SnapshotItem{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
item2 := SnapshotItem{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
NodeID: 2,
|
||||
SegmentID: 3,
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
SegmentID: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
sealedSegments = append(sealedSegments, item1)
|
||||
sealedSegments = append(sealedSegments, item2)
|
||||
sps.sealedSegments = sealedSegments
|
||||
}
|
||||
|
||||
func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() {
|
||||
sps.SetupForClustering("age", schemapb.DataType_Int32)
|
||||
targetPartitions := make([]UniqueID, 0)
|
||||
targetPartitions = append(targetPartitions, sps.targetPartition)
|
||||
{
|
||||
// test for exact values
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := "age==156"
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(2, len(testSegments[0].Segments))
|
||||
sps.Equal(0, len(testSegments[1].Segments))
|
||||
}
|
||||
{
|
||||
// test for range one expr part
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := "age>=700"
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(0, len(testSegments[0].Segments))
|
||||
sps.Equal(2, len(testSegments[1].Segments))
|
||||
}
|
||||
{
|
||||
// test for unlogical binary range
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := "age>=700 and age<=500"
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(2, len(testSegments[0].Segments))
|
||||
sps.Equal(2, len(testSegments[1].Segments))
|
||||
}
|
||||
{
|
||||
// test for unlogical binary range
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := "age>=500 and age<=550"
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(0, len(testSegments[0].Segments))
|
||||
sps.Equal(1, len(testSegments[1].Segments))
|
||||
}
|
||||
}
|
||||
|
||||
func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarStrField() {
|
||||
sps.SetupForClustering("info", schemapb.DataType_VarChar)
|
||||
targetPartitions := make([]UniqueID, 0)
|
||||
targetPartitions = append(targetPartitions, sps.targetPartition)
|
||||
{
|
||||
// test for exact str values
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := `info=="rag"`
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(0, len(testSegments[0].Segments))
|
||||
sps.Equal(0, len(testSegments[1].Segments))
|
||||
// there should be no segments fulfilling the info=="rag"
|
||||
}
|
||||
{
|
||||
// test for exact str values
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := `info=="kpl"`
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(0, len(testSegments[0].Segments))
|
||||
sps.Equal(1, len(testSegments[1].Segments))
|
||||
// there should be no segments fulfilling the info=="rag"
|
||||
}
|
||||
{
|
||||
// test for unary str values
|
||||
testSegments := make([]SnapshotItem, len(sps.sealedSegments))
|
||||
copy(testSegments, sps.sealedSegments)
|
||||
exprStr := `info<="less"`
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema)
|
||||
planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr)
|
||||
sps.NoError(err)
|
||||
serializedPlan, _ := proto.Marshal(planNode)
|
||||
queryReq := &internalpb.RetrieveRequest{
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PartitionIDs: targetPartitions,
|
||||
}
|
||||
PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio})
|
||||
sps.Equal(2, len(testSegments[0].Segments))
|
||||
sps.Equal(1, len(testSegments[1].Segments))
|
||||
// there should be no segments fulfilling the info=="rag"
|
||||
}
|
||||
}
|
||||
|
||||
func vector2Placeholder(vectors [][]float32) *commonpb.PlaceholderValue {
|
||||
ph := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: make([][]byte, 0, len(vectors)),
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return ph
|
||||
}
|
||||
|
||||
ph.Type = commonpb.PlaceholderType_FloatVector
|
||||
for _, vector := range vectors {
|
||||
ph.Values = append(ph.Values, clustering.SerializeFloatVector(vector))
|
||||
}
|
||||
return ph
|
||||
}
|
||||
|
||||
func (sps *SegmentPrunerSuite) TestPruneSegmentsByVectorField() {
|
||||
sps.SetupForClustering("vec", schemapb.DataType_FloatVector)
|
||||
|
||||
vector1 := []float32{0.8877872002188053, 0.6131822285635065, 0.8476814632326242, 0.6645877829359371, 0.9962627712600025, 0.8976183052440327, 0.41941169325798844, 0.7554387854258499}
|
||||
vector2 := []float32{0.8644394874390322, 0.023327886647378615, 0.08330118483461302, 0.7068040179963112, 0.6983994910799851, 0.5562075958994153, 0.3288536247938002, 0.07077341010237759}
|
||||
vectors := [][]float32{vector1, vector2}
|
||||
|
||||
phg := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
vector2Placeholder(vectors),
|
||||
},
|
||||
}
|
||||
bs, _ := proto.Marshal(phg)
|
||||
// test for L2 metrics
|
||||
req := &internalpb.SearchRequest{
|
||||
MetricType: "L2",
|
||||
PlaceholderGroup: bs,
|
||||
PartitionIDs: []UniqueID{sps.targetPartition},
|
||||
Topk: 100,
|
||||
}
|
||||
|
||||
PruneSegments(context.TODO(), sps.partitionStats, req, nil, sps.schema, sps.sealedSegments, PruneInfo{0.25})
|
||||
sps.Equal(1, len(sps.sealedSegments[0].Segments))
|
||||
sps.Equal(int64(1), sps.sealedSegments[0].Segments[0].SegmentID)
|
||||
sps.Equal(1, len(sps.sealedSegments[1].Segments))
|
||||
sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID)
|
||||
|
||||
// test for IP metrics
|
||||
req = &internalpb.SearchRequest{
|
||||
MetricType: "IP",
|
||||
PlaceholderGroup: bs,
|
||||
PartitionIDs: []UniqueID{sps.targetPartition},
|
||||
Topk: 100,
|
||||
}
|
||||
|
||||
PruneSegments(context.TODO(), sps.partitionStats, req, nil, sps.schema, sps.sealedSegments, PruneInfo{0.25})
|
||||
sps.Equal(1, len(sps.sealedSegments[0].Segments))
|
||||
sps.Equal(int64(1), sps.sealedSegments[0].Segments[0].SegmentID)
|
||||
sps.Equal(1, len(sps.sealedSegments[1].Segments))
|
||||
sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID)
|
||||
}
|
||||
|
||||
func TestSegmentPrunerSuite(t *testing.T) {
|
||||
suite.Run(t, new(SegmentPrunerSuite))
|
||||
}
|
|
@ -60,7 +60,7 @@ func (suite *ReduceSuite) SetupTest() {
|
|||
msgLength := 100
|
||||
|
||||
suite.rootPath = suite.T().Name()
|
||||
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ func (suite *RetrieveSuite) SetupTest() {
|
|||
msgLength := 100
|
||||
|
||||
suite.rootPath = suite.T().Name()
|
||||
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ func (suite *SegmentLoaderSuite) SetupTest() {
|
|||
// TODO:: cpp chunk manager not support local chunk manager
|
||||
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
|
||||
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
|
||||
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
@ -678,7 +678,7 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() {
|
|||
}
|
||||
|
||||
ctx := context.Background()
|
||||
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
@ -847,7 +847,7 @@ func (suite *SegmentLoaderV2Suite) SetupTest() {
|
|||
// TODO:: cpp chunk manager not support local chunk manager
|
||||
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
|
||||
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
|
||||
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
suite.loader = NewLoaderV2(suite.manager, suite.chunkManager)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
|
|
@ -39,7 +39,7 @@ func (suite *SegmentSuite) SetupTest() {
|
|||
msgLength := 100
|
||||
|
||||
suite.rootPath = suite.T().Name()
|
||||
chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
||||
|
|
|
@ -254,6 +254,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
|||
node.factory,
|
||||
channel.GetSeekPosition().GetTimestamp(),
|
||||
node.queryHook,
|
||||
node.chunkManager,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("failed to create shard delegator", zap.Error(err))
|
||||
|
|
|
@ -116,7 +116,7 @@ func (suite *ServiceSuite) SetupTest() {
|
|||
suite.msgStream = msgstream.NewMockMsgStream(suite.T())
|
||||
// TODO:: cpp chunk manager not support local chunk manager
|
||||
// suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test"))
|
||||
suite.chunkManagerFactory = segments.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManagerFactory = storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.factory.EXPECT().Init(mock.Anything).Return()
|
||||
suite.factory.EXPECT().NewPersistentStorageChunkManager(mock.Anything).Return(suite.chunkManagerFactory.NewPersistentStorageChunkManager(ctx))
|
||||
|
||||
|
|
|
@ -20,6 +20,14 @@ import "encoding/json"
|
|||
|
||||
type SegmentStats struct {
|
||||
FieldStats []FieldStats `json:"fieldStats"`
|
||||
NumRows int
|
||||
}
|
||||
|
||||
func NewSegmentStats(fieldStats []FieldStats, rows int) *SegmentStats {
|
||||
return &SegmentStats{
|
||||
FieldStats: fieldStats,
|
||||
NumRows: rows,
|
||||
}
|
||||
}
|
||||
|
||||
type PartitionStatsSnapshot struct {
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -1247,3 +1248,17 @@ func Min(a, b int64) int64 {
|
|||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func NewTestChunkManagerFactory(params *paramtable.ComponentParam, rootPath string) *ChunkManagerFactory {
|
||||
return NewChunkManagerFactory("minio",
|
||||
RootPath(rootPath),
|
||||
Address(params.MinioCfg.Address.GetValue()),
|
||||
AccessKeyID(params.MinioCfg.AccessKeyID.GetValue()),
|
||||
SecretAccessKeyID(params.MinioCfg.SecretAccessKey.GetValue()),
|
||||
UseSSL(params.MinioCfg.UseSSL.GetAsBool()),
|
||||
BucketName(params.MinioCfg.BucketName.GetValue()),
|
||||
UseIAM(params.MinioCfg.UseIAM.GetAsBool()),
|
||||
CloudProvider(params.MinioCfg.CloudProvider.GetValue()),
|
||||
IAMEndpoint(params.MinioCfg.IAMEndpoint.GetValue()),
|
||||
CreateBucket(true))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
package clustering
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func CalcVectorDistance(dim int64, dataType schemapb.DataType, left []byte, right []float32, metric string) ([]float32, error) {
|
||||
switch dataType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
distance, err := distance.CalcFloatDistance(dim, DeserializeFloatVector(left), right, metric)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return distance, nil
|
||||
// todo support other vector type
|
||||
case schemapb.DataType_BinaryVector:
|
||||
case schemapb.DataType_Float16Vector:
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
default:
|
||||
return nil, merr.ErrParameterInvalid
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func DeserializeFloatVector(data []byte) []float32 {
|
||||
vectorLen := len(data) / 4 // Each float32 occupies 4 bytes
|
||||
fv := make([]float32, vectorLen)
|
||||
|
||||
for i := 0; i < vectorLen; i++ {
|
||||
bits := binary.LittleEndian.Uint32(data[i*4 : (i+1)*4])
|
||||
fv[i] = math.Float32frombits(bits)
|
||||
}
|
||||
|
||||
return fv
|
||||
}
|
||||
|
||||
func SerializeFloatVector(fv []float32) []byte {
|
||||
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
|
||||
buf := make([]byte, 4)
|
||||
for _, f := range fv {
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
|
||||
data = append(data, buf...)
|
||||
}
|
||||
return data
|
||||
}
|
|
@ -0,0 +1,511 @@
|
|||
package exprutil
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
)
|
||||
|
||||
type KeyType int64
|
||||
|
||||
const (
|
||||
PartitionKey KeyType = iota
|
||||
ClusteringKey KeyType = PartitionKey + 1
|
||||
)
|
||||
|
||||
func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) {
|
||||
node := plan.GetNode()
|
||||
|
||||
if node == nil {
|
||||
return nil, errors.New("can't get expr from empty plan node")
|
||||
}
|
||||
|
||||
var expr *planpb.Expr
|
||||
switch node := node.(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
expr = node.VectorAnns.GetPredicates()
|
||||
case *planpb.PlanNode_Query:
|
||||
expr = node.Query.GetPredicates()
|
||||
default:
|
||||
return nil, errors.New("unsupported plan node type")
|
||||
}
|
||||
|
||||
return expr, nil
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
|
||||
leftRes, leftInRange := ParseKeysFromExpr(expr.Left, keyType)
|
||||
rightRes, rightInRange := ParseKeysFromExpr(expr.Right, keyType)
|
||||
|
||||
if expr.Op == planpb.BinaryExpr_LogicalAnd {
|
||||
// case: partition_key_field in [7, 8] && partition_key > 8
|
||||
if len(leftRes)+len(rightRes) > 0 {
|
||||
leftRes = append(leftRes, rightRes...)
|
||||
return leftRes, false
|
||||
}
|
||||
|
||||
// case: other_field > 10 && partition_key_field > 8
|
||||
return nil, leftInRange || rightInRange
|
||||
}
|
||||
|
||||
if expr.Op == planpb.BinaryExpr_LogicalOr {
|
||||
// case: partition_key_field in [7, 8] or partition_key > 8
|
||||
if leftInRange || rightInRange {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// case: partition_key_field in [7, 8] or other_field > 10
|
||||
leftRes = append(leftRes, rightRes...)
|
||||
return leftRes, false
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
|
||||
res, partitionInRange := ParseKeysFromExpr(expr.GetChild(), keyType)
|
||||
if expr.Op == planpb.UnaryExpr_Not {
|
||||
// case: partition_key_field not in [7, 8]
|
||||
if len(res) != 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// case: other_field not in [10]
|
||||
return nil, partitionInRange
|
||||
}
|
||||
|
||||
// UnaryOp only includes "Not" for now
|
||||
return res, partitionInRange
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
|
||||
if keyType == PartitionKey && expr.GetColumnInfo().GetIsPartitionKey() {
|
||||
return expr.GetValues(), false
|
||||
} else if keyType == ClusteringKey && expr.GetColumnInfo().GetIsClusteringKey() {
|
||||
return expr.GetValues(), false
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, keyType KeyType) ([]*planpb.GenericValue, bool) {
|
||||
if expr.GetOp() == planpb.OpType_Equal {
|
||||
if expr.GetColumnInfo().GetIsPartitionKey() && keyType == PartitionKey ||
|
||||
expr.GetColumnInfo().GetIsClusteringKey() && keyType == ClusteringKey {
|
||||
return []*planpb.GenericValue{expr.Value}, false
|
||||
}
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func ParseKeysFromExpr(expr *planpb.Expr, keyType KeyType) ([]*planpb.GenericValue, bool) {
|
||||
var res []*planpb.GenericValue
|
||||
keyInRange := false
|
||||
switch expr := expr.GetExpr().(type) {
|
||||
case *planpb.Expr_BinaryExpr:
|
||||
res, keyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr, keyType)
|
||||
case *planpb.Expr_UnaryExpr:
|
||||
res, keyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr, keyType)
|
||||
case *planpb.Expr_TermExpr:
|
||||
res, keyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr, keyType)
|
||||
case *planpb.Expr_UnaryRangeExpr:
|
||||
res, keyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr, keyType)
|
||||
}
|
||||
|
||||
return res, keyInRange
|
||||
}
|
||||
|
||||
func ParseKeys(expr *planpb.Expr, kType KeyType) []*planpb.GenericValue {
|
||||
res, keyInRange := ParseKeysFromExpr(expr, kType)
|
||||
if keyInRange {
|
||||
res = nil
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type PlanRange struct {
|
||||
lower *planpb.GenericValue
|
||||
upper *planpb.GenericValue
|
||||
includeLower bool
|
||||
includeUpper bool
|
||||
}
|
||||
|
||||
func (planRange *PlanRange) ToIntRange() *IntRange {
|
||||
iRange := &IntRange{}
|
||||
if planRange.lower == nil {
|
||||
iRange.lower = math.MinInt64
|
||||
iRange.includeLower = false
|
||||
} else {
|
||||
iRange.lower = planRange.lower.GetInt64Val()
|
||||
iRange.includeLower = planRange.includeLower
|
||||
}
|
||||
|
||||
if planRange.upper == nil {
|
||||
iRange.upper = math.MaxInt64
|
||||
iRange.includeUpper = false
|
||||
} else {
|
||||
iRange.upper = planRange.upper.GetInt64Val()
|
||||
iRange.includeUpper = planRange.includeUpper
|
||||
}
|
||||
return iRange
|
||||
}
|
||||
|
||||
func (planRange *PlanRange) ToStrRange() *StrRange {
|
||||
sRange := &StrRange{}
|
||||
if planRange.lower == nil {
|
||||
sRange.lower = ""
|
||||
sRange.includeLower = false
|
||||
} else {
|
||||
sRange.lower = planRange.lower.GetStringVal()
|
||||
sRange.includeLower = planRange.includeLower
|
||||
}
|
||||
|
||||
if planRange.upper == nil {
|
||||
sRange.upper = ""
|
||||
sRange.includeUpper = false
|
||||
} else {
|
||||
sRange.upper = planRange.upper.GetStringVal()
|
||||
sRange.includeUpper = planRange.includeUpper
|
||||
}
|
||||
return sRange
|
||||
}
|
||||
|
||||
type IntRange struct {
|
||||
lower int64
|
||||
upper int64
|
||||
includeLower bool
|
||||
includeUpper bool
|
||||
}
|
||||
|
||||
func NewIntRange(l int64, r int64, includeL bool, includeR bool) *IntRange {
|
||||
return &IntRange{
|
||||
lower: l,
|
||||
upper: r,
|
||||
includeLower: includeL,
|
||||
includeUpper: includeR,
|
||||
}
|
||||
}
|
||||
|
||||
func IntRangeOverlap(range1 *IntRange, range2 *IntRange) bool {
|
||||
var leftBound int64
|
||||
if range1.lower < range2.lower {
|
||||
leftBound = range2.lower
|
||||
} else {
|
||||
leftBound = range1.lower
|
||||
}
|
||||
var rightBound int64
|
||||
if range1.upper < range2.upper {
|
||||
rightBound = range1.upper
|
||||
} else {
|
||||
rightBound = range2.upper
|
||||
}
|
||||
return leftBound <= rightBound
|
||||
}
|
||||
|
||||
type StrRange struct {
|
||||
lower string
|
||||
upper string
|
||||
includeLower bool
|
||||
includeUpper bool
|
||||
}
|
||||
|
||||
func NewStrRange(l string, r string, includeL bool, includeR bool) *StrRange {
|
||||
return &StrRange{
|
||||
lower: l,
|
||||
upper: r,
|
||||
includeLower: includeL,
|
||||
includeUpper: includeR,
|
||||
}
|
||||
}
|
||||
|
||||
func StrRangeOverlap(range1 *StrRange, range2 *StrRange) bool {
|
||||
var leftBound string
|
||||
if range1.lower < range2.lower {
|
||||
leftBound = range2.lower
|
||||
} else {
|
||||
leftBound = range1.lower
|
||||
}
|
||||
var rightBound string
|
||||
if range1.upper < range2.upper || range2.upper == "" {
|
||||
rightBound = range1.upper
|
||||
} else {
|
||||
rightBound = range2.upper
|
||||
}
|
||||
return leftBound <= rightBound
|
||||
}
|
||||
|
||||
/*
|
||||
principles for range parsing
|
||||
1. no handling unary expr like 'NOT'
|
||||
2. no handling 'or' expr, no matter on clusteringKey or not, just terminate all possible prune
|
||||
3. for any unlogical 'and' expr, we check and terminate upper away
|
||||
4. no handling Term and Range at the same time
|
||||
*/
|
||||
|
||||
func ParseRanges(expr *planpb.Expr, kType KeyType) ([]*PlanRange, bool) {
|
||||
var res []*PlanRange
|
||||
matchALL := true
|
||||
switch expr := expr.GetExpr().(type) {
|
||||
case *planpb.Expr_BinaryExpr:
|
||||
res, matchALL = ParseRangesFromBinaryExpr(expr.BinaryExpr, kType)
|
||||
case *planpb.Expr_UnaryRangeExpr:
|
||||
res, matchALL = ParseRangesFromUnaryRangeExpr(expr.UnaryRangeExpr, kType)
|
||||
case *planpb.Expr_TermExpr:
|
||||
res, matchALL = ParseRangesFromTermExpr(expr.TermExpr, kType)
|
||||
case *planpb.Expr_UnaryExpr:
|
||||
res, matchALL = nil, true
|
||||
// we don't handle NOT operation, just consider as unable_to_parse_range
|
||||
}
|
||||
return res, matchALL
|
||||
}
|
||||
|
||||
func ParseRangesFromBinaryExpr(expr *planpb.BinaryExpr, kType KeyType) ([]*PlanRange, bool) {
|
||||
if expr.Op == planpb.BinaryExpr_LogicalOr {
|
||||
return nil, true
|
||||
}
|
||||
_, leftIsTerm := expr.GetLeft().GetExpr().(*planpb.Expr_TermExpr)
|
||||
_, rightIsTerm := expr.GetRight().GetExpr().(*planpb.Expr_TermExpr)
|
||||
if leftIsTerm || rightIsTerm {
|
||||
// either of lower or upper is term query like x IN [1,2,3]
|
||||
// we will terminate the prune process
|
||||
return nil, true
|
||||
}
|
||||
leftRanges, leftALL := ParseRanges(expr.Left, kType)
|
||||
rightRanges, rightALL := ParseRanges(expr.Right, kType)
|
||||
if leftALL && rightALL {
|
||||
return nil, true
|
||||
} else if leftALL && !rightALL {
|
||||
return rightRanges, rightALL
|
||||
} else if rightALL && !leftALL {
|
||||
return leftRanges, leftALL
|
||||
}
|
||||
// only unary ranges or further binary ranges are lower
|
||||
// calculate the intersection and return the resulting ranges
|
||||
// it's expected that only single range can be returned from lower and upper child
|
||||
if len(leftRanges) != 1 || len(rightRanges) != 1 {
|
||||
return nil, true
|
||||
}
|
||||
intersected := Intersect(leftRanges[0], rightRanges[0])
|
||||
matchALL := intersected == nil
|
||||
return []*PlanRange{intersected}, matchALL
|
||||
}
|
||||
|
||||
func ParseRangesFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, kType KeyType) ([]*PlanRange, bool) {
|
||||
if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey ||
|
||||
expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey {
|
||||
switch expr.GetOp() {
|
||||
case planpb.OpType_Equal:
|
||||
{
|
||||
return []*PlanRange{
|
||||
{
|
||||
lower: expr.Value,
|
||||
upper: expr.Value,
|
||||
includeLower: true,
|
||||
includeUpper: true,
|
||||
},
|
||||
}, false
|
||||
}
|
||||
case planpb.OpType_GreaterThan:
|
||||
{
|
||||
return []*PlanRange{
|
||||
{
|
||||
lower: expr.Value,
|
||||
upper: nil,
|
||||
includeLower: false,
|
||||
includeUpper: false,
|
||||
},
|
||||
}, false
|
||||
}
|
||||
case planpb.OpType_GreaterEqual:
|
||||
{
|
||||
return []*PlanRange{
|
||||
{
|
||||
lower: expr.Value,
|
||||
upper: nil,
|
||||
includeLower: true,
|
||||
includeUpper: false,
|
||||
},
|
||||
}, false
|
||||
}
|
||||
case planpb.OpType_LessThan:
|
||||
{
|
||||
return []*PlanRange{
|
||||
{
|
||||
lower: nil,
|
||||
upper: expr.Value,
|
||||
includeLower: false,
|
||||
includeUpper: false,
|
||||
},
|
||||
}, false
|
||||
}
|
||||
case planpb.OpType_LessEqual:
|
||||
{
|
||||
return []*PlanRange{
|
||||
{
|
||||
lower: nil,
|
||||
upper: expr.Value,
|
||||
includeLower: false,
|
||||
includeUpper: true,
|
||||
},
|
||||
}, false
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func ParseRangesFromTermExpr(expr *planpb.TermExpr, kType KeyType) ([]*PlanRange, bool) {
|
||||
if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey ||
|
||||
expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey {
|
||||
res := make([]*PlanRange, 0)
|
||||
for _, value := range expr.GetValues() {
|
||||
res = append(res, &PlanRange{
|
||||
lower: value,
|
||||
upper: value,
|
||||
includeLower: true,
|
||||
includeUpper: true,
|
||||
})
|
||||
}
|
||||
return res, false
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
var minusInfiniteInt = &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_Int64Val{
|
||||
Int64Val: math.MinInt64,
|
||||
},
|
||||
}
|
||||
|
||||
var positiveInfiniteInt = &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_Int64Val{
|
||||
Int64Val: math.MaxInt64,
|
||||
},
|
||||
}
|
||||
|
||||
var minStrVal = &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_StringVal{
|
||||
StringVal: "",
|
||||
},
|
||||
}
|
||||
|
||||
var maxStrVal = &planpb.GenericValue{}
|
||||
|
||||
func complementPlanRange(pr *PlanRange, dataType schemapb.DataType) *PlanRange {
|
||||
if dataType == schemapb.DataType_Int64 {
|
||||
if pr.lower == nil {
|
||||
pr.lower = minusInfiniteInt
|
||||
}
|
||||
if pr.upper == nil {
|
||||
pr.upper = positiveInfiniteInt
|
||||
}
|
||||
} else {
|
||||
if pr.lower == nil {
|
||||
pr.lower = minStrVal
|
||||
}
|
||||
if pr.upper == nil {
|
||||
pr.upper = maxStrVal
|
||||
}
|
||||
}
|
||||
|
||||
return pr
|
||||
}
|
||||
|
||||
func GetCommonDataType(a *PlanRange, b *PlanRange) schemapb.DataType {
|
||||
var bound *planpb.GenericValue
|
||||
if a.lower != nil {
|
||||
bound = a.lower
|
||||
} else if a.upper != nil {
|
||||
bound = a.upper
|
||||
}
|
||||
if bound == nil {
|
||||
if b.lower != nil {
|
||||
bound = b.lower
|
||||
} else if b.upper != nil {
|
||||
bound = b.upper
|
||||
}
|
||||
}
|
||||
if bound == nil {
|
||||
return schemapb.DataType_None
|
||||
}
|
||||
switch bound.Val.(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
{
|
||||
return schemapb.DataType_Int64
|
||||
}
|
||||
case *planpb.GenericValue_StringVal:
|
||||
{
|
||||
return schemapb.DataType_VarChar
|
||||
}
|
||||
}
|
||||
return schemapb.DataType_None
|
||||
}
|
||||
|
||||
func Intersect(a *PlanRange, b *PlanRange) *PlanRange {
|
||||
dataType := GetCommonDataType(a, b)
|
||||
complementPlanRange(a, dataType)
|
||||
complementPlanRange(b, dataType)
|
||||
|
||||
// Check if 'a' and 'b' non-overlapping at all
|
||||
rightBound := minGenericValue(a.upper, b.upper)
|
||||
leftBound := maxGenericValue(a.lower, b.lower)
|
||||
if compareGenericValue(leftBound, rightBound) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if 'a' range ends exactly where 'b' range starts
|
||||
if !a.includeUpper && !b.includeLower && (compareGenericValue(a.upper, b.lower) == 0) {
|
||||
return nil
|
||||
}
|
||||
// Check if 'b' range ends exactly where 'a' range starts
|
||||
if !b.includeUpper && !a.includeLower && (compareGenericValue(b.upper, a.lower) == 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &PlanRange{
|
||||
lower: leftBound,
|
||||
upper: rightBound,
|
||||
includeLower: a.includeLower || b.includeLower,
|
||||
includeUpper: a.includeUpper || b.includeUpper,
|
||||
}
|
||||
}
|
||||
|
||||
func compareGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) int64 {
|
||||
if right == nil || left == nil {
|
||||
return -1
|
||||
}
|
||||
switch left.Val.(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
if left.GetInt64Val() == right.GetInt64Val() {
|
||||
return 0
|
||||
} else if left.GetInt64Val() < right.GetInt64Val() {
|
||||
return -1
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
case *planpb.GenericValue_StringVal:
|
||||
if right.Val == nil {
|
||||
return -1
|
||||
}
|
||||
return int64(strings.Compare(left.GetStringVal(), right.GetStringVal()))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func minGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue {
|
||||
if compareGenericValue(left, right) < 0 {
|
||||
return left
|
||||
}
|
||||
return right
|
||||
}
|
||||
|
||||
func maxGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue {
|
||||
if compareGenericValue(left, right) >= 0 {
|
||||
return left
|
||||
}
|
||||
return right
|
||||
}
|
|
@ -0,0 +1,279 @@
|
|||
package exprutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/util/testutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestParsePartitionKeys(t *testing.T) {
|
||||
prefix := "TestParsePartitionKeys"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
fieldName2Type := make(map[string]schemapb.DataType)
|
||||
fieldName2Type["int64_field"] = schemapb.DataType_Int64
|
||||
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
|
||||
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
|
||||
schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type,
|
||||
"int64_field", false, 8)
|
||||
partitionKeyField := &schemapb.FieldSchema{
|
||||
Name: "partition_key_field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPartitionKey: true,
|
||||
}
|
||||
schema.Fields = append(schema.Fields, partitionKeyField)
|
||||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
for _, field := range schema.Fields {
|
||||
field.FieldID = int64(fieldID)
|
||||
fieldID++
|
||||
}
|
||||
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
|
||||
queryInfo := &planpb.QueryInfo{
|
||||
Topk: 10,
|
||||
MetricType: "L2",
|
||||
SearchParams: "",
|
||||
RoundDecimal: -1,
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
expr string
|
||||
expected int
|
||||
validPartitionKeys []int64
|
||||
invalidPartitionKeys []int64
|
||||
}
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "binary_expr_and with term",
|
||||
expr: "partition_key_field in [7, 8] && int64_field >= 10",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with equal",
|
||||
expr: "partition_key_field == 7 && int64_field >= 10",
|
||||
expected: 1,
|
||||
validPartitionKeys: []int64{7},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with term2",
|
||||
expr: "partition_key_field in [7, 8] && int64_field == 10",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{10},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with partition key in range",
|
||||
expr: "partition_key_field in [7, 8] && partition_key_field > 9",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{9},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with partition key in range2",
|
||||
expr: "int64_field == 10 && partition_key_field > 9",
|
||||
expected: 0,
|
||||
validPartitionKeys: []int64{},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with term and not",
|
||||
expr: "partition_key_field in [7, 8] && partition_key_field not in [10, 20]",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{10, 20},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_or with term and not",
|
||||
expr: "partition_key_field in [7, 8] or partition_key_field not in [10, 20]",
|
||||
expected: 0,
|
||||
validPartitionKeys: []int64{},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_or with term and not 2",
|
||||
expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{10, 20},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// test search plan
|
||||
searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo)
|
||||
assert.NoError(t, err)
|
||||
expr, err := ParseExprFromPlan(searchPlan)
|
||||
assert.NoError(t, err)
|
||||
partitionKeys := ParseKeys(expr, PartitionKey)
|
||||
assert.Equal(t, tc.expected, len(partitionKeys))
|
||||
for _, key := range partitionKeys {
|
||||
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
|
||||
assert.Contains(t, tc.validPartitionKeys, int64Val)
|
||||
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
|
||||
}
|
||||
|
||||
// test query plan
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, tc.expr)
|
||||
assert.NoError(t, err)
|
||||
expr, err = ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
partitionKeys = ParseKeys(expr, PartitionKey)
|
||||
assert.Equal(t, tc.expected, len(partitionKeys))
|
||||
for _, key := range partitionKeys {
|
||||
int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
|
||||
assert.Contains(t, tc.validPartitionKeys, int64Val)
|
||||
assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIntRanges(t *testing.T) {
|
||||
prefix := "TestParseRanges"
|
||||
clusterKeyField := "cluster_key_field"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
fieldName2Type := make(map[string]schemapb.DataType)
|
||||
fieldName2Type["int64_field"] = schemapb.DataType_Int64
|
||||
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
|
||||
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
|
||||
schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type,
|
||||
"int64_field", false, 8)
|
||||
clusterKeyFieldSchema := &schemapb.FieldSchema{
|
||||
Name: clusterKeyField,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsClusteringKey: true,
|
||||
}
|
||||
schema.Fields = append(schema.Fields, clusterKeyFieldSchema)
|
||||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
for _, field := range schema.Fields {
|
||||
field.FieldID = int64(fieldID)
|
||||
fieldID++
|
||||
}
|
||||
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
// test query plan
|
||||
{
|
||||
expr := "cluster_key_field > 50"
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
|
||||
assert.NoError(t, err)
|
||||
planExpr, err := ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
|
||||
assert.False(t, matchALL)
|
||||
assert.Equal(t, 1, len(parsedRanges))
|
||||
range0 := parsedRanges[0]
|
||||
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50))
|
||||
assert.Nil(t, range0.upper)
|
||||
assert.Equal(t, range0.includeLower, false)
|
||||
assert.Equal(t, range0.includeUpper, false)
|
||||
}
|
||||
|
||||
// test binary query plan
|
||||
{
|
||||
expr := "cluster_key_field > 50 and cluster_key_field <= 100"
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
|
||||
assert.NoError(t, err)
|
||||
planExpr, err := ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
|
||||
assert.False(t, matchALL)
|
||||
assert.Equal(t, 1, len(parsedRanges))
|
||||
range0 := parsedRanges[0]
|
||||
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50))
|
||||
assert.Equal(t, false, range0.includeLower)
|
||||
assert.Equal(t, true, range0.includeUpper)
|
||||
}
|
||||
|
||||
// test binary query plan
|
||||
{
|
||||
expr := "cluster_key_field >= 50 and cluster_key_field < 100"
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
|
||||
assert.NoError(t, err)
|
||||
planExpr, err := ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
|
||||
assert.False(t, matchALL)
|
||||
assert.Equal(t, 1, len(parsedRanges))
|
||||
range0 := parsedRanges[0]
|
||||
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50))
|
||||
assert.Equal(t, true, range0.includeLower)
|
||||
assert.Equal(t, false, range0.includeUpper)
|
||||
}
|
||||
|
||||
// test binary query plan
|
||||
{
|
||||
expr := "cluster_key_field in [100]"
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
|
||||
assert.NoError(t, err)
|
||||
planExpr, err := ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
|
||||
assert.False(t, matchALL)
|
||||
assert.Equal(t, 1, len(parsedRanges))
|
||||
range0 := parsedRanges[0]
|
||||
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(100))
|
||||
assert.Equal(t, true, range0.includeLower)
|
||||
assert.Equal(t, true, range0.includeUpper)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStrRanges(t *testing.T) {
|
||||
prefix := "TestParseRanges"
|
||||
clusterKeyField := "cluster_key_field"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
fieldName2Type := make(map[string]schemapb.DataType)
|
||||
fieldName2Type["int64_field"] = schemapb.DataType_Int64
|
||||
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
|
||||
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
|
||||
schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type,
|
||||
"int64_field", false, 8)
|
||||
clusterKeyFieldSchema := &schemapb.FieldSchema{
|
||||
Name: clusterKeyField,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
IsClusteringKey: true,
|
||||
}
|
||||
schema.Fields = append(schema.Fields, clusterKeyFieldSchema)
|
||||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
for _, field := range schema.Fields {
|
||||
field.FieldID = int64(fieldID)
|
||||
fieldID++
|
||||
}
|
||||
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
// test query plan
|
||||
{
|
||||
expr := "cluster_key_field >= \"aaa\""
|
||||
queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr)
|
||||
assert.NoError(t, err)
|
||||
planExpr, err := ParseExprFromPlan(queryPlan)
|
||||
assert.NoError(t, err)
|
||||
parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey)
|
||||
assert.False(t, matchALL)
|
||||
assert.Equal(t, 1, len(parsedRanges))
|
||||
range0 := parsedRanges[0]
|
||||
assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_StringVal).StringVal, "aaa")
|
||||
assert.Nil(t, range0.upper)
|
||||
assert.Equal(t, range0.includeLower, true)
|
||||
assert.Equal(t, range0.includeUpper, false)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
const (
|
||||
testMaxVarCharLength = 100
|
||||
)
|
||||
|
||||
func ConstructCollectionSchemaWithKeys(collectionName string,
|
||||
fieldName2DataType map[string]schemapb.DataType,
|
||||
primaryFieldName string,
|
||||
partitionKeyFieldName string,
|
||||
clusteringKeyFieldName string,
|
||||
autoID bool,
|
||||
dim int,
|
||||
) *schemapb.CollectionSchema {
|
||||
schema := ConstructCollectionSchemaByDataType(collectionName,
|
||||
fieldName2DataType,
|
||||
primaryFieldName,
|
||||
autoID,
|
||||
dim)
|
||||
for _, field := range schema.Fields {
|
||||
if field.Name == partitionKeyFieldName {
|
||||
field.IsPartitionKey = true
|
||||
}
|
||||
if field.Name == clusteringKeyFieldName {
|
||||
field.IsClusteringKey = true
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
func isVectorType(dataType schemapb.DataType) bool {
|
||||
return dataType == schemapb.DataType_FloatVector ||
|
||||
dataType == schemapb.DataType_BinaryVector ||
|
||||
dataType == schemapb.DataType_Float16Vector ||
|
||||
dataType == schemapb.DataType_BFloat16Vector
|
||||
}
|
||||
|
||||
func ConstructCollectionSchemaByDataType(collectionName string,
|
||||
fieldName2DataType map[string]schemapb.DataType,
|
||||
primaryFieldName string,
|
||||
autoID bool,
|
||||
dim int,
|
||||
) *schemapb.CollectionSchema {
|
||||
fieldsSchema := make([]*schemapb.FieldSchema, 0)
|
||||
fieldIdx := int64(0)
|
||||
for fieldName, dataType := range fieldName2DataType {
|
||||
fieldSchema := &schemapb.FieldSchema{
|
||||
Name: fieldName,
|
||||
DataType: dataType,
|
||||
FieldID: fieldIdx,
|
||||
}
|
||||
fieldIdx += 1
|
||||
if isVectorType(dataType) {
|
||||
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
}
|
||||
}
|
||||
if dataType == schemapb.DataType_VarChar {
|
||||
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MaxLengthKey,
|
||||
Value: strconv.Itoa(testMaxVarCharLength),
|
||||
},
|
||||
}
|
||||
}
|
||||
if fieldName == primaryFieldName {
|
||||
fieldSchema.IsPrimaryKey = true
|
||||
fieldSchema.AutoID = autoID
|
||||
}
|
||||
|
||||
fieldsSchema = append(fieldsSchema, fieldSchema)
|
||||
}
|
||||
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Fields: fieldsSchema,
|
||||
}
|
||||
}
|
|
@ -119,3 +119,12 @@ func convertToArrowType(dataType schemapb.DataType) (arrow.DataType, error) {
|
|||
return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", dataType.String())
|
||||
}
|
||||
}
|
||||
|
||||
func GetClusteringKeyField(fields []*schemapb.FieldSchema) *schemapb.FieldSchema {
|
||||
for _, field := range fields {
|
||||
if field.IsClusteringKey {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -88,6 +88,9 @@ const (
|
|||
|
||||
// SegmentIndexPath storage path const for segment index files.
|
||||
SegmentIndexPath = `index_files`
|
||||
|
||||
// PartitionStatsPath storage path const for partition stats files
|
||||
PartitionStatsPath = `part_stats`
|
||||
)
|
||||
|
||||
// Search, Index parameter keys
|
||||
|
|
|
@ -2039,6 +2039,7 @@ type queryNodeConfig struct {
|
|||
FlowGraphMaxParallelism ParamItem `refreshable:"false"`
|
||||
|
||||
MemoryIndexLoadPredictMemoryUsageFactor ParamItem `refreshable:"true"`
|
||||
EnableSegmentPrune ParamItem `refreshable:"false"`
|
||||
}
|
||||
|
||||
func (p *queryNodeConfig) init(base *BaseTable) {
|
||||
|
@ -2512,6 +2513,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to
|
|||
Doc: "memory usage prediction factor for memory index loaded",
|
||||
}
|
||||
p.MemoryIndexLoadPredictMemoryUsageFactor.Init(base.mgr)
|
||||
p.EnableSegmentPrune = ParamItem{
|
||||
Key: "queryNode.enableSegmentPrune",
|
||||
Version: "2.3.4",
|
||||
DefaultValue: "false",
|
||||
Doc: "use partition prune function on shard delegator",
|
||||
}
|
||||
p.EnableSegmentPrune.Init(base.mgr)
|
||||
}
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////
|
||||
|
|
Loading…
Reference in New Issue