mirror of https://github.com/milvus-io/milvus.git
feat: support rate limiter based on db and partition levels (#31070)
issue: https://github.com/milvus-io/milvus/issues/30577 co-author: @jaime0815 --------- Signed-off-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com> Co-authored-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com>pull/32216/head
parent
fb376fd1e6
commit
c012e6786f
1
Makefile
1
Makefile
|
@ -457,6 +457,7 @@ generate-mockery-datacoord: getdeps
|
|||
$(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=Scheduler --dir=internal/datacoord --filename=mock_scheduler.go --output=internal/datacoord --structname=MockScheduler --with-expecter --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage
|
||||
|
||||
generate-mockery-datanode: getdeps
|
||||
$(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Code generated by mockery v2.30.1. DO NOT EDIT.
|
||||
// Code generated by mockery v2.32.4. DO NOT EDIT.
|
||||
|
||||
package broker
|
||||
|
||||
|
@ -77,6 +77,59 @@ func (_c *MockBroker_DescribeCollectionInternal_Call) RunAndReturn(run func(cont
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetDatabaseID provides a mock function with given fields: ctx, dbName
|
||||
func (_m *MockBroker) GetDatabaseID(ctx context.Context, dbName string) (int64, error) {
|
||||
ret := _m.Called(ctx, dbName)
|
||||
|
||||
var r0 int64
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok {
|
||||
return rf(ctx, dbName)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok {
|
||||
r0 = rf(ctx, dbName)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int64)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, dbName)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockBroker_GetDatabaseID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseID'
|
||||
type MockBroker_GetDatabaseID_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetDatabaseID is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - dbName string
|
||||
func (_e *MockBroker_Expecter) GetDatabaseID(ctx interface{}, dbName interface{}) *MockBroker_GetDatabaseID_Call {
|
||||
return &MockBroker_GetDatabaseID_Call{Call: _e.mock.On("GetDatabaseID", ctx, dbName)}
|
||||
}
|
||||
|
||||
func (_c *MockBroker_GetDatabaseID_Call) Run(run func(ctx context.Context, dbName string)) *MockBroker_GetDatabaseID_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockBroker_GetDatabaseID_Call) Return(_a0 int64, _a1 error) *MockBroker_GetDatabaseID_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockBroker_GetDatabaseID_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *MockBroker_GetDatabaseID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// HasCollection provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *MockBroker) HasCollection(ctx context.Context, collectionID int64) (bool, error) {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
|
|
@ -263,7 +263,7 @@ func CheckDiskQuota(job ImportJob, meta *meta, imeta ImportMeta) (int64, error)
|
|||
}
|
||||
|
||||
err := merr.WrapErrServiceQuotaExceeded("disk quota exceeded, please allocate more resources")
|
||||
totalUsage, collectionsUsage := meta.GetCollectionBinlogSize()
|
||||
totalUsage, collectionsUsage, _ := meta.GetCollectionBinlogSize()
|
||||
|
||||
tasks := imeta.GetTaskBy(WithJob(job.GetJobID()), WithType(PreImportTaskType))
|
||||
files := make([]*datapb.ImportFileStats, 0)
|
||||
|
|
|
@ -85,6 +85,7 @@ type collectionInfo struct {
|
|||
Properties map[string]string
|
||||
CreatedAt Timestamp
|
||||
DatabaseName string
|
||||
DatabaseID int64
|
||||
}
|
||||
|
||||
// NewMeta creates meta from provided `kv.TxnKV`
|
||||
|
@ -200,6 +201,7 @@ func (m *meta) GetClonedCollectionInfo(collectionID UniqueID) *collectionInfo {
|
|||
StartPositions: common.CloneKeyDataPairs(coll.StartPositions),
|
||||
Properties: clonedProperties,
|
||||
DatabaseName: coll.DatabaseName,
|
||||
DatabaseID: coll.DatabaseID,
|
||||
}
|
||||
|
||||
return cloneColl
|
||||
|
@ -257,10 +259,11 @@ func (m *meta) GetNumRowsOfCollection(collectionID UniqueID) int64 {
|
|||
}
|
||||
|
||||
// GetCollectionBinlogSize returns the total binlog size and binlog size of collections.
|
||||
func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) {
|
||||
func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64, map[UniqueID]map[UniqueID]int64) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
collectionBinlogSize := make(map[UniqueID]int64)
|
||||
partitionBinlogSize := make(map[UniqueID]map[UniqueID]int64)
|
||||
collectionRowsNum := make(map[UniqueID]map[commonpb.SegmentState]int64)
|
||||
segments := m.segments.GetSegments()
|
||||
var total int64
|
||||
|
@ -270,6 +273,13 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) {
|
|||
total += segmentSize
|
||||
collectionBinlogSize[segment.GetCollectionID()] += segmentSize
|
||||
|
||||
partBinlogSize, ok := partitionBinlogSize[segment.GetCollectionID()]
|
||||
if !ok {
|
||||
partBinlogSize = make(map[int64]int64)
|
||||
partitionBinlogSize[segment.GetCollectionID()] = partBinlogSize
|
||||
}
|
||||
partBinlogSize[segment.GetPartitionID()] += segmentSize
|
||||
|
||||
coll, ok := m.collections[segment.GetCollectionID()]
|
||||
if ok {
|
||||
metrics.DataCoordStoredBinlogSize.WithLabelValues(coll.DatabaseName,
|
||||
|
@ -294,7 +304,7 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) {
|
|||
}
|
||||
}
|
||||
}
|
||||
return total, collectionBinlogSize
|
||||
return total, collectionBinlogSize, partitionBinlogSize
|
||||
}
|
||||
|
||||
func (m *meta) GetAllCollectionNumRows() map[int64]int64 {
|
||||
|
|
|
@ -603,13 +603,13 @@ func TestMeta_Basic(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// check TotalBinlogSize
|
||||
total, collectionBinlogSize := meta.GetCollectionBinlogSize()
|
||||
total, collectionBinlogSize, _ := meta.GetCollectionBinlogSize()
|
||||
assert.Len(t, collectionBinlogSize, 1)
|
||||
assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID])
|
||||
assert.Equal(t, int64(size0+size1), total)
|
||||
|
||||
meta.collections[collID] = collInfo
|
||||
total, collectionBinlogSize = meta.GetCollectionBinlogSize()
|
||||
total, collectionBinlogSize, _ = meta.GetCollectionBinlogSize()
|
||||
assert.Len(t, collectionBinlogSize, 1)
|
||||
assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID])
|
||||
assert.Equal(t, int64(size0+size1), total)
|
||||
|
|
|
@ -37,10 +37,11 @@ import (
|
|||
|
||||
// getQuotaMetrics returns DataCoordQuotaMetrics.
|
||||
func (s *Server) getQuotaMetrics() *metricsinfo.DataCoordQuotaMetrics {
|
||||
total, colSizes := s.meta.GetCollectionBinlogSize()
|
||||
total, colSizes, partSizes := s.meta.GetCollectionBinlogSize()
|
||||
return &metricsinfo.DataCoordQuotaMetrics{
|
||||
TotalBinlogSize: total,
|
||||
CollectionBinlogSize: colSizes,
|
||||
PartitionsBinlogSize: partSizes,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -329,6 +329,15 @@ type mockRootCoordClient struct {
|
|||
cnt int64
|
||||
}
|
||||
|
||||
func (m *mockRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
return &rootcoordpb.DescribeDatabaseResponse{
|
||||
Status: merr.Success(),
|
||||
DbID: 1,
|
||||
DbName: "default",
|
||||
CreatedTimestamp: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockRootCoordClient) Close() error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
|
|
|
@ -1159,6 +1159,7 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i
|
|||
Properties: properties,
|
||||
CreatedAt: resp.GetCreatedTimestamp(),
|
||||
DatabaseName: resp.GetDbName(),
|
||||
DatabaseID: resp.GetDbId(),
|
||||
}
|
||||
s.meta.AddCollection(collInfo)
|
||||
return nil
|
||||
|
|
|
@ -1551,6 +1551,7 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt
|
|||
Partitions: req.GetPartitionIDs(),
|
||||
StartPositions: req.GetStartPositions(),
|
||||
Properties: properties,
|
||||
DatabaseID: req.GetDbID(),
|
||||
}
|
||||
s.meta.AddCollection(collInfo)
|
||||
return merr.Success(), nil
|
||||
|
|
|
@ -47,7 +47,7 @@ func initGlobalRateCollector() error {
|
|||
|
||||
// newRateCollector returns a new rateCollector.
|
||||
func newRateCollector() (*rateCollector, error) {
|
||||
rc, err := ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
|
||||
rc, err := ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -648,3 +648,14 @@ func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRe
|
|||
}
|
||||
return ret.(*milvuspb.ListDatabasesResponse), err
|
||||
}
|
||||
|
||||
func (c *Client) DescribeDatabase(ctx context.Context, req *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
|
||||
)
|
||||
return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
return client.DescribeDatabase(ctx, req)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -77,6 +77,10 @@ type Server struct {
|
|||
newQueryCoordClient func() types.QueryCoordClient
|
||||
}
|
||||
|
||||
func (s *Server) DescribeDatabase(ctx context.Context, request *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
return s.rootCoord.DescribeDatabase(ctx, request)
|
||||
}
|
||||
|
||||
func (s *Server) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
|
||||
return s.rootCoord.CreateDatabase(ctx, request)
|
||||
}
|
||||
|
|
|
@ -861,6 +861,61 @@ func (_c *RootCoord_DescribeCollectionInternal_Call) RunAndReturn(run func(conte
|
|||
return _c
|
||||
}
|
||||
|
||||
// DescribeDatabase provides a mock function with given fields: _a0, _a1
|
||||
func (_m *RootCoord) DescribeDatabase(_a0 context.Context, _a1 *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *rootcoordpb.DescribeDatabaseResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) *rootcoordpb.DescribeDatabaseResponse); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RootCoord_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase'
|
||||
type RootCoord_DescribeDatabase_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// DescribeDatabase is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *rootcoordpb.DescribeDatabaseRequest
|
||||
func (_e *RootCoord_Expecter) DescribeDatabase(_a0 interface{}, _a1 interface{}) *RootCoord_DescribeDatabase_Call {
|
||||
return &RootCoord_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *RootCoord_DescribeDatabase_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.DescribeDatabaseRequest)) *RootCoord_DescribeDatabase_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*rootcoordpb.DescribeDatabaseRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *RootCoord_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *RootCoord_DescribeDatabase_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *RootCoord_DescribeDatabase_Call) RunAndReturn(run func(context.Context, *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error)) *RootCoord_DescribeDatabase_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// DropAlias provides a mock function with given fields: _a0, _a1
|
||||
func (_m *RootCoord) DropAlias(_a0 context.Context, _a1 *milvuspb.DropAliasRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
|
|
@ -1124,6 +1124,76 @@ func (_c *MockRootCoordClient_DescribeCollectionInternal_Call) RunAndReturn(run
|
|||
return _c
|
||||
}
|
||||
|
||||
// DescribeDatabase provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, in)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *rootcoordpb.DescribeDatabaseResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error)); ok {
|
||||
return rf(ctx, in, opts...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) *rootcoordpb.DescribeDatabaseResponse); ok {
|
||||
r0 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) error); ok {
|
||||
r1 = rf(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockRootCoordClient_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase'
|
||||
type MockRootCoordClient_DescribeDatabase_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// DescribeDatabase is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *rootcoordpb.DescribeDatabaseRequest
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *MockRootCoordClient_Expecter) DescribeDatabase(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DescribeDatabase_Call {
|
||||
return &MockRootCoordClient_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockRootCoordClient_DescribeDatabase_Call) Run(run func(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DescribeDatabase_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(grpc.CallOption)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(*rootcoordpb.DescribeDatabaseRequest), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRootCoordClient_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *MockRootCoordClient_DescribeDatabase_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRootCoordClient_DescribeDatabase_Call) RunAndReturn(run func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error)) *MockRootCoordClient_DescribeDatabase_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// DropAlias provides a mock function with given fields: ctx, in, opts
|
||||
func (_m *MockRootCoordClient) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
|
|
|
@ -634,6 +634,7 @@ message AlterCollectionRequest {
|
|||
repeated int64 partitionIDs = 3;
|
||||
repeated common.KeyDataPair start_positions = 4;
|
||||
repeated common.KeyValuePair properties = 5;
|
||||
int64 dbID = 6;
|
||||
}
|
||||
|
||||
message GcConfirmRequest {
|
||||
|
|
|
@ -266,6 +266,13 @@ message ShowConfigurationsResponse {
|
|||
repeated common.KeyValuePair configuations = 2;
|
||||
}
|
||||
|
||||
enum RateScope {
|
||||
Cluster = 0;
|
||||
Database = 1;
|
||||
Collection = 2;
|
||||
Partition = 3;
|
||||
}
|
||||
|
||||
enum RateType {
|
||||
DDLCollection = 0;
|
||||
DDLPartition = 1;
|
||||
|
|
|
@ -58,6 +58,7 @@ message RefreshPolicyInfoCacheRequest {
|
|||
string opKey = 3;
|
||||
}
|
||||
|
||||
// Deprecated: use ClusterLimiter instead it
|
||||
message CollectionRate {
|
||||
int64 collection = 1;
|
||||
repeated internal.Rate rates = 2;
|
||||
|
@ -65,9 +66,27 @@ message CollectionRate {
|
|||
repeated common.ErrorCode codes = 4;
|
||||
}
|
||||
|
||||
message LimiterNode {
|
||||
// self limiter information
|
||||
Limiter limiter = 1;
|
||||
// db id -> db limiter
|
||||
// collection id -> collection limiter
|
||||
// partition id -> partition limiter
|
||||
map<int64, LimiterNode> children = 2;
|
||||
}
|
||||
|
||||
message Limiter {
|
||||
repeated internal.Rate rates = 1;
|
||||
// we can use map to store quota states and error code, because key in map fields cannot be enum types
|
||||
repeated milvus.QuotaState states = 2;
|
||||
repeated common.ErrorCode codes = 3;
|
||||
}
|
||||
|
||||
message SetRatesRequest {
|
||||
common.MsgBase base = 1;
|
||||
// deprecated
|
||||
repeated CollectionRate rates = 2;
|
||||
LimiterNode rootLimiter = 3;
|
||||
}
|
||||
|
||||
message ListClientInfosRequest {
|
||||
|
|
|
@ -140,6 +140,7 @@ service RootCoord {
|
|||
rpc CreateDatabase(milvus.CreateDatabaseRequest) returns (common.Status) {}
|
||||
rpc DropDatabase(milvus.DropDatabaseRequest) returns (common.Status) {}
|
||||
rpc ListDatabases(milvus.ListDatabasesRequest) returns (milvus.ListDatabasesResponse) {}
|
||||
rpc DescribeDatabase(DescribeDatabaseRequest) returns(DescribeDatabaseResponse){}
|
||||
}
|
||||
|
||||
message AllocTimestampRequest {
|
||||
|
@ -206,3 +207,14 @@ message GetCredentialResponse {
|
|||
string password = 3;
|
||||
}
|
||||
|
||||
message DescribeDatabaseRequest {
|
||||
common.MsgBase base = 1;
|
||||
string db_name = 2;
|
||||
}
|
||||
|
||||
message DescribeDatabaseResponse {
|
||||
common.Status status = 1;
|
||||
string db_name = 2;
|
||||
int64 dbID = 3;
|
||||
uint64 created_timestamp = 4;
|
||||
}
|
||||
|
|
|
@ -56,6 +56,8 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
|
@ -160,8 +162,10 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
|||
for _, alias := range aliasName {
|
||||
metrics.CleanupProxyCollectionMetrics(paramtable.GetNodeID(), alias)
|
||||
}
|
||||
DeregisterSubLabel(ratelimitutil.GetCollectionSubLabel(request.GetDbName(), request.GetCollectionName()))
|
||||
} else if msgType == commonpb.MsgType_DropDatabase {
|
||||
metrics.CleanupProxyDBMetrics(paramtable.GetNodeID(), request.GetDbName())
|
||||
DeregisterSubLabel(ratelimitutil.GetDBSubLabel(request.GetDbName()))
|
||||
}
|
||||
log.Info("complete to invalidate collection meta cache")
|
||||
|
||||
|
@ -289,6 +293,7 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
|
|||
}
|
||||
|
||||
log.Info(rpcDone(method))
|
||||
DeregisterSubLabel(ratelimitutil.GetDBSubLabel(request.GetDbName()))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
|
@ -527,6 +532,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
|||
zap.Uint64("BeginTs", dct.BeginTs()),
|
||||
zap.Uint64("EndTs", dct.EndTs()),
|
||||
)
|
||||
DeregisterSubLabel(ratelimitutil.GetCollectionSubLabel(request.GetDbName(), request.GetCollectionName()))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
|
@ -2680,11 +2686,11 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
|||
hookutil.FailCntKey: len(it.result.ErrIndex),
|
||||
})
|
||||
SetReportValue(it.result.GetStatus(), v)
|
||||
|
||||
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
|
||||
if merr.Ok(it.result.GetStatus()) {
|
||||
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeUpsert, dbName, username).Add(float64(v))
|
||||
}
|
||||
|
||||
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(nodeID, method,
|
||||
metrics.SuccessLabel, dbName, collectionName).Inc()
|
||||
successCnt := it.result.UpsertCnt - int64(len(it.result.ErrIndex))
|
||||
|
@ -2700,6 +2706,19 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
|||
return it.result, nil
|
||||
}
|
||||
|
||||
func GetDBAndCollectionRateSubLabels(req any) []string {
|
||||
subLabels := make([]string, 2)
|
||||
dbName, _ := requestutil.GetDbNameFromRequest(req)
|
||||
if dbName != "" {
|
||||
subLabels[0] = ratelimitutil.GetDBSubLabel(dbName.(string))
|
||||
}
|
||||
collectionName, _ := requestutil.GetCollectionNameFromRequest(req)
|
||||
if collectionName != "" {
|
||||
subLabels[1] = ratelimitutil.GetCollectionSubLabel(dbName.(string), collectionName.(string))
|
||||
}
|
||||
return subLabels
|
||||
}
|
||||
|
||||
// Search searches the most similar records of requests.
|
||||
func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
var err error
|
||||
|
@ -2734,7 +2753,8 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
request.GetCollectionName(),
|
||||
).Add(float64(request.GetNq()))
|
||||
|
||||
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()))
|
||||
subLabels := GetDBAndCollectionRateSubLabels(request)
|
||||
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()), subLabels...)
|
||||
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
|
@ -2909,8 +2929,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
if merr.Ok(qt.result.GetStatus()) {
|
||||
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v))
|
||||
}
|
||||
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(nodeID).Add(float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...)
|
||||
}
|
||||
return qt.result, nil
|
||||
}
|
||||
|
@ -2941,6 +2962,13 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||
request.GetCollectionName(),
|
||||
).Add(float64(receiveSize))
|
||||
|
||||
subLabels := GetDBAndCollectionRateSubLabels(request)
|
||||
allNQ := int64(0)
|
||||
for _, searchRequest := range request.Requests {
|
||||
allNQ += searchRequest.GetNq()
|
||||
}
|
||||
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(allNQ), subLabels...)
|
||||
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
|
@ -3098,8 +3126,9 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||
if merr.Ok(qt.result.GetStatus()) {
|
||||
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeHybridSearch, dbName, username).Add(float64(v))
|
||||
}
|
||||
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(nodeID).Add(float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...)
|
||||
}
|
||||
return qt.result, nil
|
||||
}
|
||||
|
@ -3246,7 +3275,8 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryRes
|
|||
request.GetCollectionName(),
|
||||
).Add(float64(1))
|
||||
|
||||
rateCol.Add(internalpb.RateType_DQLQuery.String(), 1)
|
||||
subLabels := GetDBAndCollectionRateSubLabels(request)
|
||||
rateCol.Add(internalpb.RateType_DQLQuery.String(), 1, subLabels...)
|
||||
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.QueryResults{
|
||||
|
@ -3364,7 +3394,7 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryRes
|
|||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
sentSize := proto.Size(qt.result)
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...)
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
|
||||
|
||||
return qt.result, nil
|
||||
|
@ -5092,7 +5122,7 @@ func (node *Proxy) SetRates(ctx context.Context, request *proxypb.SetRatesReques
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
err := node.multiRateLimiter.SetRates(request.GetRates())
|
||||
err := node.simpleLimiter.SetRates(request.GetRootLimiter())
|
||||
// TODO: set multiple rate limiter rates
|
||||
if err != nil {
|
||||
resp = merr.Status(err)
|
||||
|
@ -5162,12 +5192,9 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt
|
|||
}, nil
|
||||
}
|
||||
|
||||
states, reasons := node.multiRateLimiter.GetQuotaStates()
|
||||
return &milvuspb.CheckHealthResponse{
|
||||
Status: merr.Success(),
|
||||
QuotaStates: states,
|
||||
Reasons: reasons,
|
||||
IsHealthy: true,
|
||||
Status: merr.Success(),
|
||||
IsHealthy: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -5978,3 +6005,10 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR
|
|||
metrics.ProxyReqLatency.WithLabelValues(nodeID, method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DeregisterSubLabel must add the sub-labels here if using other labels for the sub-labels
|
||||
func DeregisterSubLabel(subLabel string) {
|
||||
rateCol.DeregisterSubLabel(internalpb.RateType_DQLQuery.String(), subLabel)
|
||||
rateCol.DeregisterSubLabel(internalpb.RateType_DQLSearch.String(), subLabel)
|
||||
rateCol.DeregisterSubLabel(metricsinfo.ReadResultThroughput, subLabel)
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) {
|
|||
chMgr.EXPECT().removeDMLStream(mock.Anything).Return()
|
||||
|
||||
node := &Proxy{chMgr: chMgr}
|
||||
_ = node.initRateCollector()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
ctx := context.Background()
|
||||
|
@ -78,7 +79,7 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) {
|
|||
func TestProxy_CheckHealth(t *testing.T) {
|
||||
t.Run("not healthy", func(t *testing.T) {
|
||||
node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
ctx := context.Background()
|
||||
resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
|
||||
|
@ -96,7 +97,7 @@ func TestProxy_CheckHealth(t *testing.T) {
|
|||
dataCoord: NewDataCoordMock(),
|
||||
session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}},
|
||||
}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
ctx := context.Background()
|
||||
resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
|
||||
|
@ -129,7 +130,7 @@ func TestProxy_CheckHealth(t *testing.T) {
|
|||
queryCoord: qc,
|
||||
dataCoord: dataCoordMock,
|
||||
}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
ctx := context.Background()
|
||||
resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
|
||||
|
@ -146,7 +147,7 @@ func TestProxy_CheckHealth(t *testing.T) {
|
|||
dataCoord: NewDataCoordMock(),
|
||||
queryCoord: qc,
|
||||
}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err := node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{})
|
||||
assert.NoError(t, err)
|
||||
|
@ -156,18 +157,30 @@ func TestProxy_CheckHealth(t *testing.T) {
|
|||
|
||||
states := []milvuspb.QuotaState{milvuspb.QuotaState_DenyToWrite, milvuspb.QuotaState_DenyToRead}
|
||||
codes := []commonpb.ErrorCode{commonpb.ErrorCode_MemoryQuotaExhausted, commonpb.ErrorCode_ForceDeny}
|
||||
node.multiRateLimiter.SetRates([]*proxypb.CollectionRate{
|
||||
{
|
||||
Collection: 1,
|
||||
States: states,
|
||||
Codes: codes,
|
||||
err = node.simpleLimiter.SetRates(&proxypb.LimiterNode{
|
||||
Limiter: &proxypb.Limiter{},
|
||||
// db level
|
||||
Children: map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
Limiter: &proxypb.Limiter{},
|
||||
// collection level
|
||||
Children: map[int64]*proxypb.LimiterNode{
|
||||
100: {
|
||||
Limiter: &proxypb.Limiter{
|
||||
States: states,
|
||||
Codes: codes,
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err = node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, resp.IsHealthy)
|
||||
assert.Equal(t, 2, len(resp.GetQuotaStates()))
|
||||
assert.Equal(t, 2, len(resp.GetReasons()))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -229,7 +242,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
|
|||
|
||||
node, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
qc := mocks.NewMockQueryCoordClient(t)
|
||||
|
@ -321,7 +334,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
|
|||
|
||||
node, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
qc := mocks.NewMockQueryCoordClient(t)
|
||||
|
@ -922,7 +935,7 @@ func TestProxyCreateDatabase(t *testing.T) {
|
|||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
|
@ -977,11 +990,12 @@ func TestProxyDropDatabase(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
node, err := NewProxy(ctx, factory)
|
||||
node.initRateCollector()
|
||||
assert.NoError(t, err)
|
||||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
|
@ -1040,7 +1054,7 @@ func TestProxyListDatabase(t *testing.T) {
|
|||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
node.multiRateLimiter = NewMultiRateLimiter()
|
||||
node.simpleLimiter = NewSimpleLimiter()
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
|
|
|
@ -89,10 +89,10 @@ type Cache interface {
|
|||
|
||||
RemoveDatabase(ctx context.Context, database string)
|
||||
HasDatabase(ctx context.Context, database string) bool
|
||||
GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error)
|
||||
// AllocID is only using on requests that need to skip timestamp allocation, don't overuse it.
|
||||
AllocID(ctx context.Context) (int64, error)
|
||||
}
|
||||
|
||||
type collectionBasicInfo struct {
|
||||
collID typeutil.UniqueID
|
||||
createdTimestamp uint64
|
||||
|
@ -109,6 +109,11 @@ type collectionInfo struct {
|
|||
consistencyLevel commonpb.ConsistencyLevel
|
||||
}
|
||||
|
||||
type databaseInfo struct {
|
||||
dbID typeutil.UniqueID
|
||||
createdTimestamp uint64
|
||||
}
|
||||
|
||||
// schemaInfo is a helper function wraps *schemapb.CollectionSchema
|
||||
// with extra fields mapping and methods
|
||||
type schemaInfo struct {
|
||||
|
@ -244,17 +249,19 @@ type MetaCache struct {
|
|||
rootCoord types.RootCoordClient
|
||||
queryCoord types.QueryCoordClient
|
||||
|
||||
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
|
||||
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
|
||||
dbInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName
|
||||
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
|
||||
privilegeInfos map[string]struct{} // privileges cache
|
||||
userToRoles map[string]map[string]struct{} // user to role cache
|
||||
mu sync.RWMutex
|
||||
credMut sync.RWMutex
|
||||
leaderMut sync.RWMutex
|
||||
shardMgr shardClientMgr
|
||||
sfGlobal conc.Singleflight[*collectionInfo]
|
||||
dbInfo map[string]*databaseInfo // database -> db_info
|
||||
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
|
||||
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
|
||||
dbCollectionInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName
|
||||
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
|
||||
privilegeInfos map[string]struct{} // privileges cache
|
||||
userToRoles map[string]map[string]struct{} // user to role cache
|
||||
mu sync.RWMutex
|
||||
credMut sync.RWMutex
|
||||
leaderMut sync.RWMutex
|
||||
shardMgr shardClientMgr
|
||||
sfGlobal conc.Singleflight[*collectionInfo]
|
||||
sfDB conc.Singleflight[*databaseInfo]
|
||||
|
||||
IDStart int64
|
||||
IDCount int64
|
||||
|
@ -287,15 +294,16 @@ func InitMetaCache(ctx context.Context, rootCoord types.RootCoordClient, queryCo
|
|||
// NewMetaCache creates a MetaCache with provided RootCoord and QueryNode
|
||||
func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordClient, shardMgr shardClientMgr) (*MetaCache, error) {
|
||||
return &MetaCache{
|
||||
rootCoord: rootCoord,
|
||||
queryCoord: queryCoord,
|
||||
collInfo: map[string]map[string]*collectionInfo{},
|
||||
collLeader: map[string]map[string]*shardLeaders{},
|
||||
dbInfo: map[string]map[typeutil.UniqueID]string{},
|
||||
credMap: map[string]*internalpb.CredentialInfo{},
|
||||
shardMgr: shardMgr,
|
||||
privilegeInfos: map[string]struct{}{},
|
||||
userToRoles: map[string]map[string]struct{}{},
|
||||
rootCoord: rootCoord,
|
||||
queryCoord: queryCoord,
|
||||
dbInfo: map[string]*databaseInfo{},
|
||||
collInfo: map[string]map[string]*collectionInfo{},
|
||||
collLeader: map[string]map[string]*shardLeaders{},
|
||||
dbCollectionInfo: map[string]map[typeutil.UniqueID]string{},
|
||||
credMap: map[string]*internalpb.CredentialInfo{},
|
||||
shardMgr: shardMgr,
|
||||
privilegeInfos: map[string]struct{}{},
|
||||
userToRoles: map[string]map[string]struct{}{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -510,7 +518,7 @@ func (m *MetaCache) innerGetCollectionByID(collectionID int64) (string, string)
|
|||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for database, db := range m.dbInfo {
|
||||
for database, db := range m.dbCollectionInfo {
|
||||
name, ok := db[collectionID]
|
||||
if ok {
|
||||
return database, name
|
||||
|
@ -554,7 +562,7 @@ func (m *MetaCache) updateDBInfo(ctx context.Context) error {
|
|||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.dbInfo = dbInfo
|
||||
m.dbCollectionInfo = dbInfo
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -739,6 +747,19 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio
|
|||
return partitions, nil
|
||||
}
|
||||
|
||||
func (m *MetaCache) describeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
req := &rootcoordpb.DescribeDatabaseRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
|
||||
resp, err := m.rootCoord.DescribeDatabase(ctx, req)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parsePartitionsInfo parse partitionInfo list to partitionInfos struct.
|
||||
// prepare all name to id & info map
|
||||
// try parse partition names to partitionKey index.
|
||||
|
@ -1084,6 +1105,7 @@ func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.collInfo, database)
|
||||
delete(m.dbInfo, database)
|
||||
}
|
||||
|
||||
func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool {
|
||||
|
@ -1093,6 +1115,41 @@ func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool {
|
|||
return ok
|
||||
}
|
||||
|
||||
func (m *MetaCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) {
|
||||
dbInfo := m.safeGetDBInfo(database)
|
||||
if dbInfo != nil {
|
||||
return dbInfo, nil
|
||||
}
|
||||
|
||||
dbInfo, err, _ := m.sfDB.Do(database, func() (*databaseInfo, error) {
|
||||
resp, err := m.describeDatabase(ctx, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
dbInfo := &databaseInfo{
|
||||
dbID: resp.GetDbID(),
|
||||
createdTimestamp: resp.GetCreatedTimestamp(),
|
||||
}
|
||||
m.dbInfo[database] = dbInfo
|
||||
return dbInfo, nil
|
||||
})
|
||||
|
||||
return dbInfo, err
|
||||
}
|
||||
|
||||
func (m *MetaCache) safeGetDBInfo(database string) *databaseInfo {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
db, ok := m.dbInfo[database]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (m *MetaCache) AllocID(ctx context.Context) (int64, error) {
|
||||
m.IDLock.Lock()
|
||||
defer m.IDLock.Unlock()
|
||||
|
|
|
@ -817,6 +817,49 @@ func TestMetaCache_Database(t *testing.T) {
|
|||
assert.Equal(t, CheckDatabase(ctx, dbName), true)
|
||||
}
|
||||
|
||||
func TestGetDatabaseInfo(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := mocks.NewMockRootCoordClient(t)
|
||||
queryCoord := &mocks.MockQueryCoordClient{}
|
||||
shardMgr := newShardClientMgr()
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
|
||||
Status: merr.Success(),
|
||||
DbID: 1,
|
||||
DbName: "default",
|
||||
}, nil).Once()
|
||||
{
|
||||
dbInfo, err := cache.GetDatabaseInfo(ctx, "default")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, UniqueID(1), dbInfo.dbID)
|
||||
}
|
||||
|
||||
{
|
||||
dbInfo, err := cache.GetDatabaseInfo(ctx, "default")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, UniqueID(1), dbInfo.dbID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := mocks.NewMockRootCoordClient(t)
|
||||
queryCoord := &mocks.MockQueryCoordClient{}
|
||||
shardMgr := newShardClientMgr()
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
|
||||
Status: merr.Status(errors.New("mock error: describe database")),
|
||||
}, nil).Once()
|
||||
_, err = cache.GetDatabaseInfo(ctx, "default")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetaCache_AllocID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
queryCoord := &mocks.MockQueryCoordClient{}
|
||||
|
@ -935,9 +978,9 @@ func TestGlobalMetaCache_UpdateDBInfo(t *testing.T) {
|
|||
}, nil).Once()
|
||||
err := cache.updateDBInfo(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, cache.dbInfo, 1)
|
||||
assert.Len(t, cache.dbInfo["db1"], 1)
|
||||
assert.Equal(t, "collection1", cache.dbInfo["db1"][1])
|
||||
assert.Len(t, cache.dbCollectionInfo, 1)
|
||||
assert.Len(t, cache.dbCollectionInfo["db1"], 1)
|
||||
assert.Equal(t, "collection1", cache.dbCollectionInfo["db1"][1])
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -50,12 +50,29 @@ func getQuotaMetrics() (*metricsinfo.ProxyQuotaMetrics, error) {
|
|||
Rate: rate,
|
||||
})
|
||||
}
|
||||
|
||||
getSubLabelRateMetric := func(label string) {
|
||||
rates, err2 := rateCol.RateSubLabel(label, ratelimitutil.DefaultAvgDuration)
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
for s, f := range rates {
|
||||
rms = append(rms, metricsinfo.RateMetric{
|
||||
Label: s,
|
||||
Rate: f,
|
||||
})
|
||||
}
|
||||
}
|
||||
getRateMetric(internalpb.RateType_DMLInsert.String())
|
||||
getRateMetric(internalpb.RateType_DMLUpsert.String())
|
||||
getRateMetric(internalpb.RateType_DMLDelete.String())
|
||||
getRateMetric(internalpb.RateType_DQLSearch.String())
|
||||
getSubLabelRateMetric(internalpb.RateType_DQLSearch.String())
|
||||
getRateMetric(internalpb.RateType_DQLQuery.String())
|
||||
getSubLabelRateMetric(internalpb.RateType_DQLQuery.String())
|
||||
getRateMetric(metricsinfo.ReadResultThroughput)
|
||||
getSubLabelRateMetric(metricsinfo.ReadResultThroughput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -450,6 +450,61 @@ func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Contex
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetDatabaseInfo provides a mock function with given fields: ctx, database
|
||||
func (_m *MockCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) {
|
||||
ret := _m.Called(ctx, database)
|
||||
|
||||
var r0 *databaseInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*databaseInfo, error)); ok {
|
||||
return rf(ctx, database)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *databaseInfo); ok {
|
||||
r0 = rf(ctx, database)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*databaseInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, database)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockCache_GetDatabaseInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseInfo'
|
||||
type MockCache_GetDatabaseInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetDatabaseInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - database string
|
||||
func (_e *MockCache_Expecter) GetDatabaseInfo(ctx interface{}, database interface{}) *MockCache_GetDatabaseInfo_Call {
|
||||
return &MockCache_GetDatabaseInfo_Call{Call: _e.mock.On("GetDatabaseInfo", ctx, database)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetDatabaseInfo_Call) Run(run func(ctx context.Context, database string)) *MockCache_GetDatabaseInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetDatabaseInfo_Call) Return(_a0 *databaseInfo, _a1 error) *MockCache_GetDatabaseInfo_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetDatabaseInfo_Call) RunAndReturn(run func(context.Context, string) (*databaseInfo, error)) *MockCache_GetDatabaseInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName
|
||||
func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) {
|
||||
ret := _m.Called(ctx, database, collectionName, partitionName)
|
||||
|
|
|
@ -1,375 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
var QuotaErrorString = map[commonpb.ErrorCode]string{
|
||||
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator",
|
||||
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
|
||||
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
|
||||
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",
|
||||
}
|
||||
|
||||
func GetQuotaErrorString(errCode commonpb.ErrorCode) string {
|
||||
return QuotaErrorString[errCode]
|
||||
}
|
||||
|
||||
// MultiRateLimiter includes multilevel rate limiters, such as global rateLimiter,
|
||||
// collection level rateLimiter and so on. It also implements Limiter interface.
|
||||
type MultiRateLimiter struct {
|
||||
quotaStatesMu sync.RWMutex
|
||||
// for DML and DQL
|
||||
collectionLimiters map[int64]*rateLimiter
|
||||
// for DDL
|
||||
globalDDLLimiter *rateLimiter
|
||||
}
|
||||
|
||||
// NewMultiRateLimiter returns a new MultiRateLimiter.
|
||||
func NewMultiRateLimiter() *MultiRateLimiter {
|
||||
m := &MultiRateLimiter{
|
||||
collectionLimiters: make(map[int64]*rateLimiter, 0),
|
||||
globalDDLLimiter: newRateLimiter(true),
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Check checks if request would be limited or denied.
|
||||
func (m *MultiRateLimiter) Check(collectionIDs []int64, rt internalpb.RateType, n int) error {
|
||||
if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.quotaStatesMu.RLock()
|
||||
defer m.quotaStatesMu.RUnlock()
|
||||
|
||||
checkFunc := func(limiter *rateLimiter) error {
|
||||
if limiter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
limit, rate := limiter.limit(rt, n)
|
||||
if rate == 0 {
|
||||
return limiter.getQuotaExceededError(rt)
|
||||
}
|
||||
if limit {
|
||||
return limiter.getRateLimitError(rate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// first, check global level rate limits
|
||||
ret := checkFunc(m.globalDDLLimiter)
|
||||
|
||||
// second check collection level rate limits
|
||||
// only dml, dql and flush have collection level rate limits
|
||||
if ret == nil && len(collectionIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
|
||||
// store done limiters to cancel them when error occurs.
|
||||
doneLimiters := make([]*rateLimiter, 0, len(collectionIDs)+1)
|
||||
doneLimiters = append(doneLimiters, m.globalDDLLimiter)
|
||||
|
||||
for _, collectionID := range collectionIDs {
|
||||
ret = checkFunc(m.collectionLimiters[collectionID])
|
||||
if ret != nil {
|
||||
for _, limiter := range doneLimiters {
|
||||
limiter.cancel(rt, n)
|
||||
}
|
||||
break
|
||||
}
|
||||
doneLimiters = append(doneLimiters, m.collectionLimiters[collectionID])
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool {
|
||||
// Most ddl is global level, only DDLFlush will be applied at collection
|
||||
switch rt {
|
||||
case internalpb.RateType_DDLCollection, internalpb.RateType_DDLPartition, internalpb.RateType_DDLIndex,
|
||||
internalpb.RateType_DDLCompaction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetQuotaStates returns quota states.
|
||||
func (m *MultiRateLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) {
|
||||
m.quotaStatesMu.RLock()
|
||||
defer m.quotaStatesMu.RUnlock()
|
||||
serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode])
|
||||
|
||||
// deduplicate same (state, code) pair from different collection
|
||||
for _, limiter := range m.collectionLimiters {
|
||||
limiter.quotaStates.Range(func(state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
|
||||
if serviceStates[state] == nil {
|
||||
serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]()
|
||||
}
|
||||
serviceStates[state].Insert(errCode)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
states := make([]milvuspb.QuotaState, 0)
|
||||
reasons := make([]string, 0)
|
||||
for state, errCodes := range serviceStates {
|
||||
for errCode := range errCodes {
|
||||
states = append(states, state)
|
||||
reasons = append(reasons, GetQuotaErrorString(errCode))
|
||||
}
|
||||
}
|
||||
|
||||
return states, reasons
|
||||
}
|
||||
|
||||
// SetQuotaStates sets quota states for MultiRateLimiter.
|
||||
func (m *MultiRateLimiter) SetRates(rates []*proxypb.CollectionRate) error {
|
||||
m.quotaStatesMu.Lock()
|
||||
defer m.quotaStatesMu.Unlock()
|
||||
collectionSet := typeutil.NewUniqueSet()
|
||||
for _, collectionRates := range rates {
|
||||
collectionSet.Insert(collectionRates.Collection)
|
||||
rateLimiter, ok := m.collectionLimiters[collectionRates.GetCollection()]
|
||||
if !ok {
|
||||
rateLimiter = newRateLimiter(false)
|
||||
}
|
||||
err := rateLimiter.setRates(collectionRates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.collectionLimiters[collectionRates.GetCollection()] = rateLimiter
|
||||
}
|
||||
|
||||
// remove dropped collection's rate limiter
|
||||
for collectionID := range m.collectionLimiters {
|
||||
if !collectionSet.Contain(collectionID) {
|
||||
delete(m.collectionLimiters, collectionID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rateLimiter implements Limiter.
|
||||
type rateLimiter struct {
|
||||
limiters *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]
|
||||
quotaStates *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]
|
||||
}
|
||||
|
||||
// newRateLimiter returns a new RateLimiter.
|
||||
func newRateLimiter(globalLevel bool) *rateLimiter {
|
||||
rl := &rateLimiter{
|
||||
limiters: typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter](),
|
||||
quotaStates: typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode](),
|
||||
}
|
||||
rl.registerLimiters(globalLevel)
|
||||
return rl
|
||||
}
|
||||
|
||||
// limit returns true, the request will be rejected.
|
||||
// Otherwise, the request will pass.
|
||||
func (rl *rateLimiter) limit(rt internalpb.RateType, n int) (bool, float64) {
|
||||
limit, ok := rl.limiters.Get(rt)
|
||||
if !ok {
|
||||
return false, -1
|
||||
}
|
||||
return !limit.AllowN(time.Now(), n), float64(limit.Limit())
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) cancel(rt internalpb.RateType, n int) {
|
||||
limit, ok := rl.limiters.Get(rt)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
limit.Cancel(n)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) setRates(collectionRate *proxypb.CollectionRate) error {
|
||||
log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0).With(
|
||||
zap.Int64("proxyNodeID", paramtable.GetNodeID()),
|
||||
zap.Int64("CollectionID", collectionRate.Collection),
|
||||
)
|
||||
for _, r := range collectionRate.GetRates() {
|
||||
if limit, ok := rl.limiters.Get(r.GetRt()); ok {
|
||||
limit.SetLimit(ratelimitutil.Limit(r.GetR()))
|
||||
setRateGaugeByRateType(r.GetRt(), paramtable.GetNodeID(), collectionRate.Collection, r.GetR())
|
||||
} else {
|
||||
return fmt.Errorf("unregister rateLimiter for rateType %s", r.GetRt().String())
|
||||
}
|
||||
log.RatedDebug(30, "current collection rates in proxy",
|
||||
zap.String("rateType", r.Rt.String()),
|
||||
zap.String("rateLimit", ratelimitutil.Limit(r.GetR()).String()),
|
||||
)
|
||||
}
|
||||
|
||||
// clear old quota states
|
||||
rl.quotaStates = typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
|
||||
for i := 0; i < len(collectionRate.GetStates()); i++ {
|
||||
rl.quotaStates.Insert(collectionRate.States[i], collectionRate.Codes[i])
|
||||
log.RatedWarn(30, "Proxy set collection quota states",
|
||||
zap.String("state", collectionRate.GetStates()[i].String()),
|
||||
zap.String("reason", collectionRate.GetCodes()[i].String()),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) getQuotaExceededError(rt internalpb.RateType) error {
|
||||
switch rt {
|
||||
case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad:
|
||||
if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok {
|
||||
return merr.WrapErrServiceQuotaExceeded(GetQuotaErrorString(errCode))
|
||||
}
|
||||
case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery:
|
||||
if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok {
|
||||
return merr.WrapErrServiceQuotaExceeded(GetQuotaErrorString(errCode))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) getRateLimitError(rate float64) error {
|
||||
return merr.WrapErrServiceRateLimit(rate, "request is rejected by grpc RateLimiter middleware, please retry later")
|
||||
}
|
||||
|
||||
// setRateGaugeByRateType sets ProxyLimiterRate metrics.
|
||||
func setRateGaugeByRateType(rateType internalpb.RateType, nodeID int64, collectionID int64, rate float64) {
|
||||
if ratelimitutil.Limit(rate) == ratelimitutil.Inf {
|
||||
return
|
||||
}
|
||||
nodeIDStr := strconv.FormatInt(nodeID, 10)
|
||||
collectionIDStr := strconv.FormatInt(collectionID, 10)
|
||||
switch rateType {
|
||||
case internalpb.RateType_DMLInsert:
|
||||
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.InsertLabel).Set(rate)
|
||||
case internalpb.RateType_DMLUpsert:
|
||||
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.UpsertLabel).Set(rate)
|
||||
case internalpb.RateType_DMLDelete:
|
||||
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.DeleteLabel).Set(rate)
|
||||
case internalpb.RateType_DQLSearch:
|
||||
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.SearchLabel).Set(rate)
|
||||
case internalpb.RateType_DQLQuery:
|
||||
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.QueryLabel).Set(rate)
|
||||
}
|
||||
}
|
||||
|
||||
// registerLimiters register limiter for all rate types.
|
||||
func (rl *rateLimiter) registerLimiters(globalLevel bool) {
|
||||
log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0)
|
||||
quotaConfig := &Params.QuotaConfig
|
||||
for rt := range internalpb.RateType_name {
|
||||
var r *paramtable.ParamItem
|
||||
switch internalpb.RateType(rt) {
|
||||
case internalpb.RateType_DDLCollection:
|
||||
r = "aConfig.DDLCollectionRate
|
||||
case internalpb.RateType_DDLPartition:
|
||||
r = "aConfig.DDLPartitionRate
|
||||
case internalpb.RateType_DDLIndex:
|
||||
r = "aConfig.MaxIndexRate
|
||||
case internalpb.RateType_DDLFlush:
|
||||
if globalLevel {
|
||||
r = "aConfig.MaxFlushRate
|
||||
} else {
|
||||
r = "aConfig.MaxFlushRatePerCollection
|
||||
}
|
||||
case internalpb.RateType_DDLCompaction:
|
||||
r = "aConfig.MaxCompactionRate
|
||||
case internalpb.RateType_DMLInsert:
|
||||
if globalLevel {
|
||||
r = "aConfig.DMLMaxInsertRate
|
||||
} else {
|
||||
r = "aConfig.DMLMaxInsertRatePerCollection
|
||||
}
|
||||
case internalpb.RateType_DMLUpsert:
|
||||
if globalLevel {
|
||||
r = "aConfig.DMLMaxUpsertRate
|
||||
} else {
|
||||
r = "aConfig.DMLMaxUpsertRatePerCollection
|
||||
}
|
||||
case internalpb.RateType_DMLDelete:
|
||||
if globalLevel {
|
||||
r = "aConfig.DMLMaxDeleteRate
|
||||
} else {
|
||||
r = "aConfig.DMLMaxDeleteRatePerCollection
|
||||
}
|
||||
case internalpb.RateType_DMLBulkLoad:
|
||||
if globalLevel {
|
||||
r = "aConfig.DMLMaxBulkLoadRate
|
||||
} else {
|
||||
r = "aConfig.DMLMaxBulkLoadRatePerCollection
|
||||
}
|
||||
case internalpb.RateType_DQLSearch:
|
||||
if globalLevel {
|
||||
r = "aConfig.DQLMaxSearchRate
|
||||
} else {
|
||||
r = "aConfig.DQLMaxSearchRatePerCollection
|
||||
}
|
||||
case internalpb.RateType_DQLQuery:
|
||||
if globalLevel {
|
||||
r = "aConfig.DQLMaxQueryRate
|
||||
} else {
|
||||
r = "aConfig.DQLMaxQueryRatePerCollection
|
||||
}
|
||||
}
|
||||
limit := ratelimitutil.Limit(r.GetAsFloat())
|
||||
burst := r.GetAsFloat() // use rate as burst, because Limiter is with punishment mechanism, burst is insignificant.
|
||||
rl.limiters.GetOrInsert(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst))
|
||||
onEvent := func(rateType internalpb.RateType) func(*config.Event) {
|
||||
return func(event *config.Event) {
|
||||
f, err := strconv.ParseFloat(r.Formatter(event.Value), 64)
|
||||
if err != nil {
|
||||
log.Info("Error format for rateLimit",
|
||||
zap.String("rateType", rateType.String()),
|
||||
zap.String("key", event.Key),
|
||||
zap.String("value", event.Value),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
limit, ok := rl.limiters.Get(rateType)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
limit.SetLimit(ratelimitutil.Limit(f))
|
||||
}
|
||||
}(internalpb.RateType(rt))
|
||||
paramtable.Get().Watch(r.Key, config.NewHandler(fmt.Sprintf("rateLimiter-%d", rt), onEvent))
|
||||
log.RatedDebug(30, "RateLimiter register for rateType",
|
||||
zap.String("rateType", internalpb.RateType_name[rt]),
|
||||
zap.String("rateLimit", ratelimitutil.Limit(r.GetAsFloat()).String()),
|
||||
zap.String("burst", fmt.Sprintf("%v", burst)))
|
||||
}
|
||||
}
|
|
@ -1,338 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
)
|
||||
|
||||
func TestMultiRateLimiter(t *testing.T) {
|
||||
collectionID := int64(1)
|
||||
t.Run("test multiRateLimiter", func(t *testing.T) {
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
multiLimiter.collectionLimiters[collectionID] = newRateLimiter(false)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
|
||||
multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
|
||||
} else {
|
||||
multiLimiter.collectionLimiters[collectionID].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
}
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
|
||||
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 5)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 5)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
} else {
|
||||
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), math.MaxInt)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), math.MaxInt)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
}
|
||||
}
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
})
|
||||
|
||||
t.Run("test global static limit", func(t *testing.T) {
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
multiLimiter.collectionLimiters[1] = newRateLimiter(false)
|
||||
multiLimiter.collectionLimiters[2] = newRateLimiter(false)
|
||||
multiLimiter.collectionLimiters[3] = newRateLimiter(false)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
|
||||
multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
|
||||
} else {
|
||||
multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
|
||||
multiLimiter.collectionLimiters[1].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
|
||||
multiLimiter.collectionLimiters[2].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
|
||||
multiLimiter.collectionLimiters[3].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
|
||||
}
|
||||
}
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if internalpb.RateType(rt) == internalpb.RateType_DDLFlush {
|
||||
err := multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 5)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 5)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
} else if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
|
||||
err := multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 5)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 5)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
} else {
|
||||
err := multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{2}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = multiLimiter.Check([]int64{3}, internalpb.RateType(rt), 1)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
}
|
||||
}
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
})
|
||||
|
||||
t.Run("not enable quotaAndLimit", func(t *testing.T) {
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
multiLimiter.collectionLimiters[collectionID] = newRateLimiter(false)
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
})
|
||||
|
||||
t.Run("test limit", func(t *testing.T) {
|
||||
run := func(insertRate float64) {
|
||||
bakInsertRate := Params.QuotaConfig.DMLMaxInsertRate.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.DMLMaxInsertRate.Key, fmt.Sprintf("%f", insertRate))
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType_DMLInsert, 1*1024*1024)
|
||||
assert.NoError(t, err)
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate)
|
||||
}
|
||||
run(math.MaxFloat64)
|
||||
run(math.MaxFloat64 / 1.2)
|
||||
run(math.MaxFloat64 / 2)
|
||||
run(math.MaxFloat64 / 3)
|
||||
run(math.MaxFloat64 / 10000)
|
||||
})
|
||||
|
||||
t.Run("test set rates", func(t *testing.T) {
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
zeroRates = append(zeroRates, &internalpb.Rate{
|
||||
Rt: internalpb.RateType(rt), R: 0,
|
||||
})
|
||||
}
|
||||
|
||||
err := multiLimiter.SetRates([]*proxypb.CollectionRate{
|
||||
{
|
||||
Collection: 1,
|
||||
Rates: zeroRates,
|
||||
},
|
||||
{
|
||||
Collection: 2,
|
||||
Rates: zeroRates,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test quota states", func(t *testing.T) {
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
zeroRates = append(zeroRates, &internalpb.Rate{
|
||||
Rt: internalpb.RateType(rt), R: 0,
|
||||
})
|
||||
}
|
||||
|
||||
err := multiLimiter.SetRates([]*proxypb.CollectionRate{
|
||||
{
|
||||
Collection: 1,
|
||||
Rates: zeroRates,
|
||||
States: []milvuspb.QuotaState{
|
||||
milvuspb.QuotaState_DenyToWrite,
|
||||
},
|
||||
Codes: []commonpb.ErrorCode{
|
||||
commonpb.ErrorCode_DiskQuotaExhausted,
|
||||
},
|
||||
},
|
||||
{
|
||||
Collection: 2,
|
||||
Rates: zeroRates,
|
||||
|
||||
States: []milvuspb.QuotaState{
|
||||
milvuspb.QuotaState_DenyToRead,
|
||||
},
|
||||
Codes: []commonpb.ErrorCode{
|
||||
commonpb.ErrorCode_ForceDeny,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
states, codes := multiLimiter.GetQuotaStates()
|
||||
assert.Len(t, states, 2)
|
||||
assert.Len(t, codes, 2)
|
||||
assert.Contains(t, codes, GetQuotaErrorString(commonpb.ErrorCode_DiskQuotaExhausted))
|
||||
assert.Contains(t, codes, GetQuotaErrorString(commonpb.ErrorCode_ForceDeny))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
t.Run("test limit", func(t *testing.T) {
|
||||
paramtable.Get().CleanEvent()
|
||||
limiter := newRateLimiter(false)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
ok, _ := limiter.limit(internalpb.RateType(rt), 1)
|
||||
assert.False(t, ok)
|
||||
ok, _ = limiter.limit(internalpb.RateType(rt), math.MaxInt)
|
||||
assert.False(t, ok)
|
||||
ok, _ = limiter.limit(internalpb.RateType(rt), math.MaxInt)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test setRates", func(t *testing.T) {
|
||||
paramtable.Get().CleanEvent()
|
||||
limiter := newRateLimiter(false)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
|
||||
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
zeroRates = append(zeroRates, &internalpb.Rate{
|
||||
Rt: internalpb.RateType(rt), R: 0,
|
||||
})
|
||||
}
|
||||
err := limiter.setRates(&proxypb.CollectionRate{
|
||||
Collection: 1,
|
||||
Rates: zeroRates,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
for i := 0; i < 100; i++ {
|
||||
ok, _ := limiter.limit(internalpb.RateType(rt), 1)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
err = limiter.setRates(&proxypb.CollectionRate{
|
||||
Collection: 1,
|
||||
States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToRead, milvuspb.QuotaState_DenyToWrite},
|
||||
Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_DiskQuotaExhausted},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, limiter.quotaStates.Len(), 2)
|
||||
|
||||
err = limiter.setRates(&proxypb.CollectionRate{
|
||||
Collection: 1,
|
||||
States: []milvuspb.QuotaState{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, limiter.quotaStates.Len(), 0)
|
||||
})
|
||||
|
||||
t.Run("test get error code", func(t *testing.T) {
|
||||
paramtable.Get().CleanEvent()
|
||||
limiter := newRateLimiter(false)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
|
||||
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
zeroRates = append(zeroRates, &internalpb.Rate{
|
||||
Rt: internalpb.RateType(rt), R: 0,
|
||||
})
|
||||
}
|
||||
err := limiter.setRates(&proxypb.CollectionRate{
|
||||
Collection: 1,
|
||||
Rates: zeroRates,
|
||||
States: []milvuspb.QuotaState{
|
||||
milvuspb.QuotaState_DenyToWrite,
|
||||
milvuspb.QuotaState_DenyToRead,
|
||||
},
|
||||
Codes: []commonpb.ErrorCode{
|
||||
commonpb.ErrorCode_DiskQuotaExhausted,
|
||||
commonpb.ErrorCode_ForceDeny,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, limiter.getQuotaExceededError(internalpb.RateType_DQLQuery))
|
||||
assert.Error(t, limiter.getQuotaExceededError(internalpb.RateType_DMLInsert))
|
||||
})
|
||||
|
||||
t.Run("tests refresh rate by config", func(t *testing.T) {
|
||||
paramtable.Get().CleanEvent()
|
||||
limiter := newRateLimiter(false)
|
||||
|
||||
etcdCli, _ := etcd.GetEtcdClient(
|
||||
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
|
||||
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
|
||||
Params.EtcdCfg.Endpoints.GetAsStrings(),
|
||||
Params.EtcdCfg.EtcdTLSCert.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSKey.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
|
||||
|
||||
Params.Save(Params.QuotaConfig.DDLLimitEnabled.Key, "true")
|
||||
defer Params.Reset(Params.QuotaConfig.DDLLimitEnabled.Key)
|
||||
Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true")
|
||||
defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key)
|
||||
ctx := context.Background()
|
||||
// avoid production precision issues when comparing 0-terminated numbers
|
||||
newRate := fmt.Sprintf("%.2f1", rand.Float64())
|
||||
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate)
|
||||
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate")
|
||||
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid")
|
||||
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate")
|
||||
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/collection/max", "8")
|
||||
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/collection/max")
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
limit, _ := limiter.limiters.Get(internalpb.RateType_DDLCollection)
|
||||
return newRate == limit.Limit().String()
|
||||
}, 20*time.Second, time.Second)
|
||||
|
||||
limit, _ := limiter.limiters.Get(internalpb.RateType_DDLPartition)
|
||||
assert.Equal(t, "+inf", limit.Limit().String())
|
||||
|
||||
limit, _ = limiter.limiters.Get(internalpb.RateType_DMLInsert)
|
||||
assert.Equal(t, "8.388608e+06", limit.Limit().String())
|
||||
})
|
||||
}
|
|
@ -93,7 +93,7 @@ type Proxy struct {
|
|||
dataCoord types.DataCoordClient
|
||||
queryCoord types.QueryCoordClient
|
||||
|
||||
multiRateLimiter *MultiRateLimiter
|
||||
simpleLimiter *SimpleLimiter
|
||||
|
||||
chMgr channelsMgr
|
||||
|
||||
|
@ -147,7 +147,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
|||
factory: factory,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
shardMgr: mgr,
|
||||
multiRateLimiter: NewMultiRateLimiter(),
|
||||
simpleLimiter: NewSimpleLimiter(),
|
||||
lbPolicy: lbPolicy,
|
||||
resourceManager: resourceManager,
|
||||
replicateStreamManager: replicateStreamManager,
|
||||
|
@ -197,7 +197,7 @@ func (node *Proxy) initSession() error {
|
|||
// initRateCollector creates and starts rateCollector in Proxy.
|
||||
func (node *Proxy) initRateCollector() error {
|
||||
var err error
|
||||
rateCol, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
|
||||
rateCol, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -542,8 +542,8 @@ func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string,
|
|||
|
||||
// GetRateLimiter returns the rateLimiter in Proxy.
|
||||
func (node *Proxy) GetRateLimiter() (types.Limiter, error) {
|
||||
if node.multiRateLimiter == nil {
|
||||
if node.simpleLimiter == nil {
|
||||
return nil, fmt.Errorf("nil rate limiter in Proxy")
|
||||
}
|
||||
return node.multiRateLimiter, nil
|
||||
return node.simpleLimiter, nil
|
||||
}
|
||||
|
|
|
@ -298,8 +298,7 @@ func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p *
|
|||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
multiLimiter := NewMultiRateLimiter()
|
||||
s.multiRateLimiter = multiLimiter
|
||||
s.simpleLimiter = NewSimpleLimiter()
|
||||
|
||||
opts := tracer.GetInterceptorOpts()
|
||||
s.grpcServer = grpc.NewServer(
|
||||
|
@ -309,7 +308,7 @@ func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p *
|
|||
grpc.MaxSendMsgSize(p.ServerMaxSendSize.GetAsInt()),
|
||||
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
|
||||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
RateLimitInterceptor(multiLimiter),
|
||||
RateLimitInterceptor(s.simpleLimiter),
|
||||
)),
|
||||
grpc.StreamInterceptor(otelgrpc.StreamServerInterceptor(opts...)))
|
||||
proxypb.RegisterProxyServer(s.grpcServer, s)
|
||||
|
|
|
@ -23,25 +23,29 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
)
|
||||
|
||||
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
|
||||
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
collectionIDs, rt, n, err := getRequestInfo(req)
|
||||
dbID, collectionIDToPartIDs, rt, n, err := getRequestInfo(ctx, req)
|
||||
if err != nil {
|
||||
log.RatedWarn(10, "failed to get request info", zap.Error(err))
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
err = limiter.Check(collectionIDs, rt, n)
|
||||
err = limiter.Check(dbID, collectionIDToPartIDs, rt, n)
|
||||
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
|
||||
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
|
||||
if err != nil {
|
||||
|
@ -56,72 +60,146 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
|
|||
}
|
||||
}
|
||||
|
||||
type reqPartName interface {
|
||||
requestutil.DBNameGetter
|
||||
requestutil.CollectionNameGetter
|
||||
requestutil.PartitionNameGetter
|
||||
}
|
||||
|
||||
type reqPartNames interface {
|
||||
requestutil.DBNameGetter
|
||||
requestutil.CollectionNameGetter
|
||||
requestutil.PartitionNamesGetter
|
||||
}
|
||||
|
||||
type reqCollName interface {
|
||||
requestutil.DBNameGetter
|
||||
requestutil.CollectionNameGetter
|
||||
}
|
||||
|
||||
func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map[int64][]int64, error) {
|
||||
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), r.GetCollectionName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return db.dbID, map[int64][]int64{collectionID: {part.partitionID}}, nil
|
||||
}
|
||||
|
||||
func getCollectionAndPartitionIDs(ctx context.Context, r reqPartNames) (int64, map[int64][]int64, error) {
|
||||
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), r.GetCollectionName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
parts := make([]int64, len(r.GetPartitionNames()))
|
||||
for i, s := range r.GetPartitionNames() {
|
||||
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), s)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
parts[i] = part.partitionID
|
||||
}
|
||||
|
||||
return db.dbID, map[int64][]int64{collectionID: parts}, nil
|
||||
}
|
||||
|
||||
func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
|
||||
db, _ := globalMetaCache.GetDatabaseInfo(context.TODO(), r.GetDbName())
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return db.dbID, map[int64][]int64{collectionID: {}}
|
||||
}
|
||||
|
||||
// getRequestInfo returns collection name and rateType of request and return tokens needed.
|
||||
func getRequestInfo(req interface{}) ([]int64, internalpb.RateType, int, error) {
|
||||
func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
switch r := req.(type) {
|
||||
case *milvuspb.InsertRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DMLInsert, proto.Size(r), nil
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
||||
case *milvuspb.UpsertRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DMLUpsert, proto.Size(r), nil
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLUpsert, proto.Size(r), err
|
||||
case *milvuspb.DeleteRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DMLDelete, proto.Size(r), nil
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
|
||||
case *milvuspb.ImportRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DMLBulkLoad, proto.Size(r), nil
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
|
||||
case *milvuspb.SearchRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DQLSearch, int(r.GetNq()), nil
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
|
||||
case *milvuspb.QueryRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DQLQuery, 1, nil // think of the query request's nq as 1
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
|
||||
case *milvuspb.CreateCollectionRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.DropCollectionRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.LoadCollectionRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.ReleaseCollectionRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
||||
case *milvuspb.CreatePartitionRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.DropPartitionRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.LoadPartitionsRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.ReleasePartitionsRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
||||
case *milvuspb.CreateIndexRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLIndex, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
||||
case *milvuspb.DropIndexRequest:
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
|
||||
return []int64{collectionID}, internalpb.RateType_DDLIndex, 1, nil
|
||||
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
||||
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
||||
case *milvuspb.FlushRequest:
|
||||
collectionIDs := make([]int64, 0, len(r.GetCollectionNames()))
|
||||
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
||||
if err != nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
|
||||
collToPartIDs := make(map[int64][]int64, 0)
|
||||
for _, collectionName := range r.GetCollectionNames() {
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), collectionName)
|
||||
collectionIDs = append(collectionIDs, collectionID)
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
collToPartIDs[collectionID] = []int64{}
|
||||
}
|
||||
return collectionIDs, internalpb.RateType_DDLFlush, 1, nil
|
||||
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
|
||||
case *milvuspb.ManualCompactionRequest:
|
||||
return nil, internalpb.RateType_DDLCompaction, 1, nil
|
||||
// TODO: support more request
|
||||
default:
|
||||
if req == nil {
|
||||
return nil, 0, 0, fmt.Errorf("null request")
|
||||
dbName := GetCurDBNameFromContextOrDefault(ctx)
|
||||
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
|
||||
if err != nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, err
|
||||
}
|
||||
return nil, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
|
||||
return dbInfo.dbID, map[int64][]int64{
|
||||
r.GetCollectionID(): {},
|
||||
}, internalpb.RateType_DDLCompaction, 1, nil
|
||||
default: // TODO: support more request
|
||||
if req == nil {
|
||||
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
}
|
||||
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
@ -38,7 +39,7 @@ type limiterMock struct {
|
|||
quotaStateReasons []commonpb.ErrorCode
|
||||
}
|
||||
|
||||
func (l *limiterMock) Check(collection []int64, rt internalpb.RateType, n int) error {
|
||||
func (l *limiterMock) Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error {
|
||||
if l.rate == 0 {
|
||||
return merr.ErrServiceQuotaExceeded
|
||||
}
|
||||
|
@ -51,119 +52,173 @@ func (l *limiterMock) Check(collection []int64, rt internalpb.RateType, n int) e
|
|||
func TestRateLimitInterceptor(t *testing.T) {
|
||||
t.Run("test getRequestInfo", func(t *testing.T) {
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.On("GetCollectionID",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(int64(0), nil)
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
|
||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
||||
name: "p1",
|
||||
partitionID: 10,
|
||||
createdTimestamp: 10001,
|
||||
createdUtcTimestamp: 10002,
|
||||
}, nil)
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil)
|
||||
globalMetaCache = mockCache
|
||||
collection, rt, size, err := getRequestInfo(&milvuspb.InsertRequest{})
|
||||
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLInsert, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.UpsertRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLUpsert, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.DeleteRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLDelete, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.ImportRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.ImportRequest{}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLBulkLoad, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.SearchRequest{Nq: 5})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.SearchRequest{
|
||||
Nq: 5,
|
||||
PartitionNames: []string{
|
||||
"p1",
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, size)
|
||||
assert.Equal(t, internalpb.RateType_DQLSearch, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 1, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.QueryRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DQLQuery, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.CreateCollectionRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.LoadCollectionRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.ReleaseCollectionRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.DropCollectionRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.CreatePartitionRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.LoadPartitionsRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.ReleasePartitionsRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.DropPartitionRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.CreateIndexRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.DropIndexRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropIndexRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
|
||||
assert.ElementsMatch(t, collection, []int64{int64(0)})
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.FlushRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.FlushRequest{
|
||||
CollectionNames: []string{
|
||||
"col1",
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLFlush, rt)
|
||||
assert.Len(t, collection, 0)
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
|
||||
collection, rt, size, err = getRequestInfo(&milvuspb.ManualCompactionRequest{})
|
||||
database, _, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DDLCompaction, rt)
|
||||
assert.Len(t, collection, 0)
|
||||
assert.Equal(t, database, int64(100))
|
||||
|
||||
_, _, _, _, err = getRequestInfo(context.Background(), nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, _, _, _, err = getRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test getFailedResponse", func(t *testing.T) {
|
||||
|
@ -190,11 +245,17 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||
|
||||
t.Run("test RateLimitInterceptor", func(t *testing.T) {
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.On("GetCollectionID",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(int64(0), nil)
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
|
||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
||||
name: "p1",
|
||||
partitionID: 10,
|
||||
createdTimestamp: 10001,
|
||||
createdUtcTimestamp: 10002,
|
||||
}, nil)
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
limiter := limiterMock{rate: 100}
|
||||
|
@ -224,4 +285,158 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||
assert.Equal(t, commonpb.ErrorCode_ForceDeny, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("request info fail", func(t *testing.T) {
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info"))
|
||||
originCache := globalMetaCache
|
||||
globalMetaCache = mockCache
|
||||
defer func() {
|
||||
globalMetaCache = originCache
|
||||
}()
|
||||
|
||||
limiter := limiterMock{rate: 100}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
}, nil
|
||||
}
|
||||
serverInfo := &grpc.UnaryServerInfo{FullMethod: "MockFullMethod"}
|
||||
|
||||
limiter.limit = true
|
||||
interceptorFun := RateLimitInterceptor(&limiter)
|
||||
rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetInfo(t *testing.T) {
|
||||
mockCache := NewMockCache(t)
|
||||
ctx := context.Background()
|
||||
originCache := globalMetaCache
|
||||
globalMetaCache = mockCache
|
||||
defer func() {
|
||||
globalMetaCache = originCache
|
||||
}()
|
||||
|
||||
t.Run("fail to get database", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(4)
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionName: "p1",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionNames: []string{"p1"},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
|
||||
DbName: "foo",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail to get collection", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Times(3)
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(0), errors.New("mock error: get collection id")).Times(3)
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionName: "p1",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionNames: []string{"p1"},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
|
||||
DbName: "foo",
|
||||
CollectionNames: []string{"coo"},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail to get partition", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Twice()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Twice()
|
||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get partition info")).Twice()
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionName: "p1",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
_, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionNames: []string{"p1"},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Twice()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Twice()
|
||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
||||
name: "p1",
|
||||
partitionID: 100,
|
||||
}, nil)
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionName: "p1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(100), db)
|
||||
assert.NotNil(t, col2par[10])
|
||||
assert.Equal(t, int64(100), col2par[10][0])
|
||||
}
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
PartitionNames: []string{"p1"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(100), db)
|
||||
assert.NotNil(t, col2par[10])
|
||||
assert.Equal(t, int64(100), col2par[10][0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1111,6 +1111,10 @@ func (coord *RootCoordMock) RenameCollection(ctx context.Context, req *milvuspb.
|
|||
return &commonpb.Status{}, nil
|
||||
}
|
||||
|
||||
func (coord *RootCoordMock) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
return &rootcoordpb.DescribeDatabaseResponse{}, nil
|
||||
}
|
||||
|
||||
type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error)
|
||||
|
||||
type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error)
|
||||
|
|
|
@ -0,0 +1,344 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/util/quota"
|
||||
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// SimpleLimiter is implemented based on Limiter interface
|
||||
type SimpleLimiter struct {
|
||||
quotaStatesMu sync.RWMutex
|
||||
rateLimiter *rlinternal.RateLimiterTree
|
||||
}
|
||||
|
||||
// NewSimpleLimiter returns a new SimpleLimiter.
|
||||
func NewSimpleLimiter() *SimpleLimiter {
|
||||
rootRateLimiter := newClusterLimiter()
|
||||
m := &SimpleLimiter{rateLimiter: rlinternal.NewRateLimiterTree(rootRateLimiter)}
|
||||
return m
|
||||
}
|
||||
|
||||
// Check checks if request would be limited or denied.
|
||||
func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error {
|
||||
if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.quotaStatesMu.RLock()
|
||||
defer m.quotaStatesMu.RUnlock()
|
||||
|
||||
// 1. check global(cluster) level rate limits
|
||||
clusterRateLimiters := m.rateLimiter.GetRootLimiters()
|
||||
ret := clusterRateLimiters.Check(rt, n)
|
||||
|
||||
if ret != nil {
|
||||
clusterRateLimiters.Cancel(rt, n)
|
||||
return ret
|
||||
}
|
||||
|
||||
// store done limiters to cancel them when error occurs.
|
||||
doneLimiters := make([]*rlinternal.RateLimiterNode, 0)
|
||||
doneLimiters = append(doneLimiters, clusterRateLimiters)
|
||||
|
||||
cancelAllLimiters := func() {
|
||||
for _, limiter := range doneLimiters {
|
||||
limiter.Cancel(rt, n)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. check database level rate limits
|
||||
if ret == nil {
|
||||
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
|
||||
ret = dbRateLimiters.Check(rt, n)
|
||||
if ret != nil {
|
||||
cancelAllLimiters()
|
||||
return ret
|
||||
}
|
||||
doneLimiters = append(doneLimiters, dbRateLimiters)
|
||||
}
|
||||
|
||||
// 3. check collection level rate limits
|
||||
if ret == nil && len(collectionIDToPartIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
|
||||
for collectionID := range collectionIDToPartIDs {
|
||||
// only dml and dql have collection level rate limits
|
||||
collectionRateLimiters := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
|
||||
newDatabaseLimiter, newCollectionLimiters)
|
||||
ret = collectionRateLimiters.Check(rt, n)
|
||||
if ret != nil {
|
||||
cancelAllLimiters()
|
||||
return ret
|
||||
}
|
||||
doneLimiters = append(doneLimiters, collectionRateLimiters)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. check partition level rate limits
|
||||
if ret == nil && len(collectionIDToPartIDs) > 0 {
|
||||
for collectionID, partitionIDs := range collectionIDToPartIDs {
|
||||
for _, partID := range partitionIDs {
|
||||
partitionRateLimiters := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partID,
|
||||
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
|
||||
ret = partitionRateLimiters.Check(rt, n)
|
||||
if ret != nil {
|
||||
cancelAllLimiters()
|
||||
return ret
|
||||
}
|
||||
doneLimiters = append(doneLimiters, partitionRateLimiters)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool {
|
||||
// Most ddl is global level, only DDLFlush will be applied at collection
|
||||
switch rt {
|
||||
case internalpb.RateType_DDLCollection,
|
||||
internalpb.RateType_DDLPartition,
|
||||
internalpb.RateType_DDLIndex,
|
||||
internalpb.RateType_DDLCompaction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetQuotaStates returns quota states.
|
||||
func (m *SimpleLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) {
|
||||
m.quotaStatesMu.RLock()
|
||||
defer m.quotaStatesMu.RUnlock()
|
||||
serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode])
|
||||
|
||||
rlinternal.TraverseRateLimiterTree(m.rateLimiter.GetRootLimiters(), nil,
|
||||
func(node *rlinternal.RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
|
||||
if serviceStates[state] == nil {
|
||||
serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]()
|
||||
}
|
||||
serviceStates[state].Insert(errCode)
|
||||
return true
|
||||
})
|
||||
|
||||
states := make([]milvuspb.QuotaState, 0)
|
||||
reasons := make([]string, 0)
|
||||
for state, errCodes := range serviceStates {
|
||||
for errCode := range errCodes {
|
||||
states = append(states, state)
|
||||
reasons = append(reasons, ratelimitutil.GetQuotaErrorString(errCode))
|
||||
}
|
||||
}
|
||||
|
||||
return states, reasons
|
||||
}
|
||||
|
||||
// SetRates sets quota states for SimpleLimiter.
|
||||
func (m *SimpleLimiter) SetRates(rootLimiter *proxypb.LimiterNode) error {
|
||||
m.quotaStatesMu.Lock()
|
||||
defer m.quotaStatesMu.Unlock()
|
||||
if err := m.updateRateLimiter(rootLimiter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.rateLimiter.ClearInvalidLimiterNode(rootLimiter)
|
||||
return nil
|
||||
}
|
||||
|
||||
func initLimiter(rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[internalpb.RateType]*paramtable.ParamItem) {
|
||||
log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0)
|
||||
for rt, p := range rateLimiterConfigs {
|
||||
limit := ratelimitutil.Limit(p.GetAsFloat())
|
||||
burst := p.GetAsFloat() // use rate as burst, because SimpleLimiter is with punishment mechanism, burst is insignificant.
|
||||
rln.GetLimiters().GetOrInsert(rt, ratelimitutil.NewLimiter(limit, burst))
|
||||
onEvent := func(rateType internalpb.RateType, formatFunc func(originValue string) string) func(*config.Event) {
|
||||
return func(event *config.Event) {
|
||||
f, err := strconv.ParseFloat(formatFunc(event.Value), 64)
|
||||
if err != nil {
|
||||
log.Info("Error format for rateLimit",
|
||||
zap.String("rateType", rateType.String()),
|
||||
zap.String("key", event.Key),
|
||||
zap.String("value", event.Value),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
l, ok := rln.GetLimiters().Get(rateType)
|
||||
if !ok {
|
||||
log.Info("rateLimiter not found for rateType", zap.String("rateType", rateType.String()))
|
||||
return
|
||||
}
|
||||
l.SetLimit(ratelimitutil.Limit(f))
|
||||
}
|
||||
}(rt, p.Formatter)
|
||||
paramtable.Get().Watch(p.Key, config.NewHandler(fmt.Sprintf("rateLimiter-%d", rt), onEvent))
|
||||
log.RatedDebug(30, "RateLimiter register for rateType",
|
||||
zap.String("rateType", internalpb.RateType_name[(int32(rt))]),
|
||||
zap.String("rateLimit", ratelimitutil.Limit(p.GetAsFloat()).String()),
|
||||
zap.String("burst", fmt.Sprintf("%v", burst)))
|
||||
}
|
||||
}
|
||||
|
||||
// newClusterLimiter init limiter of cluster level for all rate types and rate scopes.
|
||||
// Cluster rate limiter doesn't support to accumulate metrics dynamically, it only uses
|
||||
// configurations as limit values.
|
||||
func newClusterLimiter() *rlinternal.RateLimiterNode {
|
||||
clusterRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
clusterLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Cluster)
|
||||
initLimiter(clusterRateLimiters, clusterLimiterConfigs)
|
||||
return clusterRateLimiters
|
||||
}
|
||||
|
||||
func newDatabaseLimiter() *rlinternal.RateLimiterNode {
|
||||
dbRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Database)
|
||||
databaseLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Database)
|
||||
initLimiter(dbRateLimiters, databaseLimiterConfigs)
|
||||
return dbRateLimiters
|
||||
}
|
||||
|
||||
func newCollectionLimiters() *rlinternal.RateLimiterNode {
|
||||
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Collection)
|
||||
collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Collection)
|
||||
initLimiter(collectionRateLimiters, collectionLimiterConfigs)
|
||||
return collectionRateLimiters
|
||||
}
|
||||
|
||||
func newPartitionLimiters() *rlinternal.RateLimiterNode {
|
||||
partRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Partition)
|
||||
collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Partition)
|
||||
initLimiter(partRateLimiters, collectionLimiterConfigs)
|
||||
return partRateLimiters
|
||||
}
|
||||
|
||||
func (m *SimpleLimiter) updateLimiterNode(req *proxypb.Limiter, node *rlinternal.RateLimiterNode, sourceID string) error {
|
||||
curLimiters := node.GetLimiters()
|
||||
for _, rate := range req.GetRates() {
|
||||
limit, ok := curLimiters.Get(rate.GetRt())
|
||||
if !ok {
|
||||
return fmt.Errorf("unregister rateLimiter for rateType %s", rate.GetRt().String())
|
||||
}
|
||||
limit.SetLimit(ratelimitutil.Limit(rate.GetR()))
|
||||
setRateGaugeByRateType(rate.GetRt(), paramtable.GetNodeID(), sourceID, rate.GetR())
|
||||
}
|
||||
quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
|
||||
states := req.GetStates()
|
||||
codes := req.GetCodes()
|
||||
for i, state := range states {
|
||||
quotaStates.Insert(state, codes[i])
|
||||
}
|
||||
node.SetQuotaStates(quotaStates)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SimpleLimiter) updateRateLimiter(reqRootLimiterNode *proxypb.LimiterNode) error {
|
||||
reqClusterLimiter := reqRootLimiterNode.GetLimiter()
|
||||
clusterLimiter := m.rateLimiter.GetRootLimiters()
|
||||
err := m.updateLimiterNode(reqClusterLimiter, clusterLimiter, "cluster")
|
||||
if err != nil {
|
||||
log.Warn("update cluster rate limiters failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
getDBSourceID := func(dbID int64) string {
|
||||
return fmt.Sprintf("db.%d", dbID)
|
||||
}
|
||||
getCollectionSourceID := func(collectionID int64) string {
|
||||
return fmt.Sprintf("collection.%d", collectionID)
|
||||
}
|
||||
getPartitionSourceID := func(partitionID int64) string {
|
||||
return fmt.Sprintf("partition.%d", partitionID)
|
||||
}
|
||||
|
||||
for dbID, reqDBRateLimiters := range reqRootLimiterNode.GetChildren() {
|
||||
// update database rate limiters
|
||||
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
|
||||
err := m.updateLimiterNode(reqDBRateLimiters.GetLimiter(), dbRateLimiters, getDBSourceID(dbID))
|
||||
if err != nil {
|
||||
log.Warn("update database rate limiters failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// update collection rate limiters
|
||||
for collectionID, reqCollectionRateLimiter := range reqDBRateLimiters.GetChildren() {
|
||||
collectionRateLimiter := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
|
||||
newDatabaseLimiter, newCollectionLimiters)
|
||||
err := m.updateLimiterNode(reqCollectionRateLimiter.GetLimiter(), collectionRateLimiter,
|
||||
getCollectionSourceID(collectionID))
|
||||
if err != nil {
|
||||
log.Warn("update collection rate limiters failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// update partition rate limiters
|
||||
for partitionID, reqPartitionRateLimiters := range reqCollectionRateLimiter.GetChildren() {
|
||||
partitionRateLimiter := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partitionID,
|
||||
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
|
||||
|
||||
err := m.updateLimiterNode(reqPartitionRateLimiters.GetLimiter(), partitionRateLimiter,
|
||||
getPartitionSourceID(partitionID))
|
||||
if err != nil {
|
||||
log.Warn("update partition rate limiters failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setRateGaugeByRateType sets ProxyLimiterRate metrics.
|
||||
func setRateGaugeByRateType(rateType internalpb.RateType, nodeID int64, sourceID string, rate float64) {
|
||||
if ratelimitutil.Limit(rate) == ratelimitutil.Inf {
|
||||
return
|
||||
}
|
||||
nodeIDStr := strconv.FormatInt(nodeID, 10)
|
||||
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, sourceID, rateType.String()).Set(rate)
|
||||
}
|
||||
|
||||
func getDefaultLimiterConfig(scope internalpb.RateScope) map[internalpb.RateType]*paramtable.ParamItem {
|
||||
return quota.GetQuotaConfigMap(scope)
|
||||
}
|
||||
|
||||
func IsDDLRequest(rt internalpb.RateType) bool {
|
||||
switch rt {
|
||||
case internalpb.RateType_DDLCollection,
|
||||
internalpb.RateType_DDLPartition,
|
||||
internalpb.RateType_DDLIndex,
|
||||
internalpb.RateType_DDLFlush,
|
||||
internalpb.RateType_DDLCompaction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,415 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
)
|
||||
|
||||
func TestSimpleRateLimiter(t *testing.T) {
|
||||
collectionID := int64(1)
|
||||
collectionIDToPartIDs := map[int64][]int64{collectionID: {}}
|
||||
t.Run("test simpleRateLimiter", func(t *testing.T) {
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
|
||||
|
||||
simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, collectionID, newDatabaseLimiter,
|
||||
func() *rlinternal.RateLimiterNode {
|
||||
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if IsDDLRequest(internalpb.RateType(rt)) {
|
||||
clusterRateLimiters.GetLimiters().
|
||||
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
|
||||
} else {
|
||||
collectionRateLimiters.GetLimiters().
|
||||
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
}
|
||||
|
||||
return collectionRateLimiters
|
||||
})
|
||||
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if IsDDLRequest(internalpb.RateType(rt)) {
|
||||
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
} else {
|
||||
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), math.MaxInt)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), math.MaxInt)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
}
|
||||
}
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
})
|
||||
|
||||
t.Run("test global static limit", func(t *testing.T) {
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
|
||||
|
||||
collectionIDToPartIDs := map[int64][]int64{
|
||||
1: {},
|
||||
2: {},
|
||||
3: {},
|
||||
}
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(i), newDatabaseLimiter,
|
||||
func() *rlinternal.RateLimiterNode {
|
||||
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if IsDDLRequest(internalpb.RateType(rt)) {
|
||||
clusterRateLimiters.GetLimiters().
|
||||
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
|
||||
} else {
|
||||
clusterRateLimiters.GetLimiters().
|
||||
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
|
||||
collectionRateLimiters.GetLimiters().
|
||||
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
|
||||
}
|
||||
}
|
||||
|
||||
return collectionRateLimiters
|
||||
})
|
||||
}
|
||||
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
if IsDDLRequest(internalpb.RateType(rt)) {
|
||||
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
} else {
|
||||
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
|
||||
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
|
||||
}
|
||||
}
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
})
|
||||
|
||||
t.Run("not enable quotaAndLimit", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
err := simpleLimiter.Check(0, nil, internalpb.RateType(rt), 1)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
})
|
||||
|
||||
t.Run("test limit", func(t *testing.T) {
|
||||
run := func(insertRate float64) {
|
||||
bakInsertRate := Params.QuotaConfig.DMLMaxInsertRate.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.DMLMaxInsertRate.Key, fmt.Sprintf("%f", insertRate))
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
|
||||
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
|
||||
err := simpleLimiter.Check(0, nil, internalpb.RateType_DMLInsert, 1*1024*1024)
|
||||
assert.NoError(t, err)
|
||||
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
|
||||
Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate)
|
||||
}
|
||||
run(math.MaxFloat64)
|
||||
run(math.MaxFloat64 / 1.2)
|
||||
run(math.MaxFloat64 / 2)
|
||||
run(math.MaxFloat64 / 3)
|
||||
run(math.MaxFloat64 / 10000)
|
||||
})
|
||||
|
||||
t.Run("test set rates", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
zeroRates := getZeroCollectionRates()
|
||||
|
||||
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
Limiter: &proxypb.Limiter{
|
||||
Rates: zeroRates,
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
2: {
|
||||
Limiter: &proxypb.Limiter{
|
||||
Rates: zeroRates,
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test quota states", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
// collection limiter
|
||||
Limiter: &proxypb.Limiter{
|
||||
Rates: getZeroCollectionRates(),
|
||||
States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToWrite, milvuspb.QuotaState_DenyToRead},
|
||||
Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_ForceDeny},
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
states, codes := simpleLimiter.GetQuotaStates()
|
||||
assert.Len(t, states, 2)
|
||||
assert.Len(t, codes, 2)
|
||||
assert.Contains(t, codes, ratelimitutil.GetQuotaErrorString(commonpb.ErrorCode_DiskQuotaExhausted))
|
||||
assert.Contains(t, codes, ratelimitutil.GetQuotaErrorString(commonpb.ErrorCode_ForceDeny))
|
||||
})
|
||||
}
|
||||
|
||||
func getZeroRates() []*internalpb.Rate {
|
||||
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
zeroRates = append(zeroRates, &internalpb.Rate{
|
||||
Rt: internalpb.RateType(rt), R: 0,
|
||||
})
|
||||
}
|
||||
return zeroRates
|
||||
}
|
||||
|
||||
func getZeroCollectionRates() []*internalpb.Rate {
|
||||
collectionRate := []internalpb.RateType{
|
||||
internalpb.RateType_DMLInsert,
|
||||
internalpb.RateType_DMLDelete,
|
||||
internalpb.RateType_DMLBulkLoad,
|
||||
internalpb.RateType_DQLSearch,
|
||||
internalpb.RateType_DQLQuery,
|
||||
internalpb.RateType_DDLFlush,
|
||||
}
|
||||
zeroRates := make([]*internalpb.Rate, 0, len(collectionRate))
|
||||
for _, rt := range collectionRate {
|
||||
zeroRates = append(zeroRates, &internalpb.Rate{
|
||||
Rt: rt, R: 0,
|
||||
})
|
||||
}
|
||||
return zeroRates
|
||||
}
|
||||
|
||||
func newCollectionLimiterNode(collectionLimiterNodes map[int64]*proxypb.LimiterNode) *proxypb.LimiterNode {
|
||||
return &proxypb.LimiterNode{
|
||||
// cluster limiter
|
||||
Limiter: &proxypb.Limiter{},
|
||||
// db level
|
||||
Children: map[int64]*proxypb.LimiterNode{
|
||||
0: {
|
||||
// db limiter
|
||||
Limiter: &proxypb.Limiter{},
|
||||
// collection level
|
||||
Children: collectionLimiterNodes,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
t.Run("test limit", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
rootLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
rootLimiters.GetLimiters().Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
ok, _ := rootLimiters.Limit(internalpb.RateType(rt), 1)
|
||||
assert.False(t, ok)
|
||||
ok, _ = rootLimiters.Limit(internalpb.RateType(rt), math.MaxInt)
|
||||
assert.False(t, ok)
|
||||
ok, _ = rootLimiters.Limit(internalpb.RateType(rt), math.MaxInt)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test setRates", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
|
||||
collectionRateLimiters := simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(1), newDatabaseLimiter,
|
||||
func() *rlinternal.RateLimiterNode {
|
||||
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
collectionRateLimiters.GetLimiters().Insert(internalpb.RateType(rt),
|
||||
ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
|
||||
return collectionRateLimiters
|
||||
})
|
||||
|
||||
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
// collection limiter
|
||||
Limiter: &proxypb.Limiter{
|
||||
Rates: getZeroRates(),
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
for i := 0; i < 100; i++ {
|
||||
ok, _ := collectionRateLimiters.Limit(internalpb.RateType(rt), 1)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
err = simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
// collection limiter
|
||||
Limiter: &proxypb.Limiter{
|
||||
States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToRead, milvuspb.QuotaState_DenyToWrite},
|
||||
Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_DiskQuotaExhausted},
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
}))
|
||||
|
||||
collectionRateLimiter := simpleLimiter.rateLimiter.GetCollectionLimiters(0, 1)
|
||||
assert.NotNil(t, collectionRateLimiter)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collectionRateLimiter.GetQuotaStates().Len(), 2)
|
||||
|
||||
err = simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
// collection limiter
|
||||
Limiter: &proxypb.Limiter{
|
||||
States: []milvuspb.QuotaState{},
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collectionRateLimiter.GetQuotaStates().Len(), 0)
|
||||
})
|
||||
|
||||
t.Run("test get error code", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
|
||||
collectionRateLimiters := simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(1), newDatabaseLimiter,
|
||||
func() *rlinternal.RateLimiterNode {
|
||||
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
for _, rt := range internalpb.RateType_value {
|
||||
collectionRateLimiters.GetLimiters().Insert(internalpb.RateType(rt),
|
||||
ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
|
||||
}
|
||||
|
||||
return collectionRateLimiters
|
||||
})
|
||||
|
||||
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
// collection limiter
|
||||
Limiter: &proxypb.Limiter{
|
||||
Rates: getZeroRates(),
|
||||
States: []milvuspb.QuotaState{
|
||||
milvuspb.QuotaState_DenyToWrite,
|
||||
milvuspb.QuotaState_DenyToRead,
|
||||
},
|
||||
Codes: []commonpb.ErrorCode{
|
||||
commonpb.ErrorCode_DiskQuotaExhausted,
|
||||
commonpb.ErrorCode_ForceDeny,
|
||||
},
|
||||
},
|
||||
Children: make(map[int64]*proxypb.LimiterNode),
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, collectionRateLimiters.GetQuotaExceededError(internalpb.RateType_DQLQuery))
|
||||
assert.Error(t, collectionRateLimiters.GetQuotaExceededError(internalpb.RateType_DMLInsert))
|
||||
})
|
||||
|
||||
t.Run("tests refresh rate by config", func(t *testing.T) {
|
||||
simpleLimiter := NewSimpleLimiter()
|
||||
clusterRateLimiter := simpleLimiter.rateLimiter.GetRootLimiters()
|
||||
etcdCli, _ := etcd.GetEtcdClient(
|
||||
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
|
||||
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
|
||||
Params.EtcdCfg.Endpoints.GetAsStrings(),
|
||||
Params.EtcdCfg.EtcdTLSCert.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSKey.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
|
||||
|
||||
Params.Save(Params.QuotaConfig.DDLLimitEnabled.Key, "true")
|
||||
defer Params.Reset(Params.QuotaConfig.DDLLimitEnabled.Key)
|
||||
Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true")
|
||||
defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key)
|
||||
ctx := context.Background()
|
||||
// avoid production precision issues when comparing 0-terminated numbers
|
||||
r := rand.Float64()
|
||||
newRate := fmt.Sprintf("%.2f", r)
|
||||
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate)
|
||||
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate")
|
||||
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid")
|
||||
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate")
|
||||
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/max", "8")
|
||||
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/max")
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
limit, _ := clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DDLCollection)
|
||||
return math.Abs(r-float64(limit.Limit())) < 0.01
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
limit, _ := clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DDLPartition)
|
||||
assert.Equal(t, "+inf", limit.Limit().String())
|
||||
|
||||
limit, _ = clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DMLInsert)
|
||||
assert.True(t, math.Abs(8*1024*1024-float64(limit.Limit())) < 0.01)
|
||||
})
|
||||
}
|
|
@ -59,7 +59,7 @@ func ConstructLabel(subs ...string) string {
|
|||
|
||||
func init() {
|
||||
var err error
|
||||
Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
|
||||
Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, false)
|
||||
if err != nil {
|
||||
log.Fatal("failed to initialize querynode rate collector", zap.Error(err))
|
||||
}
|
||||
|
|
|
@ -234,6 +234,11 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
|
|||
return err
|
||||
}
|
||||
|
||||
db, err := b.s.meta.GetDatabaseByName(ctx, req.GetDbName(), typeutil.MaxTimestamp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partitionIDs := make([]int64, len(colMeta.Partitions))
|
||||
for _, p := range colMeta.Partitions {
|
||||
partitionIDs = append(partitionIDs, p.PartitionID)
|
||||
|
@ -249,6 +254,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
|
|||
PartitionIDs: partitionIDs,
|
||||
StartPositions: colMeta.StartPositions,
|
||||
Properties: req.GetProperties(),
|
||||
DbID: db.ID,
|
||||
}
|
||||
|
||||
resp, err := b.s.dataCoord.BroadcastAlteredCollection(ctx, dcReq)
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
@ -239,6 +240,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
|
|||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(collMeta, nil)
|
||||
mockGetDatabase(meta)
|
||||
c.meta = meta
|
||||
b := newServerBroker(c)
|
||||
ctx := context.Background()
|
||||
|
@ -256,6 +258,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
|
|||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(collMeta, nil)
|
||||
mockGetDatabase(meta)
|
||||
c.meta = meta
|
||||
b := newServerBroker(c)
|
||||
ctx := context.Background()
|
||||
|
@ -273,6 +276,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
|
|||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(collMeta, nil)
|
||||
mockGetDatabase(meta)
|
||||
c.meta = meta
|
||||
b := newServerBroker(c)
|
||||
ctx := context.Background()
|
||||
|
@ -327,3 +331,11 @@ func TestServerBroker_GcConfirm(t *testing.T) {
|
|||
assert.True(t, broker.GcConfirm(context.Background(), 100, 10000))
|
||||
})
|
||||
}
|
||||
|
||||
func mockGetDatabase(meta *mockrootcoord.IMetaTable) {
|
||||
db := model.NewDatabase(1, "default", pb.DatabaseState_DatabaseCreated)
|
||||
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(db, nil).Maybe()
|
||||
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(db, nil).Maybe()
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ func (t *describeCollectionTask) Execute(ctx context.Context) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aliases := t.core.meta.ListAliasesByID(coll.CollectionID)
|
||||
db, err := t.core.meta.GetDatabaseByID(ctx, coll.DBID, t.GetTs())
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package rootcoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// describeDBTask describe database request task
|
||||
type describeDBTask struct {
|
||||
baseTask
|
||||
Req *rootcoordpb.DescribeDatabaseRequest
|
||||
Rsp *rootcoordpb.DescribeDatabaseResponse
|
||||
allowUnavailable bool
|
||||
}
|
||||
|
||||
func (t *describeDBTask) Prepare(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute task execution
|
||||
func (t *describeDBTask) Execute(ctx context.Context) (err error) {
|
||||
db, err := t.core.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
|
||||
if err != nil {
|
||||
t.Rsp = &rootcoordpb.DescribeDatabaseResponse{
|
||||
Status: merr.Status(err),
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
t.Rsp = &rootcoordpb.DescribeDatabaseResponse{
|
||||
Status: merr.Success(),
|
||||
DbID: db.ID,
|
||||
DbName: db.Name,
|
||||
CreatedTimestamp: db.CreatedTime,
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package rootcoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
)
|
||||
|
||||
func Test_describeDatabaseTask_Execute(t *testing.T) {
|
||||
t.Run("failed to get database by name", func(t *testing.T) {
|
||||
core := newTestCore(withInvalidMeta())
|
||||
task := &describeDBTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.DescribeDatabaseRequest{
|
||||
DbName: "testDB",
|
||||
},
|
||||
}
|
||||
err := task.Execute(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.NotNil(t, task.Rsp)
|
||||
assert.NotNil(t, task.Rsp.Status)
|
||||
})
|
||||
|
||||
t.Run("describe with empty database name", func(t *testing.T) {
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(model.NewDefaultDatabase(), nil)
|
||||
core := newTestCore(withMeta(meta))
|
||||
|
||||
task := &describeDBTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.DescribeDatabaseRequest{},
|
||||
}
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, task.Rsp)
|
||||
assert.Equal(t, task.Rsp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
assert.Equal(t, util.DefaultDBName, task.Rsp.GetDbName())
|
||||
assert.Equal(t, util.DefaultDBID, task.Rsp.GetDbID())
|
||||
})
|
||||
|
||||
t.Run("describe with specified database name", func(t *testing.T) {
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(&model.Database{
|
||||
Name: "db1",
|
||||
ID: 100,
|
||||
CreatedTime: 1,
|
||||
}, nil)
|
||||
core := newTestCore(withMeta(meta))
|
||||
|
||||
task := &describeDBTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.DescribeDatabaseRequest{DbName: "db1"},
|
||||
}
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, task.Rsp)
|
||||
assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success))
|
||||
assert.Equal(t, "db1", task.Rsp.GetDbName())
|
||||
assert.Equal(t, int64(100), task.Rsp.GetDbID())
|
||||
assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp())
|
||||
})
|
||||
}
|
|
@ -55,6 +55,7 @@ type IMetaTable interface {
|
|||
RemoveCollection(ctx context.Context, collectionID UniqueID, ts Timestamp) error
|
||||
GetCollectionByName(ctx context.Context, dbName string, collectionName string, ts Timestamp) (*model.Collection, error)
|
||||
GetCollectionByID(ctx context.Context, dbName string, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error)
|
||||
GetCollectionByIDWithMaxTs(ctx context.Context, collectionID UniqueID) (*model.Collection, error)
|
||||
ListCollections(ctx context.Context, dbName string, ts Timestamp, onlyAvail bool) ([]*model.Collection, error)
|
||||
ListAllAvailCollections(ctx context.Context) map[int64][]int64
|
||||
ListCollectionPhysicalChannels() map[typeutil.UniqueID][]string
|
||||
|
@ -362,7 +363,7 @@ func (mt *MetaTable) getDatabaseByNameInternal(_ context.Context, dbName string,
|
|||
|
||||
db, ok := mt.dbName2Meta[dbName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database:%s not found", dbName)
|
||||
return nil, merr.WrapErrDatabaseNotFound(dbName)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
|
@ -519,12 +520,12 @@ func filterUnavailable(coll *model.Collection) *model.Collection {
|
|||
}
|
||||
|
||||
// getLatestCollectionByIDInternal should be called with ts = typeutil.MaxTimestamp
|
||||
func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowAvailable bool) (*model.Collection, error) {
|
||||
func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowUnavailable bool) (*model.Collection, error) {
|
||||
coll, ok := mt.collID2Meta[collectionID]
|
||||
if !ok || coll == nil {
|
||||
return nil, merr.WrapErrCollectionNotFound(collectionID)
|
||||
}
|
||||
if allowAvailable {
|
||||
if allowUnavailable {
|
||||
return coll.Clone(), nil
|
||||
}
|
||||
if !coll.Available() {
|
||||
|
@ -623,6 +624,11 @@ func (mt *MetaTable) GetCollectionByID(ctx context.Context, dbName string, colle
|
|||
return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, allowUnavailable)
|
||||
}
|
||||
|
||||
// GetCollectionByIDWithMaxTs get collection, dbName can be ignored if ts is max timestamps
|
||||
func (mt *MetaTable) GetCollectionByIDWithMaxTs(ctx context.Context, collectionID UniqueID) (*model.Collection, error) {
|
||||
return mt.GetCollectionByID(ctx, "", collectionID, typeutil.MaxTimestamp, false)
|
||||
}
|
||||
|
||||
func (mt *MetaTable) ListAllAvailCollections(ctx context.Context) map[int64][]int64 {
|
||||
mt.ddLock.RLock()
|
||||
defer mt.ddLock.RUnlock()
|
||||
|
|
|
@ -95,6 +95,11 @@ type mockMetaTable struct {
|
|||
DropGrantFunc func(tenant string, role *milvuspb.RoleEntity) error
|
||||
ListPolicyFunc func(tenant string) ([]string, error)
|
||||
ListUserRoleFunc func(tenant string) ([]string, error)
|
||||
DescribeDatabaseFunc func(ctx context.Context, dbName string) (*model.Database, error)
|
||||
}
|
||||
|
||||
func (m mockMetaTable) GetDatabaseByName(ctx context.Context, dbName string, ts Timestamp) (*model.Database, error) {
|
||||
return m.DescribeDatabaseFunc(ctx, dbName)
|
||||
}
|
||||
|
||||
func (m mockMetaTable) ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Database, error) {
|
||||
|
@ -516,6 +521,9 @@ func withInvalidMeta() Opt {
|
|||
meta.ListAliasesFunc = func(ctx context.Context, dbName, collectionName string, ts Timestamp) ([]string, error) {
|
||||
return nil, errors.New("error mock ListAliases")
|
||||
}
|
||||
meta.DescribeDatabaseFunc = func(ctx context.Context, dbName string) (*model.Database, error) {
|
||||
return nil, errors.New("error mock DescribeDatabase")
|
||||
}
|
||||
return withMeta(meta)
|
||||
}
|
||||
|
||||
|
|
|
@ -843,6 +843,61 @@ func (_c *IMetaTable_GetCollectionByID_Call) RunAndReturn(run func(context.Conte
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionByIDWithMaxTs provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *IMetaTable) GetCollectionByIDWithMaxTs(ctx context.Context, collectionID int64) (*model.Collection, error) {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
||||
var r0 *model.Collection
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) (*model.Collection, error)); ok {
|
||||
return rf(ctx, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) *model.Collection); ok {
|
||||
r0 = rf(ctx, collectionID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.Collection)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
|
||||
r1 = rf(ctx, collectionID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// IMetaTable_GetCollectionByIDWithMaxTs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionByIDWithMaxTs'
|
||||
type IMetaTable_GetCollectionByIDWithMaxTs_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetCollectionByIDWithMaxTs is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - collectionID int64
|
||||
func (_e *IMetaTable_Expecter) GetCollectionByIDWithMaxTs(ctx interface{}, collectionID interface{}) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
|
||||
return &IMetaTable_GetCollectionByIDWithMaxTs_Call{Call: _e.mock.On("GetCollectionByIDWithMaxTs", ctx, collectionID)}
|
||||
}
|
||||
|
||||
func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) Run(run func(ctx context.Context, collectionID int64)) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) Return(_a0 *model.Collection, _a1 error) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) RunAndReturn(run func(context.Context, int64) (*model.Collection, error)) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionByName provides a mock function with given fields: ctx, dbName, collectionName, ts
|
||||
func (_m *IMetaTable) GetCollectionByName(ctx context.Context, dbName string, collectionName string, ts uint64) (*model.Collection, error) {
|
||||
ret := _m.Called(ctx, dbName, collectionName, ts)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -2669,6 +2669,40 @@ func (c *Core) RenameCollection(ctx context.Context, req *milvuspb.RenameCollect
|
|||
return merr.Success(), nil
|
||||
}
|
||||
|
||||
func (c *Core) DescribeDatabase(ctx context.Context, req *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
|
||||
return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.String("dbName", req.GetDbName()))
|
||||
log.Info("received request to describe database ")
|
||||
|
||||
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.TotalLabel).Inc()
|
||||
tr := timerecord.NewTimeRecorder("DescribeDatabase")
|
||||
t := &describeDBTask{
|
||||
baseTask: newBaseTask(ctx, c),
|
||||
Req: req,
|
||||
}
|
||||
|
||||
if err := c.scheduler.AddTask(t); err != nil {
|
||||
log.Warn("failed to enqueue request to describe database", zap.Error(err))
|
||||
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.FailLabel).Inc()
|
||||
return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
|
||||
if err := t.WaitToFinish(); err != nil {
|
||||
log.Warn("failed to describe database", zap.Uint64("ts", t.GetTs()), zap.Error(err))
|
||||
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.FailLabel).Inc()
|
||||
return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
|
||||
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.SuccessLabel).Inc()
|
||||
metrics.RootCoordDDLReqLatency.WithLabelValues("DescribeDatabase").Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
log.Info("done to describe database", zap.Uint64("ts", t.GetTs()))
|
||||
return t.Rsp, nil
|
||||
}
|
||||
|
||||
func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
|
||||
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
|
||||
return &milvuspb.CheckHealthResponse{
|
||||
|
|
|
@ -1443,6 +1443,43 @@ func TestRootCoord_CheckHealth(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRootCoord_DescribeDatabase(t *testing.T) {
|
||||
t.Run("not healthy", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := newTestCore(withAbnormalCode())
|
||||
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil))
|
||||
})
|
||||
|
||||
t.Run("add task failed", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := newTestCore(withHealthyCode(),
|
||||
withInvalidScheduler())
|
||||
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil))
|
||||
})
|
||||
|
||||
t.Run("execute task failed", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := newTestCore(withHealthyCode(),
|
||||
withTaskFailScheduler())
|
||||
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil))
|
||||
})
|
||||
|
||||
t.Run("run ok", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := newTestCore(withHealthyCode(),
|
||||
withValidScheduler())
|
||||
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, merr.CheckRPCCall(resp.GetStatus(), nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRootCoord_RBACError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := newTestCore(withHealthyCode(), withInvalidMeta())
|
||||
|
|
|
@ -138,13 +138,16 @@ func getCollectionRateLimitConfigDefaultValue(configKey string) float64 {
|
|||
return Params.QuotaConfig.DQLMinSearchRatePerCollection.GetAsFloat()
|
||||
case common.CollectionDiskQuotaKey:
|
||||
return Params.QuotaConfig.DiskQuotaPerCollection.GetAsFloat()
|
||||
|
||||
default:
|
||||
return float64(0)
|
||||
}
|
||||
}
|
||||
|
||||
func getCollectionRateLimitConfig(properties map[string]string, configKey string) float64 {
|
||||
return getRateLimitConfig(properties, configKey, getCollectionRateLimitConfigDefaultValue(configKey))
|
||||
}
|
||||
|
||||
func getRateLimitConfig(properties map[string]string, configKey string, configValue float64) float64 {
|
||||
megaBytes2Bytes := func(v float64) float64 {
|
||||
return v * 1024.0 * 1024.0
|
||||
}
|
||||
|
@ -189,15 +192,15 @@ func getCollectionRateLimitConfig(properties map[string]string, configKey string
|
|||
log.Warn("invalid configuration for collection dml rate",
|
||||
zap.String("config item", configKey),
|
||||
zap.String("config value", v))
|
||||
return getCollectionRateLimitConfigDefaultValue(configKey)
|
||||
return configValue
|
||||
}
|
||||
|
||||
rateInBytes := toBytesIfNecessary(rate)
|
||||
if rateInBytes < 0 {
|
||||
return getCollectionRateLimitConfigDefaultValue(configKey)
|
||||
return configValue
|
||||
}
|
||||
return rateInBytes
|
||||
}
|
||||
|
||||
return getCollectionRateLimitConfigDefaultValue(configKey)
|
||||
return configValue
|
||||
}
|
||||
|
|
|
@ -292,3 +292,27 @@ func Test_getCollectionRateLimitConfig(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRateLimitConfigErr(t *testing.T) {
|
||||
key := common.CollectionQueryRateMaxKey
|
||||
t.Run("negative value", func(t *testing.T) {
|
||||
v := getRateLimitConfig(map[string]string{
|
||||
key: "-1",
|
||||
}, key, 1)
|
||||
assert.EqualValues(t, 1, v)
|
||||
})
|
||||
|
||||
t.Run("valid value", func(t *testing.T) {
|
||||
v := getRateLimitConfig(map[string]string{
|
||||
key: "1",
|
||||
}, key, 100)
|
||||
assert.EqualValues(t, 1, v)
|
||||
})
|
||||
|
||||
t.Run("not exist value", func(t *testing.T) {
|
||||
v := getRateLimitConfig(map[string]string{
|
||||
key: "1",
|
||||
}, "b", 100)
|
||||
assert.EqualValues(t, 100, v)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ import (
|
|||
// If Limit function return true, the request will be rejected.
|
||||
// Otherwise, the request will pass. Limit also returns limit of limiter.
|
||||
type Limiter interface {
|
||||
Check(collectionIDs []int64, rt internalpb.RateType, n int) error
|
||||
Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error
|
||||
}
|
||||
|
||||
// Component is the interface all services implement
|
||||
|
|
|
@ -37,6 +37,10 @@ type GrpcRootCoordClient struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
func (m *GrpcRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
|
||||
return &rootcoordpb.DescribeDatabaseResponse{}, m.Err
|
||||
}
|
||||
|
||||
func (m *GrpcRootCoordClient) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package quota
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
var (
|
||||
initOnce sync.Once
|
||||
limitConfigMap map[internalpb.RateScope]map[internalpb.RateType]*paramtable.ParamItem
|
||||
)
|
||||
|
||||
func initLimitConfigMaps() {
|
||||
initOnce.Do(func() {
|
||||
quotaConfig := ¶mtable.Get().QuotaConfig
|
||||
limitConfigMap = map[internalpb.RateScope]map[internalpb.RateType]*paramtable.ParamItem{
|
||||
internalpb.RateScope_Cluster: {
|
||||
internalpb.RateType_DDLCollection: "aConfig.DDLCollectionRate,
|
||||
internalpb.RateType_DDLPartition: "aConfig.DDLPartitionRate,
|
||||
internalpb.RateType_DDLIndex: "aConfig.MaxIndexRate,
|
||||
internalpb.RateType_DDLFlush: "aConfig.MaxFlushRate,
|
||||
internalpb.RateType_DDLCompaction: "aConfig.MaxCompactionRate,
|
||||
internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRate,
|
||||
internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRate,
|
||||
internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRate,
|
||||
internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRate,
|
||||
internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRate,
|
||||
internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRate,
|
||||
},
|
||||
internalpb.RateScope_Database: {
|
||||
internalpb.RateType_DDLCollection: "aConfig.DDLCollectionRatePerDB,
|
||||
internalpb.RateType_DDLPartition: "aConfig.DDLPartitionRatePerDB,
|
||||
internalpb.RateType_DDLIndex: "aConfig.MaxIndexRatePerDB,
|
||||
internalpb.RateType_DDLFlush: "aConfig.MaxFlushRatePerDB,
|
||||
internalpb.RateType_DDLCompaction: "aConfig.MaxCompactionRatePerDB,
|
||||
internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRatePerDB,
|
||||
internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRatePerDB,
|
||||
internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRatePerDB,
|
||||
internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRatePerDB,
|
||||
internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRatePerDB,
|
||||
internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRatePerDB,
|
||||
},
|
||||
internalpb.RateScope_Collection: {
|
||||
internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRatePerCollection,
|
||||
internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRatePerCollection,
|
||||
internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRatePerCollection,
|
||||
internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRatePerCollection,
|
||||
internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRatePerCollection,
|
||||
internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRatePerCollection,
|
||||
internalpb.RateType_DDLFlush: "aConfig.MaxFlushRatePerCollection,
|
||||
},
|
||||
internalpb.RateScope_Partition: {
|
||||
internalpb.RateType_DMLInsert: "aConfig.DMLMaxInsertRatePerPartition,
|
||||
internalpb.RateType_DMLUpsert: "aConfig.DMLMaxUpsertRatePerPartition,
|
||||
internalpb.RateType_DMLDelete: "aConfig.DMLMaxDeleteRatePerPartition,
|
||||
internalpb.RateType_DMLBulkLoad: "aConfig.DMLMaxBulkLoadRatePerPartition,
|
||||
internalpb.RateType_DQLSearch: "aConfig.DQLMaxSearchRatePerPartition,
|
||||
internalpb.RateType_DQLQuery: "aConfig.DQLMaxQueryRatePerPartition,
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func GetQuotaConfigMap(scope internalpb.RateScope) map[internalpb.RateType]*paramtable.ParamItem {
|
||||
initLimitConfigMaps()
|
||||
configMap, ok := limitConfigMap[scope]
|
||||
if !ok {
|
||||
log.Warn("Unknown rate scope", zap.Any("scope", scope))
|
||||
return make(map[internalpb.RateType]*paramtable.ParamItem)
|
||||
}
|
||||
return configMap
|
||||
}
|
||||
|
||||
func GetQuotaValue(scope internalpb.RateScope, rateType internalpb.RateType, params *paramtable.ComponentParam) float64 {
|
||||
configMap := GetQuotaConfigMap(scope)
|
||||
config, ok := configMap[rateType]
|
||||
if !ok {
|
||||
log.Warn("Unknown rate type", zap.Any("rateType", rateType))
|
||||
return math.MaxFloat64
|
||||
}
|
||||
return config.GetAsFloat()
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package quota
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestGetQuotaConfigMap(t *testing.T) {
|
||||
paramtable.Init()
|
||||
{
|
||||
m := GetQuotaConfigMap(internalpb.RateScope_Cluster)
|
||||
assert.Equal(t, 11, len(m))
|
||||
}
|
||||
{
|
||||
m := GetQuotaConfigMap(internalpb.RateScope_Database)
|
||||
assert.Equal(t, 11, len(m))
|
||||
}
|
||||
{
|
||||
m := GetQuotaConfigMap(internalpb.RateScope_Collection)
|
||||
assert.Equal(t, 7, len(m))
|
||||
}
|
||||
{
|
||||
m := GetQuotaConfigMap(internalpb.RateScope_Partition)
|
||||
assert.Equal(t, 6, len(m))
|
||||
}
|
||||
{
|
||||
m := GetQuotaConfigMap(internalpb.RateScope(1000))
|
||||
assert.Equal(t, 0, len(m))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQuotaValue(t *testing.T) {
|
||||
paramtable.Init()
|
||||
param := paramtable.Get()
|
||||
param.Save(param.QuotaConfig.DDLLimitEnabled.Key, "true")
|
||||
defer param.Reset(param.QuotaConfig.DDLLimitEnabled.Key)
|
||||
param.Save(param.QuotaConfig.DMLLimitEnabled.Key, "true")
|
||||
defer param.Reset(param.QuotaConfig.DMLLimitEnabled.Key)
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
param.Save(param.QuotaConfig.DDLCollectionRate.Key, "10")
|
||||
defer param.Reset(param.QuotaConfig.DDLCollectionRate.Key)
|
||||
v := GetQuotaValue(internalpb.RateScope_Cluster, internalpb.RateType_DDLCollection, param)
|
||||
assert.EqualValues(t, 10, v)
|
||||
})
|
||||
t.Run("database", func(t *testing.T) {
|
||||
param.Save(param.QuotaConfig.DDLCollectionRatePerDB.Key, "10")
|
||||
defer param.Reset(param.QuotaConfig.DDLCollectionRatePerDB.Key)
|
||||
v := GetQuotaValue(internalpb.RateScope_Database, internalpb.RateType_DDLCollection, param)
|
||||
assert.EqualValues(t, 10, v)
|
||||
})
|
||||
t.Run("collection", func(t *testing.T) {
|
||||
param.Save(param.QuotaConfig.DMLMaxInsertRatePerCollection.Key, "10")
|
||||
defer param.Reset(param.QuotaConfig.DMLMaxInsertRatePerCollection.Key)
|
||||
v := GetQuotaValue(internalpb.RateScope_Collection, internalpb.RateType_DMLInsert, param)
|
||||
assert.EqualValues(t, 10*1024*1024, v)
|
||||
})
|
||||
t.Run("partition", func(t *testing.T) {
|
||||
param.Save(param.QuotaConfig.DMLMaxInsertRatePerPartition.Key, "10")
|
||||
defer param.Reset(param.QuotaConfig.DMLMaxInsertRatePerPartition.Key)
|
||||
v := GetQuotaValue(internalpb.RateScope_Partition, internalpb.RateType_DMLInsert, param)
|
||||
assert.EqualValues(t, 10*1024*1024, v)
|
||||
})
|
||||
t.Run("unknown", func(t *testing.T) {
|
||||
v := GetQuotaValue(internalpb.RateScope(1000), internalpb.RateType(1000), param)
|
||||
assert.EqualValues(t, math.MaxFloat64, v)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,336 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ratelimitutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type RateLimiterNode struct {
|
||||
limiters *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]
|
||||
quotaStates *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]
|
||||
level internalpb.RateScope
|
||||
|
||||
// db id, collection id or partition id, cluster id is 0 for the cluster level
|
||||
id int64
|
||||
|
||||
// children will be databases if current level is cluster
|
||||
// children will be collections if current level is database
|
||||
// children will be partitions if current level is collection
|
||||
children *typeutil.ConcurrentMap[int64, *RateLimiterNode]
|
||||
}
|
||||
|
||||
func NewRateLimiterNode(level internalpb.RateScope) *RateLimiterNode {
|
||||
rln := &RateLimiterNode{
|
||||
limiters: typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter](),
|
||||
quotaStates: typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode](),
|
||||
children: typeutil.NewConcurrentMap[int64, *RateLimiterNode](),
|
||||
level: level,
|
||||
}
|
||||
return rln
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) Level() internalpb.RateScope {
|
||||
return rln.level
|
||||
}
|
||||
|
||||
// Limit returns true, the request will be rejected.
|
||||
// Otherwise, the request will pass.
|
||||
func (rln *RateLimiterNode) Limit(rt internalpb.RateType, n int) (bool, float64) {
|
||||
limit, ok := rln.limiters.Get(rt)
|
||||
if !ok {
|
||||
return false, -1
|
||||
}
|
||||
return !limit.AllowN(time.Now(), n), float64(limit.Limit())
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) Cancel(rt internalpb.RateType, n int) {
|
||||
limit, ok := rln.limiters.Get(rt)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
limit.Cancel(n)
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) Check(rt internalpb.RateType, n int) error {
|
||||
limit, rate := rln.Limit(rt, n)
|
||||
if rate == 0 {
|
||||
return rln.GetQuotaExceededError(rt)
|
||||
}
|
||||
if limit {
|
||||
return rln.GetRateLimitError(rate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetQuotaExceededError(rt internalpb.RateType) error {
|
||||
switch rt {
|
||||
case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad:
|
||||
if errCode, ok := rln.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok {
|
||||
return merr.WrapErrServiceQuotaExceeded(ratelimitutil.GetQuotaErrorString(errCode))
|
||||
}
|
||||
case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery:
|
||||
if errCode, ok := rln.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok {
|
||||
return merr.WrapErrServiceQuotaExceeded(ratelimitutil.GetQuotaErrorString(errCode))
|
||||
}
|
||||
}
|
||||
return merr.WrapErrServiceQuotaExceeded(fmt.Sprintf("rate type: %s", rt.String()))
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetRateLimitError(rate float64) error {
|
||||
return merr.WrapErrServiceRateLimit(rate, "request is rejected by grpc RateLimiter middleware, please retry later")
|
||||
}
|
||||
|
||||
func TraverseRateLimiterTree(root *RateLimiterNode, fn1 func(internalpb.RateType, *ratelimitutil.Limiter) bool,
|
||||
fn2 func(node *RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool,
|
||||
) {
|
||||
if fn1 != nil {
|
||||
root.limiters.Range(fn1)
|
||||
}
|
||||
|
||||
if fn2 != nil {
|
||||
root.quotaStates.Range(func(state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
|
||||
return fn2(root, state, errCode)
|
||||
})
|
||||
}
|
||||
root.GetChildren().Range(func(key int64, child *RateLimiterNode) bool {
|
||||
TraverseRateLimiterTree(child, fn1, fn2)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) AddChild(key int64, child *RateLimiterNode) {
|
||||
rln.children.Insert(key, child)
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetChild(key int64) *RateLimiterNode {
|
||||
n, _ := rln.children.Get(key)
|
||||
return n
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetChildren() *typeutil.ConcurrentMap[int64, *RateLimiterNode] {
|
||||
return rln.children
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetLimiters() *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter] {
|
||||
return rln.limiters
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) SetLimiters(new *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]) {
|
||||
rln.limiters = new
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetQuotaStates() *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode] {
|
||||
return rln.quotaStates
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) SetQuotaStates(new *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]) {
|
||||
rln.quotaStates = new
|
||||
}
|
||||
|
||||
func (rln *RateLimiterNode) GetID() int64 {
|
||||
return rln.id
|
||||
}
|
||||
|
||||
// RateLimiterTree is implemented based on RateLimiterNode to operate multilevel rate limiters
|
||||
//
|
||||
// it contains the following four levels generally:
|
||||
//
|
||||
// -> global level
|
||||
// -> database level
|
||||
// -> collection level
|
||||
// -> partition levelearl
|
||||
type RateLimiterTree struct {
|
||||
root *RateLimiterNode
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimiterTree returns a new RateLimiterTree.
|
||||
func NewRateLimiterTree(root *RateLimiterNode) *RateLimiterTree {
|
||||
return &RateLimiterTree{root: root}
|
||||
}
|
||||
|
||||
// GetRootLimiters get root limiters
|
||||
func (m *RateLimiterTree) GetRootLimiters() *RateLimiterNode {
|
||||
return m.root
|
||||
}
|
||||
|
||||
func (m *RateLimiterTree) ClearInvalidLimiterNode(req *proxypb.LimiterNode) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
reqDBLimits := req.GetChildren()
|
||||
removeDBLimits := make([]int64, 0)
|
||||
m.GetRootLimiters().GetChildren().Range(func(key int64, _ *RateLimiterNode) bool {
|
||||
if _, ok := reqDBLimits[key]; !ok {
|
||||
removeDBLimits = append(removeDBLimits, key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
for _, dbID := range removeDBLimits {
|
||||
m.GetRootLimiters().GetChildren().Remove(dbID)
|
||||
}
|
||||
|
||||
m.GetRootLimiters().GetChildren().Range(func(dbID int64, dbNode *RateLimiterNode) bool {
|
||||
reqCollectionLimits := reqDBLimits[dbID].GetChildren()
|
||||
removeCollectionLimits := make([]int64, 0)
|
||||
dbNode.GetChildren().Range(func(key int64, _ *RateLimiterNode) bool {
|
||||
if _, ok := reqCollectionLimits[key]; !ok {
|
||||
removeCollectionLimits = append(removeCollectionLimits, key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
for _, collectionID := range removeCollectionLimits {
|
||||
dbNode.GetChildren().Remove(collectionID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
m.GetRootLimiters().GetChildren().Range(func(dbID int64, dbNode *RateLimiterNode) bool {
|
||||
dbNode.GetChildren().Range(func(collectionID int64, collectionNode *RateLimiterNode) bool {
|
||||
reqPartitionLimits := reqDBLimits[dbID].GetChildren()[collectionID].GetChildren()
|
||||
removePartitionLimits := make([]int64, 0)
|
||||
collectionNode.GetChildren().Range(func(key int64, _ *RateLimiterNode) bool {
|
||||
if _, ok := reqPartitionLimits[key]; !ok {
|
||||
removePartitionLimits = append(removePartitionLimits, key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
for _, partitionID := range removePartitionLimits {
|
||||
collectionNode.GetChildren().Remove(partitionID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (m *RateLimiterTree) GetDatabaseLimiters(dbID int64) *RateLimiterNode {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.root.GetChild(dbID)
|
||||
}
|
||||
|
||||
// GetOrCreateDatabaseLimiters get limiter of database level, or create a database limiter if it doesn't exist.
|
||||
func (m *RateLimiterTree) GetOrCreateDatabaseLimiters(dbID int64, newDBRateLimiter func() *RateLimiterNode) *RateLimiterNode {
|
||||
dbRateLimiters := m.GetDatabaseLimiters(dbID)
|
||||
if dbRateLimiters != nil {
|
||||
return dbRateLimiters
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cur := m.root.GetChild(dbID); cur != nil {
|
||||
return cur
|
||||
}
|
||||
dbRateLimiters = newDBRateLimiter()
|
||||
dbRateLimiters.id = dbID
|
||||
m.root.AddChild(dbID, dbRateLimiters)
|
||||
return dbRateLimiters
|
||||
}
|
||||
|
||||
func (m *RateLimiterTree) GetCollectionLimiters(dbID, collectionID int64) *RateLimiterNode {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
dbRateLimiters := m.root.GetChild(dbID)
|
||||
|
||||
// database rate limiter not found
|
||||
if dbRateLimiters == nil {
|
||||
return nil
|
||||
}
|
||||
return dbRateLimiters.GetChild(collectionID)
|
||||
}
|
||||
|
||||
// GetOrCreateCollectionLimiters create limiter of collection level for all rate types and rate scopes.
|
||||
// create a database rate limiters if db rate limiter does not exist
|
||||
func (m *RateLimiterTree) GetOrCreateCollectionLimiters(dbID, collectionID int64,
|
||||
newDBRateLimiter func() *RateLimiterNode, newCollectionRateLimiter func() *RateLimiterNode,
|
||||
) *RateLimiterNode {
|
||||
collectionRateLimiters := m.GetCollectionLimiters(dbID, collectionID)
|
||||
if collectionRateLimiters != nil {
|
||||
return collectionRateLimiters
|
||||
}
|
||||
|
||||
dbRateLimiters := m.GetOrCreateDatabaseLimiters(dbID, newDBRateLimiter)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cur := dbRateLimiters.GetChild(collectionID); cur != nil {
|
||||
return cur
|
||||
}
|
||||
|
||||
collectionRateLimiters = newCollectionRateLimiter()
|
||||
collectionRateLimiters.id = collectionID
|
||||
dbRateLimiters.AddChild(collectionID, collectionRateLimiters)
|
||||
return collectionRateLimiters
|
||||
}
|
||||
|
||||
// It checks if the rate limiters exist for the database, collection, and partition,
|
||||
// returns the corresponding rate limiter tree.
|
||||
func (m *RateLimiterTree) GetPartitionLimiters(dbID, collectionID, partitionID int64) *RateLimiterNode {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
dbRateLimiters := m.root.GetChild(dbID)
|
||||
|
||||
// database rate limiter not found
|
||||
if dbRateLimiters == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
collectionRateLimiters := dbRateLimiters.GetChild(collectionID)
|
||||
|
||||
// collection rate limiter not found
|
||||
if collectionRateLimiters == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return collectionRateLimiters.GetChild(partitionID)
|
||||
}
|
||||
|
||||
// GetOrCreatePartitionLimiters create limiter of partition level for all rate types and rate scopes.
|
||||
// create a database rate limiters if db rate limiter does not exist
|
||||
// create a collection rate limiters if collection rate limiter does not exist
|
||||
func (m *RateLimiterTree) GetOrCreatePartitionLimiters(dbID int64, collectionID int64, partitionID int64,
|
||||
newDBRateLimiter func() *RateLimiterNode, newCollectionRateLimiter func() *RateLimiterNode,
|
||||
newPartRateLimiter func() *RateLimiterNode,
|
||||
) *RateLimiterNode {
|
||||
partRateLimiters := m.GetPartitionLimiters(dbID, collectionID, partitionID)
|
||||
if partRateLimiters != nil {
|
||||
return partRateLimiters
|
||||
}
|
||||
|
||||
collectionRateLimiters := m.GetOrCreateCollectionLimiters(dbID, collectionID, newDBRateLimiter, newCollectionRateLimiter)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cur := collectionRateLimiters.GetChild(partitionID); cur != nil {
|
||||
return cur
|
||||
}
|
||||
|
||||
partRateLimiters = newPartRateLimiter()
|
||||
partRateLimiters.id = partitionID
|
||||
collectionRateLimiters.AddChild(partitionID, partRateLimiters)
|
||||
return partRateLimiters
|
||||
}
|
|
@ -0,0 +1,205 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ratelimitutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestRateLimiterNode_AddAndGetChild(t *testing.T) {
|
||||
rln := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
child := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
|
||||
// Positive test case
|
||||
rln.AddChild(1, child)
|
||||
if rln.GetChild(1) != child {
|
||||
t.Error("AddChild did not add the child correctly")
|
||||
}
|
||||
|
||||
// Negative test case
|
||||
invalidChild := &RateLimiterNode{}
|
||||
rln.AddChild(2, child)
|
||||
if rln.GetChild(2) == invalidChild {
|
||||
t.Error("AddChild added an invalid child")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraverseRateLimiterTree(t *testing.T) {
|
||||
limiters := typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]()
|
||||
limiters.Insert(internalpb.RateType_DDLCollection, ratelimitutil.NewLimiter(ratelimitutil.Inf, 0))
|
||||
quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
|
||||
quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny)
|
||||
|
||||
root := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
root.SetLimiters(limiters)
|
||||
root.SetQuotaStates(quotaStates)
|
||||
|
||||
// Add a child to the root node
|
||||
child := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
child.SetLimiters(limiters)
|
||||
child.SetQuotaStates(quotaStates)
|
||||
root.AddChild(123, child)
|
||||
|
||||
// Add a child to the root node
|
||||
child2 := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
child2.SetLimiters(limiters)
|
||||
child2.SetQuotaStates(quotaStates)
|
||||
child.AddChild(123, child2)
|
||||
|
||||
// Positive test case for fn1
|
||||
var fn1Count int
|
||||
fn1 := func(rateType internalpb.RateType, limiter *ratelimitutil.Limiter) bool {
|
||||
fn1Count++
|
||||
return true
|
||||
}
|
||||
|
||||
// Negative test case for fn2
|
||||
var fn2Count int
|
||||
fn2 := func(node *RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
|
||||
fn2Count++
|
||||
return true
|
||||
}
|
||||
|
||||
// Call TraverseRateLimiterTree with fn1 and fn2
|
||||
TraverseRateLimiterTree(root, fn1, fn2)
|
||||
|
||||
assert.Equal(t, 3, fn1Count)
|
||||
assert.Equal(t, 3, fn2Count)
|
||||
}
|
||||
|
||||
func TestRateLimiterNodeCancel(t *testing.T) {
|
||||
t.Run("cancel not exist type", func(t *testing.T) {
|
||||
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
limitNode.Cancel(internalpb.RateType_DMLInsert, 10)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimiterNodeCheck(t *testing.T) {
|
||||
t.Run("quota exceed", func(t *testing.T) {
|
||||
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
limitNode.limiters.Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(0, 0))
|
||||
limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny)
|
||||
err := limitNode.Check(internalpb.RateType_DMLInsert, 10)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
|
||||
})
|
||||
|
||||
t.Run("rate limit", func(t *testing.T) {
|
||||
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
limitNode.limiters.Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(0.01, 0.01))
|
||||
{
|
||||
err := limitNode.Check(internalpb.RateType_DMLInsert, 1)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
{
|
||||
err := limitNode.Check(internalpb.RateType_DMLInsert, 1)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceRateLimit))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
|
||||
t.Run("write", func(t *testing.T) {
|
||||
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny)
|
||||
err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
|
||||
// reference: ratelimitutil.GetQuotaErrorString(errCode)
|
||||
assert.True(t, strings.Contains(err.Error(), "deactivated"))
|
||||
})
|
||||
|
||||
t.Run("read", func(t *testing.T) {
|
||||
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToRead, commonpb.ErrorCode_ForceDeny)
|
||||
err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
|
||||
// reference: ratelimitutil.GetQuotaErrorString(errCode)
|
||||
assert.True(t, strings.Contains(err.Error(), "deactivated"))
|
||||
})
|
||||
|
||||
t.Run("unknown", func(t *testing.T) {
|
||||
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
err := limitNode.GetQuotaExceededError(internalpb.RateType_DDLCompaction)
|
||||
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
|
||||
assert.True(t, strings.Contains(err.Error(), "rate type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimiterTreeClearInvalidLimiterNode(t *testing.T) {
|
||||
root := NewRateLimiterNode(internalpb.RateScope_Cluster)
|
||||
tree := NewRateLimiterTree(root)
|
||||
|
||||
generateNodeFFunc := func(level internalpb.RateScope) func() *RateLimiterNode {
|
||||
return func() *RateLimiterNode {
|
||||
return NewRateLimiterNode(level)
|
||||
}
|
||||
}
|
||||
|
||||
tree.GetOrCreatePartitionLimiters(1, 10, 100,
|
||||
generateNodeFFunc(internalpb.RateScope_Database),
|
||||
generateNodeFFunc(internalpb.RateScope_Collection),
|
||||
generateNodeFFunc(internalpb.RateScope_Partition),
|
||||
)
|
||||
tree.GetOrCreatePartitionLimiters(1, 10, 200,
|
||||
generateNodeFFunc(internalpb.RateScope_Database),
|
||||
generateNodeFFunc(internalpb.RateScope_Collection),
|
||||
generateNodeFFunc(internalpb.RateScope_Partition),
|
||||
)
|
||||
tree.GetOrCreatePartitionLimiters(1, 20, 300,
|
||||
generateNodeFFunc(internalpb.RateScope_Database),
|
||||
generateNodeFFunc(internalpb.RateScope_Collection),
|
||||
generateNodeFFunc(internalpb.RateScope_Partition),
|
||||
)
|
||||
tree.GetOrCreatePartitionLimiters(2, 30, 400,
|
||||
generateNodeFFunc(internalpb.RateScope_Database),
|
||||
generateNodeFFunc(internalpb.RateScope_Collection),
|
||||
generateNodeFFunc(internalpb.RateScope_Partition),
|
||||
)
|
||||
|
||||
assert.Equal(t, 2, root.GetChildren().Len())
|
||||
assert.Equal(t, 2, root.GetChild(1).GetChildren().Len())
|
||||
assert.Equal(t, 2, root.GetChild(1).GetChild(10).GetChildren().Len())
|
||||
|
||||
tree.ClearInvalidLimiterNode(&proxypb.LimiterNode{
|
||||
Children: map[int64]*proxypb.LimiterNode{
|
||||
1: {
|
||||
Children: map[int64]*proxypb.LimiterNode{
|
||||
10: {
|
||||
Children: map[int64]*proxypb.LimiterNode{
|
||||
100: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, root.GetChildren().Len())
|
||||
assert.Equal(t, 1, root.GetChild(1).GetChildren().Len())
|
||||
assert.Equal(t, 1, root.GetChild(1).GetChild(10).GetChildren().Len())
|
||||
}
|
|
@ -132,6 +132,8 @@ const (
|
|||
CollectionSearchRateMaxKey = "collection.searchRate.max.vps"
|
||||
CollectionSearchRateMinKey = "collection.searchRate.min.vps"
|
||||
CollectionDiskQuotaKey = "collection.diskProtection.diskQuota.mb"
|
||||
|
||||
PartitionDiskQuotaKey = "partition.diskProtection.diskQuota.mb"
|
||||
)
|
||||
|
||||
// common properties
|
||||
|
|
|
@ -167,7 +167,7 @@ var (
|
|||
Help: "The quota states of cluster",
|
||||
}, []string{
|
||||
"quota_states",
|
||||
"db_name",
|
||||
"name",
|
||||
})
|
||||
|
||||
// RootCoordRateLimitRatio reflects the ratio of rate limit.
|
||||
|
|
|
@ -87,6 +87,7 @@ type QueryNodeQuotaMetrics struct {
|
|||
type DataCoordQuotaMetrics struct {
|
||||
TotalBinlogSize int64
|
||||
CollectionBinlogSize map[int64]int64
|
||||
PartitionsBinlogSize map[int64]map[int64]int64
|
||||
}
|
||||
|
||||
// DataNodeQuotaMetrics are metrics of DataNode.
|
||||
|
|
|
@ -61,6 +61,12 @@ type quotaConfig struct {
|
|||
CompactionLimitEnabled ParamItem `refreshable:"true"`
|
||||
MaxCompactionRate ParamItem `refreshable:"true"`
|
||||
|
||||
DDLCollectionRatePerDB ParamItem `refreshable:"true"`
|
||||
DDLPartitionRatePerDB ParamItem `refreshable:"true"`
|
||||
MaxIndexRatePerDB ParamItem `refreshable:"true"`
|
||||
MaxFlushRatePerDB ParamItem `refreshable:"true"`
|
||||
MaxCompactionRatePerDB ParamItem `refreshable:"true"`
|
||||
|
||||
// dml
|
||||
DMLLimitEnabled ParamItem `refreshable:"true"`
|
||||
DMLMaxInsertRate ParamItem `refreshable:"true"`
|
||||
|
@ -71,6 +77,14 @@ type quotaConfig struct {
|
|||
DMLMinDeleteRate ParamItem `refreshable:"true"`
|
||||
DMLMaxBulkLoadRate ParamItem `refreshable:"true"`
|
||||
DMLMinBulkLoadRate ParamItem `refreshable:"true"`
|
||||
DMLMaxInsertRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMinInsertRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMaxUpsertRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMinUpsertRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMaxDeleteRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMinDeleteRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMaxBulkLoadRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMinBulkLoadRatePerDB ParamItem `refreshable:"true"`
|
||||
DMLMaxInsertRatePerCollection ParamItem `refreshable:"true"`
|
||||
DMLMinInsertRatePerCollection ParamItem `refreshable:"true"`
|
||||
DMLMaxUpsertRatePerCollection ParamItem `refreshable:"true"`
|
||||
|
@ -79,6 +93,14 @@ type quotaConfig struct {
|
|||
DMLMinDeleteRatePerCollection ParamItem `refreshable:"true"`
|
||||
DMLMaxBulkLoadRatePerCollection ParamItem `refreshable:"true"`
|
||||
DMLMinBulkLoadRatePerCollection ParamItem `refreshable:"true"`
|
||||
DMLMaxInsertRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMinInsertRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMaxUpsertRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMinUpsertRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMaxDeleteRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMinDeleteRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMaxBulkLoadRatePerPartition ParamItem `refreshable:"true"`
|
||||
DMLMinBulkLoadRatePerPartition ParamItem `refreshable:"true"`
|
||||
|
||||
// dql
|
||||
DQLLimitEnabled ParamItem `refreshable:"true"`
|
||||
|
@ -86,10 +108,18 @@ type quotaConfig struct {
|
|||
DQLMinSearchRate ParamItem `refreshable:"true"`
|
||||
DQLMaxQueryRate ParamItem `refreshable:"true"`
|
||||
DQLMinQueryRate ParamItem `refreshable:"true"`
|
||||
DQLMaxSearchRatePerDB ParamItem `refreshable:"true"`
|
||||
DQLMinSearchRatePerDB ParamItem `refreshable:"true"`
|
||||
DQLMaxQueryRatePerDB ParamItem `refreshable:"true"`
|
||||
DQLMinQueryRatePerDB ParamItem `refreshable:"true"`
|
||||
DQLMaxSearchRatePerCollection ParamItem `refreshable:"true"`
|
||||
DQLMinSearchRatePerCollection ParamItem `refreshable:"true"`
|
||||
DQLMaxQueryRatePerCollection ParamItem `refreshable:"true"`
|
||||
DQLMinQueryRatePerCollection ParamItem `refreshable:"true"`
|
||||
DQLMaxSearchRatePerPartition ParamItem `refreshable:"true"`
|
||||
DQLMinSearchRatePerPartition ParamItem `refreshable:"true"`
|
||||
DQLMaxQueryRatePerPartition ParamItem `refreshable:"true"`
|
||||
DQLMinQueryRatePerPartition ParamItem `refreshable:"true"`
|
||||
|
||||
// limits
|
||||
MaxCollectionNum ParamItem `refreshable:"true"`
|
||||
|
@ -114,16 +144,20 @@ type quotaConfig struct {
|
|||
GrowingSegmentsSizeHighWaterLevel ParamItem `refreshable:"true"`
|
||||
DiskProtectionEnabled ParamItem `refreshable:"true"`
|
||||
DiskQuota ParamItem `refreshable:"true"`
|
||||
DiskQuotaPerDB ParamItem `refreshable:"true"`
|
||||
DiskQuotaPerCollection ParamItem `refreshable:"true"`
|
||||
DiskQuotaPerPartition ParamItem `refreshable:"true"`
|
||||
|
||||
// limit reading
|
||||
ForceDenyReading ParamItem `refreshable:"true"`
|
||||
QueueProtectionEnabled ParamItem `refreshable:"true"`
|
||||
NQInQueueThreshold ParamItem `refreshable:"true"`
|
||||
QueueLatencyThreshold ParamItem `refreshable:"true"`
|
||||
ResultProtectionEnabled ParamItem `refreshable:"true"`
|
||||
MaxReadResultRate ParamItem `refreshable:"true"`
|
||||
CoolOffSpeed ParamItem `refreshable:"true"`
|
||||
ForceDenyReading ParamItem `refreshable:"true"`
|
||||
QueueProtectionEnabled ParamItem `refreshable:"true"`
|
||||
NQInQueueThreshold ParamItem `refreshable:"true"`
|
||||
QueueLatencyThreshold ParamItem `refreshable:"true"`
|
||||
ResultProtectionEnabled ParamItem `refreshable:"true"`
|
||||
MaxReadResultRate ParamItem `refreshable:"true"`
|
||||
MaxReadResultRatePerDB ParamItem `refreshable:"true"`
|
||||
MaxReadResultRatePerCollection ParamItem `refreshable:"true"`
|
||||
CoolOffSpeed ParamItem `refreshable:"true"`
|
||||
}
|
||||
|
||||
func (p *quotaConfig) init(base *BaseTable) {
|
||||
|
@ -185,6 +219,25 @@ seconds, (0 ~ 65536)`,
|
|||
}
|
||||
p.DDLCollectionRate.Init(base.mgr)
|
||||
|
||||
p.DDLCollectionRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.ddl.db.collectionRate",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DDLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0 ~ Inf)
|
||||
if getAsInt(v) < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps of db level , default no limit, rate for CreateCollection, DropCollection, LoadCollection, ReleaseCollection",
|
||||
Export: true,
|
||||
}
|
||||
p.DDLCollectionRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DDLPartitionRate = ParamItem{
|
||||
Key: "quotaAndLimits.ddl.partitionRate",
|
||||
Version: "2.2.0",
|
||||
|
@ -204,6 +257,25 @@ seconds, (0 ~ 65536)`,
|
|||
}
|
||||
p.DDLPartitionRate.Init(base.mgr)
|
||||
|
||||
p.DDLPartitionRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.ddl.db.partitionRate",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DDLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0 ~ Inf)
|
||||
if getAsInt(v) < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps of db level, default no limit, rate for CreatePartition, DropPartition, LoadPartition, ReleasePartition",
|
||||
Export: true,
|
||||
}
|
||||
p.DDLPartitionRatePerDB.Init(base.mgr)
|
||||
|
||||
p.IndexLimitEnabled = ParamItem{
|
||||
Key: "quotaAndLimits.indexRate.enabled",
|
||||
Version: "2.2.0",
|
||||
|
@ -231,6 +303,25 @@ seconds, (0 ~ 65536)`,
|
|||
}
|
||||
p.MaxIndexRate.Init(base.mgr)
|
||||
|
||||
p.MaxIndexRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.indexRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.IndexLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0 ~ Inf)
|
||||
if getAsFloat(v) < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps of db level, default no limit, rate for CreateIndex, DropIndex",
|
||||
Export: true,
|
||||
}
|
||||
p.MaxIndexRatePerDB.Init(base.mgr)
|
||||
|
||||
p.FlushLimitEnabled = ParamItem{
|
||||
Key: "quotaAndLimits.flushRate.enabled",
|
||||
Version: "2.2.0",
|
||||
|
@ -258,6 +349,25 @@ seconds, (0 ~ 65536)`,
|
|||
}
|
||||
p.MaxFlushRate.Init(base.mgr)
|
||||
|
||||
p.MaxFlushRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.flushRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.FlushLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0 ~ Inf)
|
||||
if getAsInt(v) < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps of db level, default no limit, rate for flush",
|
||||
Export: true,
|
||||
}
|
||||
p.MaxFlushRatePerDB.Init(base.mgr)
|
||||
|
||||
p.MaxFlushRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.flushRate.collection.max",
|
||||
Version: "2.3.9",
|
||||
|
@ -304,6 +414,25 @@ seconds, (0 ~ 65536)`,
|
|||
}
|
||||
p.MaxCompactionRate.Init(base.mgr)
|
||||
|
||||
p.MaxCompactionRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.compactionRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.CompactionLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0 ~ Inf)
|
||||
if getAsInt(v) < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps of db level, default no limit, rate for manualCompaction",
|
||||
Export: true,
|
||||
}
|
||||
p.MaxCompactionRatePerDB.Init(base.mgr)
|
||||
|
||||
// dml
|
||||
p.DMLLimitEnabled = ParamItem{
|
||||
Key: "quotaAndLimits.dml.enabled",
|
||||
|
@ -359,6 +488,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinInsertRate.Init(base.mgr)
|
||||
|
||||
p.DMLMaxInsertRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.insertRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxInsertRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxInsertRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMinInsertRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.insertRate.db.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxInsertRatePerDB.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinInsertRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMaxInsertRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.dml.insertRate.collection.max",
|
||||
Version: "2.2.9",
|
||||
|
@ -403,6 +576,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinInsertRatePerCollection.Init(base.mgr)
|
||||
|
||||
p.DMLMaxInsertRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.insertRate.partition.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxInsertRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxInsertRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMinInsertRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.insertRate.partition.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxInsertRatePerPartition.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinInsertRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMaxUpsertRate = ParamItem{
|
||||
Key: "quotaAndLimits.dml.upsertRate.max",
|
||||
Version: "2.3.0",
|
||||
|
@ -447,6 +664,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinUpsertRate.Init(base.mgr)
|
||||
|
||||
p.DMLMaxUpsertRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.upsertRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxUpsertRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxUpsertRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMinUpsertRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.upsertRate.db.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxUpsertRatePerDB.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinUpsertRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMaxUpsertRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.dml.upsertRate.collection.max",
|
||||
Version: "2.3.0",
|
||||
|
@ -491,6 +752,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinUpsertRatePerCollection.Init(base.mgr)
|
||||
|
||||
p.DMLMaxUpsertRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.upsertRate.partition.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxUpsertRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxUpsertRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMinUpsertRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.upsertRate.partition.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxUpsertRatePerPartition.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinUpsertRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMaxDeleteRate = ParamItem{
|
||||
Key: "quotaAndLimits.dml.deleteRate.max",
|
||||
Version: "2.2.0",
|
||||
|
@ -535,6 +840,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinDeleteRate.Init(base.mgr)
|
||||
|
||||
p.DMLMaxDeleteRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.deleteRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxDeleteRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxDeleteRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMinDeleteRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.deleteRate.db.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxDeleteRatePerDB.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinDeleteRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMaxDeleteRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.dml.deleteRate.collection.max",
|
||||
Version: "2.2.9",
|
||||
|
@ -579,6 +928,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinDeleteRatePerCollection.Init(base.mgr)
|
||||
|
||||
p.DMLMaxDeleteRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.deleteRate.partition.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxDeleteRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxDeleteRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMinDeleteRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.deleteRate.partition.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxDeleteRatePerPartition.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinDeleteRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMaxBulkLoadRate = ParamItem{
|
||||
Key: "quotaAndLimits.dml.bulkLoadRate.max",
|
||||
Version: "2.2.0",
|
||||
|
@ -623,6 +1016,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinBulkLoadRate.Init(base.mgr)
|
||||
|
||||
p.DMLMaxBulkLoadRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.bulkLoadRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxBulkLoadRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit, not support yet. TODO: limit db bulkLoad rate",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxBulkLoadRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMinBulkLoadRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dml.bulkLoadRate.db.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxBulkLoadRatePerDB.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinBulkLoadRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DMLMaxBulkLoadRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.dml.bulkLoadRate.collection.max",
|
||||
Version: "2.2.9",
|
||||
|
@ -667,6 +1104,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DMLMinBulkLoadRatePerCollection.Init(base.mgr)
|
||||
|
||||
p.DMLMaxBulkLoadRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.bulkLoadRate.partition.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
rate = megaBytes2Bytes(rate)
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return p.DMLMaxBulkLoadRate.GetValue()
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
Doc: "MB/s, default no limit, not support yet. TODO: limit partition bulkLoad rate",
|
||||
Export: true,
|
||||
}
|
||||
p.DMLMaxBulkLoadRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DMLMinBulkLoadRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dml.bulkLoadRate.partition.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DMLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := megaBytes2Bytes(getAsFloat(v))
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DMLMaxBulkLoadRatePerPartition.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return fmt.Sprintf("%f", rate)
|
||||
},
|
||||
}
|
||||
p.DMLMinBulkLoadRatePerPartition.Init(base.mgr)
|
||||
|
||||
// dql
|
||||
p.DQLLimitEnabled = ParamItem{
|
||||
Key: "quotaAndLimits.dql.enabled",
|
||||
|
@ -718,6 +1199,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DQLMinSearchRate.Init(base.mgr)
|
||||
|
||||
p.DQLMaxSearchRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dql.searchRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0, inf)
|
||||
if getAsFloat(v) < 0 {
|
||||
return p.DQLMaxSearchRate.GetValue()
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "vps (vectors per second), default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DQLMaxSearchRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DQLMinSearchRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dql.searchRate.db.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DQLMaxSearchRatePerDB.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return v
|
||||
},
|
||||
}
|
||||
p.DQLMinSearchRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DQLMaxSearchRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.dql.searchRate.collection.max",
|
||||
Version: "2.2.9",
|
||||
|
@ -758,6 +1279,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DQLMinSearchRatePerCollection.Init(base.mgr)
|
||||
|
||||
p.DQLMaxSearchRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dql.searchRate.partition.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0, inf)
|
||||
if getAsFloat(v) < 0 {
|
||||
return p.DQLMaxSearchRate.GetValue()
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "vps (vectors per second), default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DQLMaxSearchRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DQLMinSearchRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dql.searchRate.partition.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DQLMaxSearchRatePerPartition.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return v
|
||||
},
|
||||
}
|
||||
p.DQLMinSearchRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DQLMaxQueryRate = ParamItem{
|
||||
Key: "quotaAndLimits.dql.queryRate.max",
|
||||
Version: "2.2.0",
|
||||
|
@ -798,6 +1359,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DQLMinQueryRate.Init(base.mgr)
|
||||
|
||||
p.DQLMaxQueryRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dql.queryRate.db.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0, inf)
|
||||
if getAsFloat(v) < 0 {
|
||||
return p.DQLMaxQueryRate.GetValue()
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DQLMaxQueryRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DQLMinQueryRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.dql.queryRate.db.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DQLMaxQueryRatePerDB.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return v
|
||||
},
|
||||
}
|
||||
p.DQLMinQueryRatePerDB.Init(base.mgr)
|
||||
|
||||
p.DQLMaxQueryRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.dql.queryRate.collection.max",
|
||||
Version: "2.2.9",
|
||||
|
@ -838,6 +1439,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
|
|||
}
|
||||
p.DQLMinQueryRatePerCollection.Init(base.mgr)
|
||||
|
||||
p.DQLMaxQueryRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dql.queryRate.partition.max",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
// [0, inf)
|
||||
if getAsFloat(v) < 0 {
|
||||
return p.DQLMaxQueryRate.GetValue()
|
||||
}
|
||||
return v
|
||||
},
|
||||
Doc: "qps, default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DQLMaxQueryRatePerPartition.Init(base.mgr)
|
||||
|
||||
p.DQLMinQueryRatePerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.dql.queryRate.partition.min",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: min,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DQLLimitEnabled.GetAsBool() {
|
||||
return min
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return min
|
||||
}
|
||||
if !p.checkMinMaxLegal(rate, p.DQLMaxQueryRatePerPartition.GetAsFloat()) {
|
||||
return min
|
||||
}
|
||||
return v
|
||||
},
|
||||
}
|
||||
p.DQLMinQueryRatePerPartition.Init(base.mgr)
|
||||
|
||||
// limits
|
||||
p.MaxCollectionNum = ParamItem{
|
||||
Key: "quotaAndLimits.limits.maxCollectionNum",
|
||||
|
@ -1132,6 +1773,27 @@ but the rate will not be lower than minRateRatio * dmlRate.`,
|
|||
}
|
||||
p.DiskQuota.Init(base.mgr)
|
||||
|
||||
p.DiskQuotaPerDB = ParamItem{
|
||||
Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerDB",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: quota,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DiskProtectionEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
level := getAsFloat(v)
|
||||
// (0, +inf)
|
||||
if level <= 0 {
|
||||
return p.DiskQuota.GetValue()
|
||||
}
|
||||
// megabytes to bytes
|
||||
return fmt.Sprintf("%f", megaBytes2Bytes(level))
|
||||
},
|
||||
Doc: "MB, (0, +inf), default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DiskQuotaPerDB.Init(base.mgr)
|
||||
|
||||
p.DiskQuotaPerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerCollection",
|
||||
Version: "2.2.8",
|
||||
|
@ -1153,6 +1815,27 @@ but the rate will not be lower than minRateRatio * dmlRate.`,
|
|||
}
|
||||
p.DiskQuotaPerCollection.Init(base.mgr)
|
||||
|
||||
p.DiskQuotaPerPartition = ParamItem{
|
||||
Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerPartition",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: quota,
|
||||
Formatter: func(v string) string {
|
||||
if !p.DiskProtectionEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
level := getAsFloat(v)
|
||||
// (0, +inf)
|
||||
if level <= 0 {
|
||||
return p.DiskQuota.GetValue()
|
||||
}
|
||||
// megabytes to bytes
|
||||
return fmt.Sprintf("%f", megaBytes2Bytes(level))
|
||||
},
|
||||
Doc: "MB, (0, +inf), default no limit",
|
||||
Export: true,
|
||||
}
|
||||
p.DiskQuotaPerPartition.Init(base.mgr)
|
||||
|
||||
// limit reading
|
||||
p.ForceDenyReading = ParamItem{
|
||||
Key: "quotaAndLimits.limitReading.forceDeny",
|
||||
|
@ -1253,6 +1936,50 @@ MB/s, default no limit`,
|
|||
}
|
||||
p.MaxReadResultRate.Init(base.mgr)
|
||||
|
||||
p.MaxReadResultRatePerDB = ParamItem{
|
||||
Key: "quotaAndLimits.limitReading.resultProtection.maxReadResultRatePerDB",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.ResultProtectionEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
return fmt.Sprintf("%f", megaBytes2Bytes(rate))
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Export: true,
|
||||
}
|
||||
p.MaxReadResultRatePerDB.Init(base.mgr)
|
||||
|
||||
p.MaxReadResultRatePerCollection = ParamItem{
|
||||
Key: "quotaAndLimits.limitReading.resultProtection.maxReadResultRatePerCollection",
|
||||
Version: "2.4.1",
|
||||
DefaultValue: max,
|
||||
Formatter: func(v string) string {
|
||||
if !p.ResultProtectionEnabled.GetAsBool() {
|
||||
return max
|
||||
}
|
||||
rate := getAsFloat(v)
|
||||
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
|
||||
return fmt.Sprintf("%f", megaBytes2Bytes(rate))
|
||||
}
|
||||
// [0, inf)
|
||||
if rate < 0 {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
},
|
||||
Export: true,
|
||||
}
|
||||
p.MaxReadResultRatePerCollection.Init(base.mgr)
|
||||
|
||||
const defaultSpeed = "0.9"
|
||||
p.CoolOffSpeed = ParamItem{
|
||||
Key: "quotaAndLimits.limitReading.coolOffSpeed",
|
||||
|
|
|
@ -45,12 +45,13 @@ const Inf = Limit(math.MaxFloat64)
|
|||
// in bucket may be negative, and the latter events would be "punished",
|
||||
// any event should wait for the tokens to be filled to greater or equal to 0.
|
||||
type Limiter struct {
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
limit Limit
|
||||
burst float64
|
||||
tokens float64
|
||||
// last is the last time the limiter's tokens field was updated
|
||||
last time.Time
|
||||
last time.Time
|
||||
hasUpdated bool
|
||||
}
|
||||
|
||||
// NewLimiter returns a new Limiter that allows events up to rate r.
|
||||
|
@ -63,13 +64,20 @@ func NewLimiter(r Limit, b float64) *Limiter {
|
|||
|
||||
// Limit returns the maximum overall event rate.
|
||||
func (lim *Limiter) Limit() Limit {
|
||||
lim.mu.Lock()
|
||||
defer lim.mu.Unlock()
|
||||
lim.mu.RLock()
|
||||
defer lim.mu.RUnlock()
|
||||
return lim.limit
|
||||
}
|
||||
|
||||
// AllowN reports whether n events may happen at time now.
|
||||
func (lim *Limiter) AllowN(now time.Time, n int) bool {
|
||||
lim.mu.RLock()
|
||||
if lim.limit == Inf {
|
||||
lim.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
lim.mu.RUnlock()
|
||||
|
||||
lim.mu.Lock()
|
||||
defer lim.mu.Unlock()
|
||||
|
||||
|
@ -119,6 +127,7 @@ func (lim *Limiter) SetLimit(newLimit Limit) {
|
|||
// use rate as burst, because Limiter is with punishment mechanism, burst is insignificant.
|
||||
lim.burst = float64(newLimit)
|
||||
}
|
||||
lim.hasUpdated = true
|
||||
}
|
||||
|
||||
// Cancel the AllowN operation and refund the tokens that have already been deducted by the limiter.
|
||||
|
@ -128,6 +137,12 @@ func (lim *Limiter) Cancel(n int) {
|
|||
lim.tokens += float64(n)
|
||||
}
|
||||
|
||||
func (lim *Limiter) HasUpdated() bool {
|
||||
lim.mu.RLock()
|
||||
defer lim.mu.RUnlock()
|
||||
return lim.hasUpdated
|
||||
}
|
||||
|
||||
// advance calculates and returns an updated state for lim resulting from the passage of time.
|
||||
// lim is not changed. advance requires that lim.mu is held.
|
||||
func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) {
|
||||
|
|
|
@ -19,8 +19,11 @@ package ratelimitutil
|
|||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -34,21 +37,22 @@ const (
|
|||
type RateCollector struct {
|
||||
sync.Mutex
|
||||
|
||||
window time.Duration
|
||||
granularity time.Duration
|
||||
position int
|
||||
values map[string][]float64
|
||||
window time.Duration
|
||||
granularity time.Duration
|
||||
position int
|
||||
values map[string][]float64
|
||||
deprecatedSubLabels []lo.Tuple2[string, string]
|
||||
|
||||
last time.Time
|
||||
}
|
||||
|
||||
// NewRateCollector is shorthand for newRateCollector(window, granularity, time.Now()).
|
||||
func NewRateCollector(window time.Duration, granularity time.Duration) (*RateCollector, error) {
|
||||
return newRateCollector(window, granularity, time.Now())
|
||||
func NewRateCollector(window time.Duration, granularity time.Duration, enableSubLabel bool) (*RateCollector, error) {
|
||||
return newRateCollector(window, granularity, time.Now(), enableSubLabel)
|
||||
}
|
||||
|
||||
// newRateCollector returns a new RateCollector with given window and granularity.
|
||||
func newRateCollector(window time.Duration, granularity time.Duration, now time.Time) (*RateCollector, error) {
|
||||
func newRateCollector(window time.Duration, granularity time.Duration, now time.Time, enableSubLabel bool) (*RateCollector, error) {
|
||||
if window == 0 || granularity == 0 {
|
||||
return nil, fmt.Errorf("create RateCollector failed, window or granularity cannot be 0, window = %d, granularity = %d", window, granularity)
|
||||
}
|
||||
|
@ -62,9 +66,52 @@ func newRateCollector(window time.Duration, granularity time.Duration, now time.
|
|||
values: make(map[string][]float64),
|
||||
last: now,
|
||||
}
|
||||
|
||||
if enableSubLabel {
|
||||
go rc.cleanDeprecateSubLabels()
|
||||
}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (r *RateCollector) cleanDeprecateSubLabels() {
|
||||
tick := time.NewTicker(r.window * 2)
|
||||
defer tick.Stop()
|
||||
for range tick.C {
|
||||
r.Lock()
|
||||
for _, labelInfo := range r.deprecatedSubLabels {
|
||||
r.removeSubLabel(labelInfo)
|
||||
}
|
||||
r.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateCollector) removeSubLabel(labelInfo lo.Tuple2[string, string]) {
|
||||
label := labelInfo.A
|
||||
subLabel := labelInfo.B
|
||||
if subLabel == "" {
|
||||
return
|
||||
}
|
||||
removeKeys := make([]string, 1)
|
||||
removeKeys[0] = FormatSubLabel(label, subLabel)
|
||||
|
||||
deleteCollectionSubLabelWithPrefix := func(dbName string) {
|
||||
for key := range r.values {
|
||||
if strings.HasPrefix(key, FormatSubLabel(label, GetCollectionSubLabel(dbName, ""))) {
|
||||
removeKeys = append(removeKeys, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.Split(subLabel, ".")
|
||||
if strings.HasPrefix(subLabel, GetDBSubLabel("")) {
|
||||
dbName := parts[1]
|
||||
deleteCollectionSubLabelWithPrefix(dbName)
|
||||
}
|
||||
for _, key := range removeKeys {
|
||||
delete(r.values, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Register init values of RateCollector for specified label.
|
||||
func (r *RateCollector) Register(label string) {
|
||||
r.Lock()
|
||||
|
@ -81,21 +128,77 @@ func (r *RateCollector) Deregister(label string) {
|
|||
delete(r.values, label)
|
||||
}
|
||||
|
||||
func GetDBSubLabel(dbName string) string {
|
||||
return fmt.Sprintf("db.%s", dbName)
|
||||
}
|
||||
|
||||
func GetCollectionSubLabel(dbName, collectionName string) string {
|
||||
return fmt.Sprintf("collection.%s.%s", dbName, collectionName)
|
||||
}
|
||||
|
||||
func FormatSubLabel(label, subLabel string) string {
|
||||
return fmt.Sprintf("%s-%s", label, subLabel)
|
||||
}
|
||||
|
||||
func GetDBFromSubLabel(label, fullLabel string) (string, bool) {
|
||||
if !strings.HasPrefix(fullLabel, FormatSubLabel(label, GetDBSubLabel(""))) {
|
||||
return "", false
|
||||
}
|
||||
return fullLabel[len(FormatSubLabel(label, GetDBSubLabel(""))):], true
|
||||
}
|
||||
|
||||
func GetCollectionFromSubLabel(label, fullLabel string) (string, string, bool) {
|
||||
if !strings.HasPrefix(fullLabel, FormatSubLabel(label, "")) {
|
||||
return "", "", false
|
||||
}
|
||||
subLabels := strings.Split(fullLabel[len(FormatSubLabel(label, "")):], ".")
|
||||
if len(subLabels) != 3 || subLabels[0] != "collection" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return subLabels[1], subLabels[2], true
|
||||
}
|
||||
|
||||
func (r *RateCollector) DeregisterSubLabel(label, subLabel string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.deprecatedSubLabels = append(r.deprecatedSubLabels, lo.Tuple2[string, string]{
|
||||
A: label,
|
||||
B: subLabel,
|
||||
})
|
||||
}
|
||||
|
||||
// Add is shorthand for add(label, value, time.Now()).
|
||||
func (r *RateCollector) Add(label string, value float64) {
|
||||
r.add(label, value, time.Now())
|
||||
func (r *RateCollector) Add(label string, value float64, subLabels ...string) {
|
||||
r.add(label, value, time.Now(), subLabels...)
|
||||
}
|
||||
|
||||
// add increases the current value of specified label.
|
||||
func (r *RateCollector) add(label string, value float64, now time.Time) {
|
||||
func (r *RateCollector) add(label string, value float64, now time.Time, subLabels ...string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.update(now)
|
||||
if _, ok := r.values[label]; ok {
|
||||
r.values[label][r.position] += value
|
||||
for _, subLabel := range subLabels {
|
||||
r.unsafeAddForSubLabels(label, subLabel, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateCollector) unsafeAddForSubLabels(label, subLabel string, value float64) {
|
||||
if subLabel == "" {
|
||||
return
|
||||
}
|
||||
sub := FormatSubLabel(label, subLabel)
|
||||
if _, ok := r.values[sub]; ok {
|
||||
r.values[sub][r.position] += value
|
||||
return
|
||||
}
|
||||
r.values[sub] = make([]float64, int(r.window/r.granularity))
|
||||
r.values[sub][r.position] = value
|
||||
}
|
||||
|
||||
// Max is shorthand for max(label, time.Now()).
|
||||
func (r *RateCollector) Max(label string, now time.Time) (float64, error) {
|
||||
return r.max(label, time.Now())
|
||||
|
@ -145,6 +248,26 @@ func (r *RateCollector) Rate(label string, duration time.Duration) (float64, err
|
|||
return r.rate(label, duration, time.Now())
|
||||
}
|
||||
|
||||
func (r *RateCollector) RateSubLabel(label string, duration time.Duration) (map[string]float64, error) {
|
||||
subLabelPrefix := FormatSubLabel(label, "")
|
||||
subLabels := make(map[string]float64)
|
||||
r.Lock()
|
||||
for s := range r.values {
|
||||
if strings.HasPrefix(s, subLabelPrefix) {
|
||||
subLabels[s] = 0
|
||||
}
|
||||
}
|
||||
r.Unlock()
|
||||
for s := range subLabels {
|
||||
v, err := r.rate(s, duration, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subLabels[s] = v
|
||||
}
|
||||
return subLabels, nil
|
||||
}
|
||||
|
||||
// rate returns the latest mean value of the specified duration.
|
||||
func (r *RateCollector) rate(label string, duration time.Duration, now time.Time) (float64, error) {
|
||||
if duration > r.window {
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -36,7 +37,7 @@ func TestRateCollector(t *testing.T) {
|
|||
ts100 = ts0.Add(time.Duration(100.0 * float64(time.Second)))
|
||||
)
|
||||
|
||||
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0)
|
||||
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0, false)
|
||||
assert.NoError(t, err)
|
||||
label := "mock_label"
|
||||
rc.Register(label)
|
||||
|
@ -78,7 +79,7 @@ func TestRateCollector(t *testing.T) {
|
|||
ts31 = ts0.Add(time.Duration(3.1 * float64(time.Second)))
|
||||
)
|
||||
|
||||
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0)
|
||||
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0, false)
|
||||
assert.NoError(t, err)
|
||||
label := "mock_label"
|
||||
rc.Register(label)
|
||||
|
@ -105,7 +106,7 @@ func TestRateCollector(t *testing.T) {
|
|||
start := tt.now()
|
||||
end := start.Add(testPeriod * time.Second)
|
||||
|
||||
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, start)
|
||||
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, start, false)
|
||||
assert.NoError(t, err)
|
||||
label := "mock_label"
|
||||
rc.Register(label)
|
||||
|
@ -138,3 +139,111 @@ func TestRateCollector(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateSubLabel(t *testing.T) {
|
||||
rateCollector, err := NewRateCollector(5*time.Second, time.Second, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var (
|
||||
label = "search"
|
||||
db = "hoo"
|
||||
collection = "foo"
|
||||
dbSubLabel = GetDBSubLabel(db)
|
||||
collectionSubLabel = GetCollectionSubLabel(db, collection)
|
||||
ts0 = time.Now()
|
||||
ts10 = ts0.Add(time.Duration(1.0 * float64(time.Second)))
|
||||
ts19 = ts0.Add(time.Duration(1.9 * float64(time.Second)))
|
||||
ts20 = ts0.Add(time.Duration(2.0 * float64(time.Second)))
|
||||
ts30 = ts0.Add(time.Duration(3.0 * float64(time.Second)))
|
||||
ts40 = ts0.Add(time.Duration(4.0 * float64(time.Second)))
|
||||
)
|
||||
|
||||
rateCollector.Register(label)
|
||||
defer rateCollector.Deregister(label)
|
||||
rateCollector.add(label, 10, ts0, dbSubLabel, collectionSubLabel)
|
||||
rateCollector.add(label, 20, ts10, dbSubLabel, collectionSubLabel)
|
||||
rateCollector.add(label, 30, ts19, dbSubLabel, collectionSubLabel)
|
||||
rateCollector.add(label, 40, ts20, dbSubLabel, collectionSubLabel)
|
||||
rateCollector.add(label, 50, ts30, dbSubLabel, collectionSubLabel)
|
||||
rateCollector.add(label, 60, ts40, dbSubLabel, collectionSubLabel)
|
||||
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
// 10 20+30 40 50 60
|
||||
{
|
||||
avg, err := rateCollector.Rate(label, 3*time.Second)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float64(50), avg)
|
||||
}
|
||||
{
|
||||
avg, err := rateCollector.Rate(label, 5*time.Second)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float64(42), avg)
|
||||
}
|
||||
{
|
||||
avgs, err := rateCollector.RateSubLabel(label, 3*time.Second)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(avgs))
|
||||
assert.Equal(t, float64(50), avgs[FormatSubLabel(label, dbSubLabel)])
|
||||
assert.Equal(t, float64(50), avgs[FormatSubLabel(label, collectionSubLabel)])
|
||||
}
|
||||
|
||||
rateCollector.Add(label, 10, GetCollectionSubLabel(db, collection))
|
||||
rateCollector.Add(label, 10, GetCollectionSubLabel(db, "col2"))
|
||||
|
||||
rateCollector.DeregisterSubLabel(label, GetCollectionSubLabel(db, "col2"))
|
||||
rateCollector.DeregisterSubLabel(label, dbSubLabel)
|
||||
|
||||
rateCollector.removeSubLabel(lo.Tuple2[string, string]{
|
||||
A: "aaa",
|
||||
})
|
||||
|
||||
rateCollector.Lock()
|
||||
for _, labelInfo := range rateCollector.deprecatedSubLabels {
|
||||
rateCollector.removeSubLabel(labelInfo)
|
||||
}
|
||||
rateCollector.Unlock()
|
||||
|
||||
{
|
||||
_, ok := rateCollector.values[FormatSubLabel(label, dbSubLabel)]
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
{
|
||||
_, ok := rateCollector.values[FormatSubLabel(label, collectionSubLabel)]
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
{
|
||||
assert.Len(t, rateCollector.values, 1)
|
||||
_, ok := rateCollector.values[label]
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLabelUtil(t *testing.T) {
|
||||
assert.Equal(t, GetDBSubLabel("db"), "db.db")
|
||||
assert.Equal(t, GetCollectionSubLabel("db", "collection"), "collection.db.collection")
|
||||
{
|
||||
db, ok := GetDBFromSubLabel("foo", FormatSubLabel("foo", GetDBSubLabel("db1")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "db1", db)
|
||||
}
|
||||
|
||||
{
|
||||
_, ok := GetDBFromSubLabel("foo", "aaa")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
{
|
||||
db, col, ok := GetCollectionFromSubLabel("foo", FormatSubLabel("foo", GetCollectionSubLabel("db1", "col1")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "col1", col)
|
||||
assert.Equal(t, "db1", db)
|
||||
}
|
||||
|
||||
{
|
||||
_, _, ok := GetCollectionFromSubLabel("foo", "aaa")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ratelimitutil
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
||||
var QuotaErrorString = map[commonpb.ErrorCode]string{
|
||||
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator",
|
||||
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
|
||||
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
|
||||
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",
|
||||
}
|
||||
|
||||
func GetQuotaErrorString(errCode commonpb.ErrorCode) string {
|
||||
return QuotaErrorString[errCode]
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package ratelimitutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
func TestGetQuotaErrorString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args commonpb.ErrorCode
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Test ErrorCode_ForceDeny",
|
||||
args: commonpb.ErrorCode_ForceDeny,
|
||||
want: "the writing has been deactivated by the administrator",
|
||||
},
|
||||
{
|
||||
name: "Test ErrorCode_MemoryQuotaExhausted",
|
||||
args: commonpb.ErrorCode_MemoryQuotaExhausted,
|
||||
want: "memory quota exceeded, please allocate more resources",
|
||||
},
|
||||
{
|
||||
name: "Test ErrorCode_DiskQuotaExhausted",
|
||||
args: commonpb.ErrorCode_DiskQuotaExhausted,
|
||||
want: "disk quota exceeded, please allocate more resources",
|
||||
},
|
||||
{
|
||||
name: "Test ErrorCode_TimeTickLongDelay",
|
||||
args: commonpb.ErrorCode_TimeTickLongDelay,
|
||||
want: "time tick long delay",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetQuotaErrorString(tt.args); got != tt.want {
|
||||
t.Errorf("GetQuotaErrorString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -30,6 +30,16 @@ func TestUniqueSet(t *testing.T) {
|
|||
assert.True(t, set.Contain(9))
|
||||
assert.True(t, set.Contain(5, 7, 9))
|
||||
|
||||
containFive := false
|
||||
set.Range(func(i UniqueID) bool {
|
||||
if i == 5 {
|
||||
containFive = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.True(t, containFive)
|
||||
|
||||
set.Remove(7)
|
||||
assert.True(t, set.Contain(5))
|
||||
assert.False(t, set.Contain(7))
|
||||
|
|
Loading…
Reference in New Issue