feat: support rate limiter based on db and partition levels (#31070)

issue: https://github.com/milvus-io/milvus/issues/30577
co-author: @jaime0815

---------

Signed-off-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com>
Signed-off-by: SimFG <bang.fu@zilliz.com>
Co-authored-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com>
pull/32216/head
SimFG 2024-04-12 16:01:19 +08:00 committed by GitHub
parent fb376fd1e6
commit c012e6786f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 5539 additions and 1321 deletions

View File

@ -457,6 +457,7 @@ generate-mockery-datacoord: getdeps
$(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Scheduler --dir=internal/datacoord --filename=mock_scheduler.go --output=internal/datacoord --structname=MockScheduler --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage
generate-mockery-datanode: getdeps
$(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.30.1. DO NOT EDIT.
// Code generated by mockery v2.32.4. DO NOT EDIT.
package broker
@ -77,6 +77,59 @@ func (_c *MockBroker_DescribeCollectionInternal_Call) RunAndReturn(run func(cont
return _c
}
// GetDatabaseID provides a mock function with given fields: ctx, dbName
func (_m *MockBroker) GetDatabaseID(ctx context.Context, dbName string) (int64, error) {
ret := _m.Called(ctx, dbName)
var r0 int64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok {
return rf(ctx, dbName)
}
if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok {
r0 = rf(ctx, dbName)
} else {
r0 = ret.Get(0).(int64)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, dbName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBroker_GetDatabaseID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseID'
type MockBroker_GetDatabaseID_Call struct {
*mock.Call
}
// GetDatabaseID is a helper method to define mock.On call
// - ctx context.Context
// - dbName string
func (_e *MockBroker_Expecter) GetDatabaseID(ctx interface{}, dbName interface{}) *MockBroker_GetDatabaseID_Call {
return &MockBroker_GetDatabaseID_Call{Call: _e.mock.On("GetDatabaseID", ctx, dbName)}
}
func (_c *MockBroker_GetDatabaseID_Call) Run(run func(ctx context.Context, dbName string)) *MockBroker_GetDatabaseID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockBroker_GetDatabaseID_Call) Return(_a0 int64, _a1 error) *MockBroker_GetDatabaseID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBroker_GetDatabaseID_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *MockBroker_GetDatabaseID_Call {
_c.Call.Return(run)
return _c
}
// HasCollection provides a mock function with given fields: ctx, collectionID
func (_m *MockBroker) HasCollection(ctx context.Context, collectionID int64) (bool, error) {
ret := _m.Called(ctx, collectionID)

View File

@ -263,7 +263,7 @@ func CheckDiskQuota(job ImportJob, meta *meta, imeta ImportMeta) (int64, error)
}
err := merr.WrapErrServiceQuotaExceeded("disk quota exceeded, please allocate more resources")
totalUsage, collectionsUsage := meta.GetCollectionBinlogSize()
totalUsage, collectionsUsage, _ := meta.GetCollectionBinlogSize()
tasks := imeta.GetTaskBy(WithJob(job.GetJobID()), WithType(PreImportTaskType))
files := make([]*datapb.ImportFileStats, 0)

View File

@ -85,6 +85,7 @@ type collectionInfo struct {
Properties map[string]string
CreatedAt Timestamp
DatabaseName string
DatabaseID int64
}
// NewMeta creates meta from provided `kv.TxnKV`
@ -200,6 +201,7 @@ func (m *meta) GetClonedCollectionInfo(collectionID UniqueID) *collectionInfo {
StartPositions: common.CloneKeyDataPairs(coll.StartPositions),
Properties: clonedProperties,
DatabaseName: coll.DatabaseName,
DatabaseID: coll.DatabaseID,
}
return cloneColl
@ -257,10 +259,11 @@ func (m *meta) GetNumRowsOfCollection(collectionID UniqueID) int64 {
}
// GetCollectionBinlogSize returns the total binlog size and binlog size of collections.
func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) {
func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64, map[UniqueID]map[UniqueID]int64) {
m.RLock()
defer m.RUnlock()
collectionBinlogSize := make(map[UniqueID]int64)
partitionBinlogSize := make(map[UniqueID]map[UniqueID]int64)
collectionRowsNum := make(map[UniqueID]map[commonpb.SegmentState]int64)
segments := m.segments.GetSegments()
var total int64
@ -270,6 +273,13 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) {
total += segmentSize
collectionBinlogSize[segment.GetCollectionID()] += segmentSize
partBinlogSize, ok := partitionBinlogSize[segment.GetCollectionID()]
if !ok {
partBinlogSize = make(map[int64]int64)
partitionBinlogSize[segment.GetCollectionID()] = partBinlogSize
}
partBinlogSize[segment.GetPartitionID()] += segmentSize
coll, ok := m.collections[segment.GetCollectionID()]
if ok {
metrics.DataCoordStoredBinlogSize.WithLabelValues(coll.DatabaseName,
@ -294,7 +304,7 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) {
}
}
}
return total, collectionBinlogSize
return total, collectionBinlogSize, partitionBinlogSize
}
func (m *meta) GetAllCollectionNumRows() map[int64]int64 {

View File

@ -603,13 +603,13 @@ func TestMeta_Basic(t *testing.T) {
assert.NoError(t, err)
// check TotalBinlogSize
total, collectionBinlogSize := meta.GetCollectionBinlogSize()
total, collectionBinlogSize, _ := meta.GetCollectionBinlogSize()
assert.Len(t, collectionBinlogSize, 1)
assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID])
assert.Equal(t, int64(size0+size1), total)
meta.collections[collID] = collInfo
total, collectionBinlogSize = meta.GetCollectionBinlogSize()
total, collectionBinlogSize, _ = meta.GetCollectionBinlogSize()
assert.Len(t, collectionBinlogSize, 1)
assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID])
assert.Equal(t, int64(size0+size1), total)

View File

@ -37,10 +37,11 @@ import (
// getQuotaMetrics returns DataCoordQuotaMetrics.
func (s *Server) getQuotaMetrics() *metricsinfo.DataCoordQuotaMetrics {
total, colSizes := s.meta.GetCollectionBinlogSize()
total, colSizes, partSizes := s.meta.GetCollectionBinlogSize()
return &metricsinfo.DataCoordQuotaMetrics{
TotalBinlogSize: total,
CollectionBinlogSize: colSizes,
PartitionsBinlogSize: partSizes,
}
}

View File

@ -329,6 +329,15 @@ type mockRootCoordClient struct {
cnt int64
}
func (m *mockRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
return &rootcoordpb.DescribeDatabaseResponse{
Status: merr.Success(),
DbID: 1,
DbName: "default",
CreatedTimestamp: 1,
}, nil
}
func (m *mockRootCoordClient) Close() error {
// TODO implement me
panic("implement me")

View File

@ -1159,6 +1159,7 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i
Properties: properties,
CreatedAt: resp.GetCreatedTimestamp(),
DatabaseName: resp.GetDbName(),
DatabaseID: resp.GetDbId(),
}
s.meta.AddCollection(collInfo)
return nil

View File

@ -1551,6 +1551,7 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt
Partitions: req.GetPartitionIDs(),
StartPositions: req.GetStartPositions(),
Properties: properties,
DatabaseID: req.GetDbID(),
}
s.meta.AddCollection(collInfo)
return merr.Success(), nil

View File

@ -47,7 +47,7 @@ func initGlobalRateCollector() error {
// newRateCollector returns a new rateCollector.
func newRateCollector() (*rateCollector, error) {
rc, err := ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
rc, err := ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, false)
if err != nil {
return nil, err
}

View File

@ -648,3 +648,14 @@ func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRe
}
return ret.(*milvuspb.ListDatabasesResponse), err
}
func (c *Client) DescribeDatabase(ctx context.Context, req *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())),
)
return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*rootcoordpb.DescribeDatabaseResponse, error) {
return client.DescribeDatabase(ctx, req)
})
}

View File

@ -77,6 +77,10 @@ type Server struct {
newQueryCoordClient func() types.QueryCoordClient
}
func (s *Server) DescribeDatabase(ctx context.Context, request *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) {
return s.rootCoord.DescribeDatabase(ctx, request)
}
func (s *Server) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
return s.rootCoord.CreateDatabase(ctx, request)
}

View File

@ -861,6 +861,61 @@ func (_c *RootCoord_DescribeCollectionInternal_Call) RunAndReturn(run func(conte
return _c
}
// DescribeDatabase provides a mock function with given fields: _a0, _a1
func (_m *RootCoord) DescribeDatabase(_a0 context.Context, _a1 *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) {
ret := _m.Called(_a0, _a1)
var r0 *rootcoordpb.DescribeDatabaseResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) *rootcoordpb.DescribeDatabaseResponse); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RootCoord_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase'
type RootCoord_DescribeDatabase_Call struct {
*mock.Call
}
// DescribeDatabase is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *rootcoordpb.DescribeDatabaseRequest
func (_e *RootCoord_Expecter) DescribeDatabase(_a0 interface{}, _a1 interface{}) *RootCoord_DescribeDatabase_Call {
return &RootCoord_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase", _a0, _a1)}
}
func (_c *RootCoord_DescribeDatabase_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.DescribeDatabaseRequest)) *RootCoord_DescribeDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*rootcoordpb.DescribeDatabaseRequest))
})
return _c
}
func (_c *RootCoord_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *RootCoord_DescribeDatabase_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *RootCoord_DescribeDatabase_Call) RunAndReturn(run func(context.Context, *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error)) *RootCoord_DescribeDatabase_Call {
_c.Call.Return(run)
return _c
}
// DropAlias provides a mock function with given fields: _a0, _a1
func (_m *RootCoord) DropAlias(_a0 context.Context, _a1 *milvuspb.DropAliasRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)

View File

@ -1124,6 +1124,76 @@ func (_c *MockRootCoordClient_DescribeCollectionInternal_Call) RunAndReturn(run
return _c
}
// DescribeDatabase provides a mock function with given fields: ctx, in, opts
func (_m *MockRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *rootcoordpb.DescribeDatabaseResponse
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) *rootcoordpb.DescribeDatabaseResponse); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockRootCoordClient_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase'
type MockRootCoordClient_DescribeDatabase_Call struct {
*mock.Call
}
// DescribeDatabase is a helper method to define mock.On call
// - ctx context.Context
// - in *rootcoordpb.DescribeDatabaseRequest
// - opts ...grpc.CallOption
func (_e *MockRootCoordClient_Expecter) DescribeDatabase(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DescribeDatabase_Call {
return &MockRootCoordClient_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockRootCoordClient_DescribeDatabase_Call) Run(run func(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DescribeDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*rootcoordpb.DescribeDatabaseRequest), variadicArgs...)
})
return _c
}
func (_c *MockRootCoordClient_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *MockRootCoordClient_DescribeDatabase_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockRootCoordClient_DescribeDatabase_Call) RunAndReturn(run func(context.Context, *rootcoordpb.DescribeDatabaseRequest, ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error)) *MockRootCoordClient_DescribeDatabase_Call {
_c.Call.Return(run)
return _c
}
// DropAlias provides a mock function with given fields: ctx, in, opts
func (_m *MockRootCoordClient) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))

View File

@ -634,6 +634,7 @@ message AlterCollectionRequest {
repeated int64 partitionIDs = 3;
repeated common.KeyDataPair start_positions = 4;
repeated common.KeyValuePair properties = 5;
int64 dbID = 6;
}
message GcConfirmRequest {

View File

@ -266,6 +266,13 @@ message ShowConfigurationsResponse {
repeated common.KeyValuePair configuations = 2;
}
enum RateScope {
Cluster = 0;
Database = 1;
Collection = 2;
Partition = 3;
}
enum RateType {
DDLCollection = 0;
DDLPartition = 1;

View File

@ -58,6 +58,7 @@ message RefreshPolicyInfoCacheRequest {
string opKey = 3;
}
// Deprecated: use ClusterLimiter instead it
message CollectionRate {
int64 collection = 1;
repeated internal.Rate rates = 2;
@ -65,9 +66,27 @@ message CollectionRate {
repeated common.ErrorCode codes = 4;
}
message LimiterNode {
// self limiter information
Limiter limiter = 1;
// db id -> db limiter
// collection id -> collection limiter
// partition id -> partition limiter
map<int64, LimiterNode> children = 2;
}
message Limiter {
repeated internal.Rate rates = 1;
// we can use map to store quota states and error code, because key in map fields cannot be enum types
repeated milvus.QuotaState states = 2;
repeated common.ErrorCode codes = 3;
}
message SetRatesRequest {
common.MsgBase base = 1;
// deprecated
repeated CollectionRate rates = 2;
LimiterNode rootLimiter = 3;
}
message ListClientInfosRequest {

View File

@ -140,6 +140,7 @@ service RootCoord {
rpc CreateDatabase(milvus.CreateDatabaseRequest) returns (common.Status) {}
rpc DropDatabase(milvus.DropDatabaseRequest) returns (common.Status) {}
rpc ListDatabases(milvus.ListDatabasesRequest) returns (milvus.ListDatabasesResponse) {}
rpc DescribeDatabase(DescribeDatabaseRequest) returns(DescribeDatabaseResponse){}
}
message AllocTimestampRequest {
@ -206,3 +207,14 @@ message GetCredentialResponse {
string password = 3;
}
message DescribeDatabaseRequest {
common.MsgBase base = 1;
string db_name = 2;
}
message DescribeDatabaseResponse {
common.Status status = 1;
string db_name = 2;
int64 dbID = 3;
uint64 created_timestamp = 4;
}

View File

@ -56,6 +56,8 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/requestutil"
"github.com/milvus-io/milvus/pkg/util/retry"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -160,8 +162,10 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
for _, alias := range aliasName {
metrics.CleanupProxyCollectionMetrics(paramtable.GetNodeID(), alias)
}
DeregisterSubLabel(ratelimitutil.GetCollectionSubLabel(request.GetDbName(), request.GetCollectionName()))
} else if msgType == commonpb.MsgType_DropDatabase {
metrics.CleanupProxyDBMetrics(paramtable.GetNodeID(), request.GetDbName())
DeregisterSubLabel(ratelimitutil.GetDBSubLabel(request.GetDbName()))
}
log.Info("complete to invalidate collection meta cache")
@ -289,6 +293,7 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
}
log.Info(rpcDone(method))
DeregisterSubLabel(ratelimitutil.GetDBSubLabel(request.GetDbName()))
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
@ -527,6 +532,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
zap.Uint64("BeginTs", dct.BeginTs()),
zap.Uint64("EndTs", dct.EndTs()),
)
DeregisterSubLabel(ratelimitutil.GetCollectionSubLabel(request.GetDbName(), request.GetCollectionName()))
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
@ -2680,11 +2686,11 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
hookutil.FailCntKey: len(it.result.ErrIndex),
})
SetReportValue(it.result.GetStatus(), v)
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
if merr.Ok(it.result.GetStatus()) {
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeUpsert, dbName, username).Add(float64(v))
}
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
metrics.ProxyFunctionCall.WithLabelValues(nodeID, method,
metrics.SuccessLabel, dbName, collectionName).Inc()
successCnt := it.result.UpsertCnt - int64(len(it.result.ErrIndex))
@ -2700,6 +2706,19 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
return it.result, nil
}
func GetDBAndCollectionRateSubLabels(req any) []string {
subLabels := make([]string, 2)
dbName, _ := requestutil.GetDbNameFromRequest(req)
if dbName != "" {
subLabels[0] = ratelimitutil.GetDBSubLabel(dbName.(string))
}
collectionName, _ := requestutil.GetCollectionNameFromRequest(req)
if collectionName != "" {
subLabels[1] = ratelimitutil.GetCollectionSubLabel(dbName.(string), collectionName.(string))
}
return subLabels
}
// Search searches the most similar records of requests.
func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
var err error
@ -2734,7 +2753,8 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
request.GetCollectionName(),
).Add(float64(request.GetNq()))
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()))
subLabels := GetDBAndCollectionRateSubLabels(request)
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()), subLabels...)
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
@ -2909,8 +2929,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if merr.Ok(qt.result.GetStatus()) {
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v))
}
metrics.ProxyReadReqSendBytes.WithLabelValues(nodeID).Add(float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...)
}
return qt.result, nil
}
@ -2941,6 +2962,13 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
request.GetCollectionName(),
).Add(float64(receiveSize))
subLabels := GetDBAndCollectionRateSubLabels(request)
allNQ := int64(0)
for _, searchRequest := range request.Requests {
allNQ += searchRequest.GetNq()
}
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(allNQ), subLabels...)
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
@ -3098,8 +3126,9 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
if merr.Ok(qt.result.GetStatus()) {
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeHybridSearch, dbName, username).Add(float64(v))
}
metrics.ProxyReadReqSendBytes.WithLabelValues(nodeID).Add(float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...)
}
return qt.result, nil
}
@ -3246,7 +3275,8 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryRes
request.GetCollectionName(),
).Add(float64(1))
rateCol.Add(internalpb.RateType_DQLQuery.String(), 1)
subLabels := GetDBAndCollectionRateSubLabels(request)
rateCol.Add(internalpb.RateType_DQLQuery.String(), 1, subLabels...)
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.QueryResults{
@ -3364,7 +3394,7 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryRes
).Observe(float64(tr.ElapseSpan().Milliseconds()))
sentSize := proto.Size(qt.result)
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize), subLabels...)
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
return qt.result, nil
@ -5092,7 +5122,7 @@ func (node *Proxy) SetRates(ctx context.Context, request *proxypb.SetRatesReques
return resp, nil
}
err := node.multiRateLimiter.SetRates(request.GetRates())
err := node.simpleLimiter.SetRates(request.GetRootLimiter())
// TODO: set multiple rate limiter rates
if err != nil {
resp = merr.Status(err)
@ -5162,12 +5192,9 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt
}, nil
}
states, reasons := node.multiRateLimiter.GetQuotaStates()
return &milvuspb.CheckHealthResponse{
Status: merr.Success(),
QuotaStates: states,
Reasons: reasons,
IsHealthy: true,
Status: merr.Success(),
IsHealthy: true,
}, nil
}
@ -5978,3 +6005,10 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR
metrics.ProxyReqLatency.WithLabelValues(nodeID, method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return resp, nil
}
// DeregisterSubLabel must add the sub-labels here if using other labels for the sub-labels
func DeregisterSubLabel(subLabel string) {
rateCol.DeregisterSubLabel(internalpb.RateType_DQLQuery.String(), subLabel)
rateCol.DeregisterSubLabel(internalpb.RateType_DQLSearch.String(), subLabel)
rateCol.DeregisterSubLabel(metricsinfo.ReadResultThroughput, subLabel)
}

View File

@ -63,6 +63,7 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) {
chMgr.EXPECT().removeDMLStream(mock.Anything).Return()
node := &Proxy{chMgr: chMgr}
_ = node.initRateCollector()
node.UpdateStateCode(commonpb.StateCode_Healthy)
ctx := context.Background()
@ -78,7 +79,7 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) {
func TestProxy_CheckHealth(t *testing.T) {
t.Run("not healthy", func(t *testing.T) {
node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Abnormal)
ctx := context.Background()
resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
@ -96,7 +97,7 @@ func TestProxy_CheckHealth(t *testing.T) {
dataCoord: NewDataCoordMock(),
session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}},
}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
ctx := context.Background()
resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
@ -129,7 +130,7 @@ func TestProxy_CheckHealth(t *testing.T) {
queryCoord: qc,
dataCoord: dataCoordMock,
}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
ctx := context.Background()
resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
@ -146,7 +147,7 @@ func TestProxy_CheckHealth(t *testing.T) {
dataCoord: NewDataCoordMock(),
queryCoord: qc,
}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err := node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{})
assert.NoError(t, err)
@ -156,18 +157,30 @@ func TestProxy_CheckHealth(t *testing.T) {
states := []milvuspb.QuotaState{milvuspb.QuotaState_DenyToWrite, milvuspb.QuotaState_DenyToRead}
codes := []commonpb.ErrorCode{commonpb.ErrorCode_MemoryQuotaExhausted, commonpb.ErrorCode_ForceDeny}
node.multiRateLimiter.SetRates([]*proxypb.CollectionRate{
{
Collection: 1,
States: states,
Codes: codes,
err = node.simpleLimiter.SetRates(&proxypb.LimiterNode{
Limiter: &proxypb.Limiter{},
// db level
Children: map[int64]*proxypb.LimiterNode{
1: {
Limiter: &proxypb.Limiter{},
// collection level
Children: map[int64]*proxypb.LimiterNode{
100: {
Limiter: &proxypb.Limiter{
States: states,
Codes: codes,
},
Children: make(map[int64]*proxypb.LimiterNode),
},
},
},
},
})
assert.NoError(t, err)
resp, err = node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{})
assert.NoError(t, err)
assert.Equal(t, true, resp.IsHealthy)
assert.Equal(t, 2, len(resp.GetQuotaStates()))
assert.Equal(t, 2, len(resp.GetReasons()))
})
}
@ -229,7 +242,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
qc := mocks.NewMockQueryCoordClient(t)
@ -321,7 +334,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
qc := mocks.NewMockQueryCoordClient(t)
@ -922,7 +935,7 @@ func TestProxyCreateDatabase(t *testing.T) {
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.ddQueue.setMaxTaskNum(10)
@ -977,11 +990,12 @@ func TestProxyDropDatabase(t *testing.T) {
ctx := context.Background()
node, err := NewProxy(ctx, factory)
node.initRateCollector()
assert.NoError(t, err)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.ddQueue.setMaxTaskNum(10)
@ -1040,7 +1054,7 @@ func TestProxyListDatabase(t *testing.T) {
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.multiRateLimiter = NewMultiRateLimiter()
node.simpleLimiter = NewSimpleLimiter()
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.ddQueue.setMaxTaskNum(10)

View File

@ -89,10 +89,10 @@ type Cache interface {
RemoveDatabase(ctx context.Context, database string)
HasDatabase(ctx context.Context, database string) bool
GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error)
// AllocID is only using on requests that need to skip timestamp allocation, don't overuse it.
AllocID(ctx context.Context) (int64, error)
}
type collectionBasicInfo struct {
collID typeutil.UniqueID
createdTimestamp uint64
@ -109,6 +109,11 @@ type collectionInfo struct {
consistencyLevel commonpb.ConsistencyLevel
}
type databaseInfo struct {
dbID typeutil.UniqueID
createdTimestamp uint64
}
// schemaInfo is a helper function wraps *schemapb.CollectionSchema
// with extra fields mapping and methods
type schemaInfo struct {
@ -244,17 +249,19 @@ type MetaCache struct {
rootCoord types.RootCoordClient
queryCoord types.QueryCoordClient
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
dbInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
mu sync.RWMutex
credMut sync.RWMutex
leaderMut sync.RWMutex
shardMgr shardClientMgr
sfGlobal conc.Singleflight[*collectionInfo]
dbInfo map[string]*databaseInfo // database -> db_info
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
dbCollectionInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
mu sync.RWMutex
credMut sync.RWMutex
leaderMut sync.RWMutex
shardMgr shardClientMgr
sfGlobal conc.Singleflight[*collectionInfo]
sfDB conc.Singleflight[*databaseInfo]
IDStart int64
IDCount int64
@ -287,15 +294,16 @@ func InitMetaCache(ctx context.Context, rootCoord types.RootCoordClient, queryCo
// NewMetaCache creates a MetaCache with provided RootCoord and QueryNode
func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordClient, shardMgr shardClientMgr) (*MetaCache, error) {
return &MetaCache{
rootCoord: rootCoord,
queryCoord: queryCoord,
collInfo: map[string]map[string]*collectionInfo{},
collLeader: map[string]map[string]*shardLeaders{},
dbInfo: map[string]map[typeutil.UniqueID]string{},
credMap: map[string]*internalpb.CredentialInfo{},
shardMgr: shardMgr,
privilegeInfos: map[string]struct{}{},
userToRoles: map[string]map[string]struct{}{},
rootCoord: rootCoord,
queryCoord: queryCoord,
dbInfo: map[string]*databaseInfo{},
collInfo: map[string]map[string]*collectionInfo{},
collLeader: map[string]map[string]*shardLeaders{},
dbCollectionInfo: map[string]map[typeutil.UniqueID]string{},
credMap: map[string]*internalpb.CredentialInfo{},
shardMgr: shardMgr,
privilegeInfos: map[string]struct{}{},
userToRoles: map[string]map[string]struct{}{},
}, nil
}
@ -510,7 +518,7 @@ func (m *MetaCache) innerGetCollectionByID(collectionID int64) (string, string)
m.mu.RLock()
defer m.mu.RUnlock()
for database, db := range m.dbInfo {
for database, db := range m.dbCollectionInfo {
name, ok := db[collectionID]
if ok {
return database, name
@ -554,7 +562,7 @@ func (m *MetaCache) updateDBInfo(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.dbInfo = dbInfo
m.dbCollectionInfo = dbInfo
return nil
}
@ -739,6 +747,19 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio
return partitions, nil
}
func (m *MetaCache) describeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) {
req := &rootcoordpb.DescribeDatabaseRequest{
DbName: dbName,
}
resp, err := m.rootCoord.DescribeDatabase(ctx, req)
if err = merr.CheckRPCCall(resp, err); err != nil {
return nil, err
}
return resp, nil
}
// parsePartitionsInfo parse partitionInfo list to partitionInfos struct.
// prepare all name to id & info map
// try parse partition names to partitionKey index.
@ -1084,6 +1105,7 @@ func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.collInfo, database)
delete(m.dbInfo, database)
}
func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool {
@ -1093,6 +1115,41 @@ func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool {
return ok
}
func (m *MetaCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) {
dbInfo := m.safeGetDBInfo(database)
if dbInfo != nil {
return dbInfo, nil
}
dbInfo, err, _ := m.sfDB.Do(database, func() (*databaseInfo, error) {
resp, err := m.describeDatabase(ctx, database)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
dbInfo := &databaseInfo{
dbID: resp.GetDbID(),
createdTimestamp: resp.GetCreatedTimestamp(),
}
m.dbInfo[database] = dbInfo
return dbInfo, nil
})
return dbInfo, err
}
func (m *MetaCache) safeGetDBInfo(database string) *databaseInfo {
m.mu.RLock()
defer m.mu.RUnlock()
db, ok := m.dbInfo[database]
if !ok {
return nil
}
return db
}
func (m *MetaCache) AllocID(ctx context.Context) (int64, error) {
m.IDLock.Lock()
defer m.IDLock.Unlock()

View File

@ -817,6 +817,49 @@ func TestMetaCache_Database(t *testing.T) {
assert.Equal(t, CheckDatabase(ctx, dbName), true)
}
func TestGetDatabaseInfo(t *testing.T) {
t.Run("success", func(t *testing.T) {
ctx := context.Background()
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := &mocks.MockQueryCoordClient{}
shardMgr := newShardClientMgr()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
Status: merr.Success(),
DbID: 1,
DbName: "default",
}, nil).Once()
{
dbInfo, err := cache.GetDatabaseInfo(ctx, "default")
assert.NoError(t, err)
assert.Equal(t, UniqueID(1), dbInfo.dbID)
}
{
dbInfo, err := cache.GetDatabaseInfo(ctx, "default")
assert.NoError(t, err)
assert.Equal(t, UniqueID(1), dbInfo.dbID)
}
})
t.Run("error", func(t *testing.T) {
ctx := context.Background()
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := &mocks.MockQueryCoordClient{}
shardMgr := newShardClientMgr()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
Status: merr.Status(errors.New("mock error: describe database")),
}, nil).Once()
_, err = cache.GetDatabaseInfo(ctx, "default")
assert.Error(t, err)
})
}
func TestMetaCache_AllocID(t *testing.T) {
ctx := context.Background()
queryCoord := &mocks.MockQueryCoordClient{}
@ -935,9 +978,9 @@ func TestGlobalMetaCache_UpdateDBInfo(t *testing.T) {
}, nil).Once()
err := cache.updateDBInfo(ctx)
assert.NoError(t, err)
assert.Len(t, cache.dbInfo, 1)
assert.Len(t, cache.dbInfo["db1"], 1)
assert.Equal(t, "collection1", cache.dbInfo["db1"][1])
assert.Len(t, cache.dbCollectionInfo, 1)
assert.Len(t, cache.dbCollectionInfo["db1"], 1)
assert.Equal(t, "collection1", cache.dbCollectionInfo["db1"][1])
})
}

View File

@ -50,12 +50,29 @@ func getQuotaMetrics() (*metricsinfo.ProxyQuotaMetrics, error) {
Rate: rate,
})
}
getSubLabelRateMetric := func(label string) {
rates, err2 := rateCol.RateSubLabel(label, ratelimitutil.DefaultAvgDuration)
if err2 != nil {
err = err2
return
}
for s, f := range rates {
rms = append(rms, metricsinfo.RateMetric{
Label: s,
Rate: f,
})
}
}
getRateMetric(internalpb.RateType_DMLInsert.String())
getRateMetric(internalpb.RateType_DMLUpsert.String())
getRateMetric(internalpb.RateType_DMLDelete.String())
getRateMetric(internalpb.RateType_DQLSearch.String())
getSubLabelRateMetric(internalpb.RateType_DQLSearch.String())
getRateMetric(internalpb.RateType_DQLQuery.String())
getSubLabelRateMetric(internalpb.RateType_DQLQuery.String())
getRateMetric(metricsinfo.ReadResultThroughput)
getSubLabelRateMetric(metricsinfo.ReadResultThroughput)
if err != nil {
return nil, err
}

View File

@ -450,6 +450,61 @@ func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Contex
return _c
}
// GetDatabaseInfo provides a mock function with given fields: ctx, database
func (_m *MockCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) {
ret := _m.Called(ctx, database)
var r0 *databaseInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*databaseInfo, error)); ok {
return rf(ctx, database)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *databaseInfo); ok {
r0 = rf(ctx, database)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*databaseInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, database)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetDatabaseInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseInfo'
type MockCache_GetDatabaseInfo_Call struct {
*mock.Call
}
// GetDatabaseInfo is a helper method to define mock.On call
// - ctx context.Context
// - database string
func (_e *MockCache_Expecter) GetDatabaseInfo(ctx interface{}, database interface{}) *MockCache_GetDatabaseInfo_Call {
return &MockCache_GetDatabaseInfo_Call{Call: _e.mock.On("GetDatabaseInfo", ctx, database)}
}
func (_c *MockCache_GetDatabaseInfo_Call) Run(run func(ctx context.Context, database string)) *MockCache_GetDatabaseInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockCache_GetDatabaseInfo_Call) Return(_a0 *databaseInfo, _a1 error) *MockCache_GetDatabaseInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetDatabaseInfo_Call) RunAndReturn(run func(context.Context, string) (*databaseInfo, error)) *MockCache_GetDatabaseInfo_Call {
_c.Call.Return(run)
return _c
}
// GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName
func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName, partitionName)

View File

@ -1,375 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"strconv"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var QuotaErrorString = map[commonpb.ErrorCode]string{
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator",
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",
}
func GetQuotaErrorString(errCode commonpb.ErrorCode) string {
return QuotaErrorString[errCode]
}
// MultiRateLimiter includes multilevel rate limiters, such as global rateLimiter,
// collection level rateLimiter and so on. It also implements Limiter interface.
type MultiRateLimiter struct {
quotaStatesMu sync.RWMutex
// for DML and DQL
collectionLimiters map[int64]*rateLimiter
// for DDL
globalDDLLimiter *rateLimiter
}
// NewMultiRateLimiter returns a new MultiRateLimiter.
func NewMultiRateLimiter() *MultiRateLimiter {
m := &MultiRateLimiter{
collectionLimiters: make(map[int64]*rateLimiter, 0),
globalDDLLimiter: newRateLimiter(true),
}
return m
}
// Check checks if request would be limited or denied.
func (m *MultiRateLimiter) Check(collectionIDs []int64, rt internalpb.RateType, n int) error {
if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
return nil
}
m.quotaStatesMu.RLock()
defer m.quotaStatesMu.RUnlock()
checkFunc := func(limiter *rateLimiter) error {
if limiter == nil {
return nil
}
limit, rate := limiter.limit(rt, n)
if rate == 0 {
return limiter.getQuotaExceededError(rt)
}
if limit {
return limiter.getRateLimitError(rate)
}
return nil
}
// first, check global level rate limits
ret := checkFunc(m.globalDDLLimiter)
// second check collection level rate limits
// only dml, dql and flush have collection level rate limits
if ret == nil && len(collectionIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
// store done limiters to cancel them when error occurs.
doneLimiters := make([]*rateLimiter, 0, len(collectionIDs)+1)
doneLimiters = append(doneLimiters, m.globalDDLLimiter)
for _, collectionID := range collectionIDs {
ret = checkFunc(m.collectionLimiters[collectionID])
if ret != nil {
for _, limiter := range doneLimiters {
limiter.cancel(rt, n)
}
break
}
doneLimiters = append(doneLimiters, m.collectionLimiters[collectionID])
}
}
return ret
}
func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool {
// Most ddl is global level, only DDLFlush will be applied at collection
switch rt {
case internalpb.RateType_DDLCollection, internalpb.RateType_DDLPartition, internalpb.RateType_DDLIndex,
internalpb.RateType_DDLCompaction:
return true
default:
return false
}
}
// GetQuotaStates returns quota states.
func (m *MultiRateLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) {
m.quotaStatesMu.RLock()
defer m.quotaStatesMu.RUnlock()
serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode])
// deduplicate same (state, code) pair from different collection
for _, limiter := range m.collectionLimiters {
limiter.quotaStates.Range(func(state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
if serviceStates[state] == nil {
serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]()
}
serviceStates[state].Insert(errCode)
return true
})
}
states := make([]milvuspb.QuotaState, 0)
reasons := make([]string, 0)
for state, errCodes := range serviceStates {
for errCode := range errCodes {
states = append(states, state)
reasons = append(reasons, GetQuotaErrorString(errCode))
}
}
return states, reasons
}
// SetQuotaStates sets quota states for MultiRateLimiter.
func (m *MultiRateLimiter) SetRates(rates []*proxypb.CollectionRate) error {
m.quotaStatesMu.Lock()
defer m.quotaStatesMu.Unlock()
collectionSet := typeutil.NewUniqueSet()
for _, collectionRates := range rates {
collectionSet.Insert(collectionRates.Collection)
rateLimiter, ok := m.collectionLimiters[collectionRates.GetCollection()]
if !ok {
rateLimiter = newRateLimiter(false)
}
err := rateLimiter.setRates(collectionRates)
if err != nil {
return err
}
m.collectionLimiters[collectionRates.GetCollection()] = rateLimiter
}
// remove dropped collection's rate limiter
for collectionID := range m.collectionLimiters {
if !collectionSet.Contain(collectionID) {
delete(m.collectionLimiters, collectionID)
}
}
return nil
}
// rateLimiter implements Limiter.
type rateLimiter struct {
limiters *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]
quotaStates *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]
}
// newRateLimiter returns a new RateLimiter.
func newRateLimiter(globalLevel bool) *rateLimiter {
rl := &rateLimiter{
limiters: typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter](),
quotaStates: typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode](),
}
rl.registerLimiters(globalLevel)
return rl
}
// limit returns true, the request will be rejected.
// Otherwise, the request will pass.
func (rl *rateLimiter) limit(rt internalpb.RateType, n int) (bool, float64) {
limit, ok := rl.limiters.Get(rt)
if !ok {
return false, -1
}
return !limit.AllowN(time.Now(), n), float64(limit.Limit())
}
func (rl *rateLimiter) cancel(rt internalpb.RateType, n int) {
limit, ok := rl.limiters.Get(rt)
if !ok {
return
}
limit.Cancel(n)
}
func (rl *rateLimiter) setRates(collectionRate *proxypb.CollectionRate) error {
log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0).With(
zap.Int64("proxyNodeID", paramtable.GetNodeID()),
zap.Int64("CollectionID", collectionRate.Collection),
)
for _, r := range collectionRate.GetRates() {
if limit, ok := rl.limiters.Get(r.GetRt()); ok {
limit.SetLimit(ratelimitutil.Limit(r.GetR()))
setRateGaugeByRateType(r.GetRt(), paramtable.GetNodeID(), collectionRate.Collection, r.GetR())
} else {
return fmt.Errorf("unregister rateLimiter for rateType %s", r.GetRt().String())
}
log.RatedDebug(30, "current collection rates in proxy",
zap.String("rateType", r.Rt.String()),
zap.String("rateLimit", ratelimitutil.Limit(r.GetR()).String()),
)
}
// clear old quota states
rl.quotaStates = typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
for i := 0; i < len(collectionRate.GetStates()); i++ {
rl.quotaStates.Insert(collectionRate.States[i], collectionRate.Codes[i])
log.RatedWarn(30, "Proxy set collection quota states",
zap.String("state", collectionRate.GetStates()[i].String()),
zap.String("reason", collectionRate.GetCodes()[i].String()),
)
}
return nil
}
func (rl *rateLimiter) getQuotaExceededError(rt internalpb.RateType) error {
switch rt {
case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad:
if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok {
return merr.WrapErrServiceQuotaExceeded(GetQuotaErrorString(errCode))
}
case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery:
if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok {
return merr.WrapErrServiceQuotaExceeded(GetQuotaErrorString(errCode))
}
}
return nil
}
func (rl *rateLimiter) getRateLimitError(rate float64) error {
return merr.WrapErrServiceRateLimit(rate, "request is rejected by grpc RateLimiter middleware, please retry later")
}
// setRateGaugeByRateType sets ProxyLimiterRate metrics.
func setRateGaugeByRateType(rateType internalpb.RateType, nodeID int64, collectionID int64, rate float64) {
if ratelimitutil.Limit(rate) == ratelimitutil.Inf {
return
}
nodeIDStr := strconv.FormatInt(nodeID, 10)
collectionIDStr := strconv.FormatInt(collectionID, 10)
switch rateType {
case internalpb.RateType_DMLInsert:
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.InsertLabel).Set(rate)
case internalpb.RateType_DMLUpsert:
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.UpsertLabel).Set(rate)
case internalpb.RateType_DMLDelete:
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.DeleteLabel).Set(rate)
case internalpb.RateType_DQLSearch:
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.SearchLabel).Set(rate)
case internalpb.RateType_DQLQuery:
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, collectionIDStr, metrics.QueryLabel).Set(rate)
}
}
// registerLimiters register limiter for all rate types.
func (rl *rateLimiter) registerLimiters(globalLevel bool) {
log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0)
quotaConfig := &Params.QuotaConfig
for rt := range internalpb.RateType_name {
var r *paramtable.ParamItem
switch internalpb.RateType(rt) {
case internalpb.RateType_DDLCollection:
r = &quotaConfig.DDLCollectionRate
case internalpb.RateType_DDLPartition:
r = &quotaConfig.DDLPartitionRate
case internalpb.RateType_DDLIndex:
r = &quotaConfig.MaxIndexRate
case internalpb.RateType_DDLFlush:
if globalLevel {
r = &quotaConfig.MaxFlushRate
} else {
r = &quotaConfig.MaxFlushRatePerCollection
}
case internalpb.RateType_DDLCompaction:
r = &quotaConfig.MaxCompactionRate
case internalpb.RateType_DMLInsert:
if globalLevel {
r = &quotaConfig.DMLMaxInsertRate
} else {
r = &quotaConfig.DMLMaxInsertRatePerCollection
}
case internalpb.RateType_DMLUpsert:
if globalLevel {
r = &quotaConfig.DMLMaxUpsertRate
} else {
r = &quotaConfig.DMLMaxUpsertRatePerCollection
}
case internalpb.RateType_DMLDelete:
if globalLevel {
r = &quotaConfig.DMLMaxDeleteRate
} else {
r = &quotaConfig.DMLMaxDeleteRatePerCollection
}
case internalpb.RateType_DMLBulkLoad:
if globalLevel {
r = &quotaConfig.DMLMaxBulkLoadRate
} else {
r = &quotaConfig.DMLMaxBulkLoadRatePerCollection
}
case internalpb.RateType_DQLSearch:
if globalLevel {
r = &quotaConfig.DQLMaxSearchRate
} else {
r = &quotaConfig.DQLMaxSearchRatePerCollection
}
case internalpb.RateType_DQLQuery:
if globalLevel {
r = &quotaConfig.DQLMaxQueryRate
} else {
r = &quotaConfig.DQLMaxQueryRatePerCollection
}
}
limit := ratelimitutil.Limit(r.GetAsFloat())
burst := r.GetAsFloat() // use rate as burst, because Limiter is with punishment mechanism, burst is insignificant.
rl.limiters.GetOrInsert(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst))
onEvent := func(rateType internalpb.RateType) func(*config.Event) {
return func(event *config.Event) {
f, err := strconv.ParseFloat(r.Formatter(event.Value), 64)
if err != nil {
log.Info("Error format for rateLimit",
zap.String("rateType", rateType.String()),
zap.String("key", event.Key),
zap.String("value", event.Value),
zap.Error(err))
return
}
limit, ok := rl.limiters.Get(rateType)
if !ok {
return
}
limit.SetLimit(ratelimitutil.Limit(f))
}
}(internalpb.RateType(rt))
paramtable.Get().Watch(r.Key, config.NewHandler(fmt.Sprintf("rateLimiter-%d", rt), onEvent))
log.RatedDebug(30, "RateLimiter register for rateType",
zap.String("rateType", internalpb.RateType_name[rt]),
zap.String("rateLimit", ratelimitutil.Limit(r.GetAsFloat()).String()),
zap.String("burst", fmt.Sprintf("%v", burst)))
}
}

View File

@ -1,338 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"math"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
)
func TestMultiRateLimiter(t *testing.T) {
collectionID := int64(1)
t.Run("test multiRateLimiter", func(t *testing.T) {
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
multiLimiter := NewMultiRateLimiter()
multiLimiter.collectionLimiters[collectionID] = newRateLimiter(false)
for _, rt := range internalpb.RateType_value {
if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
} else {
multiLimiter.collectionLimiters[collectionID].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
}
for _, rt := range internalpb.RateType_value {
if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 5)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 5)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
} else {
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), math.MaxInt)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), math.MaxInt)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
}
}
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
})
t.Run("test global static limit", func(t *testing.T) {
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
multiLimiter := NewMultiRateLimiter()
multiLimiter.collectionLimiters[1] = newRateLimiter(false)
multiLimiter.collectionLimiters[2] = newRateLimiter(false)
multiLimiter.collectionLimiters[3] = newRateLimiter(false)
for _, rt := range internalpb.RateType_value {
if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
} else {
multiLimiter.globalDDLLimiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
multiLimiter.collectionLimiters[1].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
multiLimiter.collectionLimiters[2].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
multiLimiter.collectionLimiters[3].limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
}
}
for _, rt := range internalpb.RateType_value {
if internalpb.RateType(rt) == internalpb.RateType_DDLFlush {
err := multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 5)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{1, 2, 3}, internalpb.RateType(rt), 5)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
} else if isNotCollectionLevelLimitRequest(internalpb.RateType(rt)) {
err := multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 5)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 5)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
} else {
err := multiLimiter.Check([]int64{1}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{2}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = multiLimiter.Check([]int64{3}, internalpb.RateType(rt), 1)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
}
}
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
})
t.Run("not enable quotaAndLimit", func(t *testing.T) {
multiLimiter := NewMultiRateLimiter()
multiLimiter.collectionLimiters[collectionID] = newRateLimiter(false)
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
for _, rt := range internalpb.RateType_value {
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType(rt), 1)
assert.NoError(t, err)
}
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
})
t.Run("test limit", func(t *testing.T) {
run := func(insertRate float64) {
bakInsertRate := Params.QuotaConfig.DMLMaxInsertRate.GetValue()
paramtable.Get().Save(Params.QuotaConfig.DMLMaxInsertRate.Key, fmt.Sprintf("%f", insertRate))
multiLimiter := NewMultiRateLimiter()
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
err := multiLimiter.Check([]int64{collectionID}, internalpb.RateType_DMLInsert, 1*1024*1024)
assert.NoError(t, err)
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate)
}
run(math.MaxFloat64)
run(math.MaxFloat64 / 1.2)
run(math.MaxFloat64 / 2)
run(math.MaxFloat64 / 3)
run(math.MaxFloat64 / 10000)
})
t.Run("test set rates", func(t *testing.T) {
multiLimiter := NewMultiRateLimiter()
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
for _, rt := range internalpb.RateType_value {
zeroRates = append(zeroRates, &internalpb.Rate{
Rt: internalpb.RateType(rt), R: 0,
})
}
err := multiLimiter.SetRates([]*proxypb.CollectionRate{
{
Collection: 1,
Rates: zeroRates,
},
{
Collection: 2,
Rates: zeroRates,
},
})
assert.NoError(t, err)
})
t.Run("test quota states", func(t *testing.T) {
multiLimiter := NewMultiRateLimiter()
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
for _, rt := range internalpb.RateType_value {
zeroRates = append(zeroRates, &internalpb.Rate{
Rt: internalpb.RateType(rt), R: 0,
})
}
err := multiLimiter.SetRates([]*proxypb.CollectionRate{
{
Collection: 1,
Rates: zeroRates,
States: []milvuspb.QuotaState{
milvuspb.QuotaState_DenyToWrite,
},
Codes: []commonpb.ErrorCode{
commonpb.ErrorCode_DiskQuotaExhausted,
},
},
{
Collection: 2,
Rates: zeroRates,
States: []milvuspb.QuotaState{
milvuspb.QuotaState_DenyToRead,
},
Codes: []commonpb.ErrorCode{
commonpb.ErrorCode_ForceDeny,
},
},
})
assert.NoError(t, err)
states, codes := multiLimiter.GetQuotaStates()
assert.Len(t, states, 2)
assert.Len(t, codes, 2)
assert.Contains(t, codes, GetQuotaErrorString(commonpb.ErrorCode_DiskQuotaExhausted))
assert.Contains(t, codes, GetQuotaErrorString(commonpb.ErrorCode_ForceDeny))
})
}
func TestRateLimiter(t *testing.T) {
t.Run("test limit", func(t *testing.T) {
paramtable.Get().CleanEvent()
limiter := newRateLimiter(false)
for _, rt := range internalpb.RateType_value {
limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
for _, rt := range internalpb.RateType_value {
ok, _ := limiter.limit(internalpb.RateType(rt), 1)
assert.False(t, ok)
ok, _ = limiter.limit(internalpb.RateType(rt), math.MaxInt)
assert.False(t, ok)
ok, _ = limiter.limit(internalpb.RateType(rt), math.MaxInt)
assert.True(t, ok)
}
})
t.Run("test setRates", func(t *testing.T) {
paramtable.Get().CleanEvent()
limiter := newRateLimiter(false)
for _, rt := range internalpb.RateType_value {
limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
for _, rt := range internalpb.RateType_value {
zeroRates = append(zeroRates, &internalpb.Rate{
Rt: internalpb.RateType(rt), R: 0,
})
}
err := limiter.setRates(&proxypb.CollectionRate{
Collection: 1,
Rates: zeroRates,
})
assert.NoError(t, err)
for _, rt := range internalpb.RateType_value {
for i := 0; i < 100; i++ {
ok, _ := limiter.limit(internalpb.RateType(rt), 1)
assert.True(t, ok)
}
}
err = limiter.setRates(&proxypb.CollectionRate{
Collection: 1,
States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToRead, milvuspb.QuotaState_DenyToWrite},
Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_DiskQuotaExhausted},
})
assert.NoError(t, err)
assert.Equal(t, limiter.quotaStates.Len(), 2)
err = limiter.setRates(&proxypb.CollectionRate{
Collection: 1,
States: []milvuspb.QuotaState{},
})
assert.NoError(t, err)
assert.Equal(t, limiter.quotaStates.Len(), 0)
})
t.Run("test get error code", func(t *testing.T) {
paramtable.Get().CleanEvent()
limiter := newRateLimiter(false)
for _, rt := range internalpb.RateType_value {
limiter.limiters.Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
for _, rt := range internalpb.RateType_value {
zeroRates = append(zeroRates, &internalpb.Rate{
Rt: internalpb.RateType(rt), R: 0,
})
}
err := limiter.setRates(&proxypb.CollectionRate{
Collection: 1,
Rates: zeroRates,
States: []milvuspb.QuotaState{
milvuspb.QuotaState_DenyToWrite,
milvuspb.QuotaState_DenyToRead,
},
Codes: []commonpb.ErrorCode{
commonpb.ErrorCode_DiskQuotaExhausted,
commonpb.ErrorCode_ForceDeny,
},
})
assert.NoError(t, err)
assert.Error(t, limiter.getQuotaExceededError(internalpb.RateType_DQLQuery))
assert.Error(t, limiter.getQuotaExceededError(internalpb.RateType_DMLInsert))
})
t.Run("tests refresh rate by config", func(t *testing.T) {
paramtable.Get().CleanEvent()
limiter := newRateLimiter(false)
etcdCli, _ := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
Params.EtcdCfg.Endpoints.GetAsStrings(),
Params.EtcdCfg.EtcdTLSCert.GetValue(),
Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
Params.Save(Params.QuotaConfig.DDLLimitEnabled.Key, "true")
defer Params.Reset(Params.QuotaConfig.DDLLimitEnabled.Key)
Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true")
defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key)
ctx := context.Background()
// avoid production precision issues when comparing 0-terminated numbers
newRate := fmt.Sprintf("%.2f1", rand.Float64())
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate)
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate")
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid")
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate")
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/collection/max", "8")
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/collection/max")
assert.Eventually(t, func() bool {
limit, _ := limiter.limiters.Get(internalpb.RateType_DDLCollection)
return newRate == limit.Limit().String()
}, 20*time.Second, time.Second)
limit, _ := limiter.limiters.Get(internalpb.RateType_DDLPartition)
assert.Equal(t, "+inf", limit.Limit().String())
limit, _ = limiter.limiters.Get(internalpb.RateType_DMLInsert)
assert.Equal(t, "8.388608e+06", limit.Limit().String())
})
}

View File

@ -93,7 +93,7 @@ type Proxy struct {
dataCoord types.DataCoordClient
queryCoord types.QueryCoordClient
multiRateLimiter *MultiRateLimiter
simpleLimiter *SimpleLimiter
chMgr channelsMgr
@ -147,7 +147,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
factory: factory,
searchResultCh: make(chan *internalpb.SearchResults, n),
shardMgr: mgr,
multiRateLimiter: NewMultiRateLimiter(),
simpleLimiter: NewSimpleLimiter(),
lbPolicy: lbPolicy,
resourceManager: resourceManager,
replicateStreamManager: replicateStreamManager,
@ -197,7 +197,7 @@ func (node *Proxy) initSession() error {
// initRateCollector creates and starts rateCollector in Proxy.
func (node *Proxy) initRateCollector() error {
var err error
rateCol, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
rateCol, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, true)
if err != nil {
return err
}
@ -542,8 +542,8 @@ func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string,
// GetRateLimiter returns the rateLimiter in Proxy.
func (node *Proxy) GetRateLimiter() (types.Limiter, error) {
if node.multiRateLimiter == nil {
if node.simpleLimiter == nil {
return nil, fmt.Errorf("nil rate limiter in Proxy")
}
return node.multiRateLimiter, nil
return node.simpleLimiter, nil
}

View File

@ -298,8 +298,7 @@ func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p *
ctx, cancel := context.WithCancel(ctx)
defer cancel()
multiLimiter := NewMultiRateLimiter()
s.multiRateLimiter = multiLimiter
s.simpleLimiter = NewSimpleLimiter()
opts := tracer.GetInterceptorOpts()
s.grpcServer = grpc.NewServer(
@ -309,7 +308,7 @@ func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p *
grpc.MaxSendMsgSize(p.ServerMaxSendSize.GetAsInt()),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
otelgrpc.UnaryServerInterceptor(opts...),
RateLimitInterceptor(multiLimiter),
RateLimitInterceptor(s.simpleLimiter),
)),
grpc.StreamInterceptor(otelgrpc.StreamServerInterceptor(opts...)))
proxypb.RegisterProxyServer(s.grpcServer, s)

View File

@ -23,25 +23,29 @@ import (
"strconv"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/requestutil"
)
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
collectionIDs, rt, n, err := getRequestInfo(req)
dbID, collectionIDToPartIDs, rt, n, err := getRequestInfo(ctx, req)
if err != nil {
log.RatedWarn(10, "failed to get request info", zap.Error(err))
return handler(ctx, req)
}
err = limiter.Check(collectionIDs, rt, n)
err = limiter.Check(dbID, collectionIDToPartIDs, rt, n)
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
if err != nil {
@ -56,72 +60,146 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
}
}
type reqPartName interface {
requestutil.DBNameGetter
requestutil.CollectionNameGetter
requestutil.PartitionNameGetter
}
type reqPartNames interface {
requestutil.DBNameGetter
requestutil.CollectionNameGetter
requestutil.PartitionNamesGetter
}
type reqCollName interface {
requestutil.DBNameGetter
requestutil.CollectionNameGetter
}
func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map[int64][]int64, error) {
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return 0, nil, err
}
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), r.GetCollectionName())
if err != nil {
return 0, nil, err
}
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
if err != nil {
return 0, nil, err
}
return db.dbID, map[int64][]int64{collectionID: {part.partitionID}}, nil
}
func getCollectionAndPartitionIDs(ctx context.Context, r reqPartNames) (int64, map[int64][]int64, error) {
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return 0, nil, err
}
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), r.GetCollectionName())
if err != nil {
return 0, nil, err
}
parts := make([]int64, len(r.GetPartitionNames()))
for i, s := range r.GetPartitionNames() {
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), s)
if err != nil {
return 0, nil, err
}
parts[i] = part.partitionID
}
return db.dbID, map[int64][]int64{collectionID: parts}, nil
}
func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
db, _ := globalMetaCache.GetDatabaseInfo(context.TODO(), r.GetDbName())
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return db.dbID, map[int64][]int64{collectionID: {}}
}
// getRequestInfo returns collection name and rateType of request and return tokens needed.
func getRequestInfo(req interface{}) ([]int64, internalpb.RateType, int, error) {
func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
switch r := req.(type) {
case *milvuspb.InsertRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DMLInsert, proto.Size(r), nil
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.UpsertRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DMLUpsert, proto.Size(r), nil
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLUpsert, proto.Size(r), err
case *milvuspb.DeleteRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DMLDelete, proto.Size(r), nil
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
case *milvuspb.ImportRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DMLBulkLoad, proto.Size(r), nil
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
case *milvuspb.SearchRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DQLSearch, int(r.GetNq()), nil
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
case *milvuspb.QueryRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DQLQuery, 1, nil // think of the query request's nq as 1
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
case *milvuspb.CreateCollectionRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.DropCollectionRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.LoadCollectionRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.ReleaseCollectionRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLCollection, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.CreatePartitionRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.DropPartitionRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.LoadPartitionsRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.ReleasePartitionsRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLPartition, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.CreateIndexRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLIndex, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.DropIndexRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return []int64{collectionID}, internalpb.RateType_DDLIndex, 1, nil
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.FlushRequest:
collectionIDs := make([]int64, 0, len(r.GetCollectionNames()))
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
}
collToPartIDs := make(map[int64][]int64, 0)
for _, collectionName := range r.GetCollectionNames() {
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), collectionName)
collectionIDs = append(collectionIDs, collectionID)
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
}
collToPartIDs[collectionID] = []int64{}
}
return collectionIDs, internalpb.RateType_DDLFlush, 1, nil
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
case *milvuspb.ManualCompactionRequest:
return nil, internalpb.RateType_DDLCompaction, 1, nil
// TODO: support more request
default:
if req == nil {
return nil, 0, 0, fmt.Errorf("null request")
dbName := GetCurDBNameFromContextOrDefault(ctx)
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
}
return nil, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
return dbInfo.dbID, map[int64][]int64{
r.GetCollectionID(): {},
}, internalpb.RateType_DDLCompaction, 1, nil
default: // TODO: support more request
if req == nil {
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
}
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
}
}

View File

@ -20,6 +20,7 @@ import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -38,7 +39,7 @@ type limiterMock struct {
quotaStateReasons []commonpb.ErrorCode
}
func (l *limiterMock) Check(collection []int64, rt internalpb.RateType, n int) error {
func (l *limiterMock) Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error {
if l.rate == 0 {
return merr.ErrServiceQuotaExceeded
}
@ -51,119 +52,173 @@ func (l *limiterMock) Check(collection []int64, rt internalpb.RateType, n int) e
func TestRateLimitInterceptor(t *testing.T) {
t.Run("test getRequestInfo", func(t *testing.T) {
mockCache := NewMockCache(t)
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(0), nil)
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
name: "p1",
partitionID: 10,
createdTimestamp: 10001,
createdUtcTimestamp: 10002,
}, nil)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil)
globalMetaCache = mockCache
collection, rt, size, err := getRequestInfo(&milvuspb.InsertRequest{})
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
assert.Equal(t, internalpb.RateType_DMLInsert, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
collection, rt, size, err = getRequestInfo(&milvuspb.UpsertRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
assert.Equal(t, internalpb.RateType_DMLUpsert, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
collection, rt, size, err = getRequestInfo(&milvuspb.DeleteRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{}), size)
assert.Equal(t, internalpb.RateType_DMLDelete, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
collection, rt, size, err = getRequestInfo(&milvuspb.ImportRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.ImportRequest{}), size)
assert.Equal(t, internalpb.RateType_DMLBulkLoad, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
collection, rt, size, err = getRequestInfo(&milvuspb.SearchRequest{Nq: 5})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.SearchRequest{
Nq: 5,
PartitionNames: []string{
"p1",
},
})
assert.NoError(t, err)
assert.Equal(t, 5, size)
assert.Equal(t, internalpb.RateType_DQLSearch, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 1, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.QueryRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DQLQuery, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.CreateCollectionRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.LoadCollectionRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.ReleaseCollectionRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleaseCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.DropCollectionRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCollection, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.CreatePartitionRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreatePartitionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.LoadPartitionsRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.LoadPartitionsRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.ReleasePartitionsRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ReleasePartitionsRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.DropPartitionRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropPartitionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLPartition, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.CreateIndexRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateIndexRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.DropIndexRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DropIndexRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLIndex, rt)
assert.ElementsMatch(t, collection, []int64{int64(0)})
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
collection, rt, size, err = getRequestInfo(&milvuspb.FlushRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.FlushRequest{
CollectionNames: []string{
"col1",
},
})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLFlush, rt)
assert.Len(t, collection, 0)
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
collection, rt, size, err = getRequestInfo(&milvuspb.ManualCompactionRequest{})
database, _, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ManualCompactionRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DDLCompaction, rt)
assert.Len(t, collection, 0)
assert.Equal(t, database, int64(100))
_, _, _, _, err = getRequestInfo(context.Background(), nil)
assert.Error(t, err)
_, _, _, _, err = getRequestInfo(context.Background(), &milvuspb.CalcDistanceRequest{})
assert.Error(t, err)
})
t.Run("test getFailedResponse", func(t *testing.T) {
@ -190,11 +245,17 @@ func TestRateLimitInterceptor(t *testing.T) {
t.Run("test RateLimitInterceptor", func(t *testing.T) {
mockCache := NewMockCache(t)
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(0), nil)
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
name: "p1",
partitionID: 10,
createdTimestamp: 10001,
createdUtcTimestamp: 10002,
}, nil)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil)
globalMetaCache = mockCache
limiter := limiterMock{rate: 100}
@ -224,4 +285,158 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_ForceDeny, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
assert.NoError(t, err)
})
t.Run("request info fail", func(t *testing.T) {
mockCache := NewMockCache(t)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info"))
originCache := globalMetaCache
globalMetaCache = mockCache
defer func() {
globalMetaCache = originCache
}()
limiter := limiterMock{rate: 100}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return &milvuspb.MutationResult{
Status: merr.Success(),
}, nil
}
serverInfo := &grpc.UnaryServerInfo{FullMethod: "MockFullMethod"}
limiter.limit = true
interceptorFun := RateLimitInterceptor(&limiter)
rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
assert.NoError(t, err)
})
}
func TestGetInfo(t *testing.T) {
mockCache := NewMockCache(t)
ctx := context.Background()
originCache := globalMetaCache
globalMetaCache = mockCache
defer func() {
globalMetaCache = originCache
}()
t.Run("fail to get database", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(4)
{
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
PartitionName: "p1",
})
assert.Error(t, err)
}
{
_, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
DbName: "foo",
CollectionName: "coo",
PartitionNames: []string{"p1"},
})
assert.Error(t, err)
}
{
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
DbName: "foo",
})
assert.Error(t, err)
}
{
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
assert.Error(t, err)
}
})
t.Run("fail to get collection", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil).Times(3)
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(0), errors.New("mock error: get collection id")).Times(3)
{
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
PartitionName: "p1",
})
assert.Error(t, err)
}
{
_, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
DbName: "foo",
CollectionName: "coo",
PartitionNames: []string{"p1"},
})
assert.Error(t, err)
}
{
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.FlushRequest{
DbName: "foo",
CollectionNames: []string{"coo"},
})
assert.Error(t, err)
}
})
t.Run("fail to get partition", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil).Twice()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Twice()
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get partition info")).Twice()
{
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
PartitionName: "p1",
})
assert.Error(t, err)
}
{
_, _, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
DbName: "foo",
CollectionName: "coo",
PartitionNames: []string{"p1"},
})
assert.Error(t, err)
}
})
t.Run("success", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil).Twice()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Twice()
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
name: "p1",
partitionID: 100,
}, nil)
{
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
PartitionName: "p1",
})
assert.NoError(t, err)
assert.Equal(t, int64(100), db)
assert.NotNil(t, col2par[10])
assert.Equal(t, int64(100), col2par[10][0])
}
{
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
DbName: "foo",
CollectionName: "coo",
PartitionNames: []string{"p1"},
})
assert.NoError(t, err)
assert.Equal(t, int64(100), db)
assert.NotNil(t, col2par[10])
assert.Equal(t, int64(100), col2par[10][0])
}
})
}

View File

@ -1111,6 +1111,10 @@ func (coord *RootCoordMock) RenameCollection(ctx context.Context, req *milvuspb.
return &commonpb.Status{}, nil
}
func (coord *RootCoordMock) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
return &rootcoordpb.DescribeDatabaseResponse{}, nil
}
type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error)
type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error)

View File

@ -0,0 +1,344 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"strconv"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/util/quota"
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// SimpleLimiter is implemented based on Limiter interface
type SimpleLimiter struct {
quotaStatesMu sync.RWMutex
rateLimiter *rlinternal.RateLimiterTree
}
// NewSimpleLimiter returns a new SimpleLimiter.
func NewSimpleLimiter() *SimpleLimiter {
rootRateLimiter := newClusterLimiter()
m := &SimpleLimiter{rateLimiter: rlinternal.NewRateLimiterTree(rootRateLimiter)}
return m
}
// Check checks if request would be limited or denied.
func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error {
if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
return nil
}
m.quotaStatesMu.RLock()
defer m.quotaStatesMu.RUnlock()
// 1. check global(cluster) level rate limits
clusterRateLimiters := m.rateLimiter.GetRootLimiters()
ret := clusterRateLimiters.Check(rt, n)
if ret != nil {
clusterRateLimiters.Cancel(rt, n)
return ret
}
// store done limiters to cancel them when error occurs.
doneLimiters := make([]*rlinternal.RateLimiterNode, 0)
doneLimiters = append(doneLimiters, clusterRateLimiters)
cancelAllLimiters := func() {
for _, limiter := range doneLimiters {
limiter.Cancel(rt, n)
}
}
// 2. check database level rate limits
if ret == nil {
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
ret = dbRateLimiters.Check(rt, n)
if ret != nil {
cancelAllLimiters()
return ret
}
doneLimiters = append(doneLimiters, dbRateLimiters)
}
// 3. check collection level rate limits
if ret == nil && len(collectionIDToPartIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
for collectionID := range collectionIDToPartIDs {
// only dml and dql have collection level rate limits
collectionRateLimiters := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
newDatabaseLimiter, newCollectionLimiters)
ret = collectionRateLimiters.Check(rt, n)
if ret != nil {
cancelAllLimiters()
return ret
}
doneLimiters = append(doneLimiters, collectionRateLimiters)
}
}
// 4. check partition level rate limits
if ret == nil && len(collectionIDToPartIDs) > 0 {
for collectionID, partitionIDs := range collectionIDToPartIDs {
for _, partID := range partitionIDs {
partitionRateLimiters := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partID,
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
ret = partitionRateLimiters.Check(rt, n)
if ret != nil {
cancelAllLimiters()
return ret
}
doneLimiters = append(doneLimiters, partitionRateLimiters)
}
}
}
return ret
}
func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool {
// Most ddl is global level, only DDLFlush will be applied at collection
switch rt {
case internalpb.RateType_DDLCollection,
internalpb.RateType_DDLPartition,
internalpb.RateType_DDLIndex,
internalpb.RateType_DDLCompaction:
return true
default:
return false
}
}
// GetQuotaStates returns quota states.
func (m *SimpleLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) {
m.quotaStatesMu.RLock()
defer m.quotaStatesMu.RUnlock()
serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode])
rlinternal.TraverseRateLimiterTree(m.rateLimiter.GetRootLimiters(), nil,
func(node *rlinternal.RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
if serviceStates[state] == nil {
serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]()
}
serviceStates[state].Insert(errCode)
return true
})
states := make([]milvuspb.QuotaState, 0)
reasons := make([]string, 0)
for state, errCodes := range serviceStates {
for errCode := range errCodes {
states = append(states, state)
reasons = append(reasons, ratelimitutil.GetQuotaErrorString(errCode))
}
}
return states, reasons
}
// SetRates sets quota states for SimpleLimiter.
func (m *SimpleLimiter) SetRates(rootLimiter *proxypb.LimiterNode) error {
m.quotaStatesMu.Lock()
defer m.quotaStatesMu.Unlock()
if err := m.updateRateLimiter(rootLimiter); err != nil {
return err
}
m.rateLimiter.ClearInvalidLimiterNode(rootLimiter)
return nil
}
func initLimiter(rln *rlinternal.RateLimiterNode, rateLimiterConfigs map[internalpb.RateType]*paramtable.ParamItem) {
log := log.Ctx(context.TODO()).WithRateGroup("proxy.rateLimiter", 1.0, 60.0)
for rt, p := range rateLimiterConfigs {
limit := ratelimitutil.Limit(p.GetAsFloat())
burst := p.GetAsFloat() // use rate as burst, because SimpleLimiter is with punishment mechanism, burst is insignificant.
rln.GetLimiters().GetOrInsert(rt, ratelimitutil.NewLimiter(limit, burst))
onEvent := func(rateType internalpb.RateType, formatFunc func(originValue string) string) func(*config.Event) {
return func(event *config.Event) {
f, err := strconv.ParseFloat(formatFunc(event.Value), 64)
if err != nil {
log.Info("Error format for rateLimit",
zap.String("rateType", rateType.String()),
zap.String("key", event.Key),
zap.String("value", event.Value),
zap.Error(err))
return
}
l, ok := rln.GetLimiters().Get(rateType)
if !ok {
log.Info("rateLimiter not found for rateType", zap.String("rateType", rateType.String()))
return
}
l.SetLimit(ratelimitutil.Limit(f))
}
}(rt, p.Formatter)
paramtable.Get().Watch(p.Key, config.NewHandler(fmt.Sprintf("rateLimiter-%d", rt), onEvent))
log.RatedDebug(30, "RateLimiter register for rateType",
zap.String("rateType", internalpb.RateType_name[(int32(rt))]),
zap.String("rateLimit", ratelimitutil.Limit(p.GetAsFloat()).String()),
zap.String("burst", fmt.Sprintf("%v", burst)))
}
}
// newClusterLimiter init limiter of cluster level for all rate types and rate scopes.
// Cluster rate limiter doesn't support to accumulate metrics dynamically, it only uses
// configurations as limit values.
func newClusterLimiter() *rlinternal.RateLimiterNode {
clusterRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
clusterLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Cluster)
initLimiter(clusterRateLimiters, clusterLimiterConfigs)
return clusterRateLimiters
}
func newDatabaseLimiter() *rlinternal.RateLimiterNode {
dbRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Database)
databaseLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Database)
initLimiter(dbRateLimiters, databaseLimiterConfigs)
return dbRateLimiters
}
func newCollectionLimiters() *rlinternal.RateLimiterNode {
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Collection)
collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Collection)
initLimiter(collectionRateLimiters, collectionLimiterConfigs)
return collectionRateLimiters
}
func newPartitionLimiters() *rlinternal.RateLimiterNode {
partRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Partition)
collectionLimiterConfigs := getDefaultLimiterConfig(internalpb.RateScope_Partition)
initLimiter(partRateLimiters, collectionLimiterConfigs)
return partRateLimiters
}
func (m *SimpleLimiter) updateLimiterNode(req *proxypb.Limiter, node *rlinternal.RateLimiterNode, sourceID string) error {
curLimiters := node.GetLimiters()
for _, rate := range req.GetRates() {
limit, ok := curLimiters.Get(rate.GetRt())
if !ok {
return fmt.Errorf("unregister rateLimiter for rateType %s", rate.GetRt().String())
}
limit.SetLimit(ratelimitutil.Limit(rate.GetR()))
setRateGaugeByRateType(rate.GetRt(), paramtable.GetNodeID(), sourceID, rate.GetR())
}
quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
states := req.GetStates()
codes := req.GetCodes()
for i, state := range states {
quotaStates.Insert(state, codes[i])
}
node.SetQuotaStates(quotaStates)
return nil
}
func (m *SimpleLimiter) updateRateLimiter(reqRootLimiterNode *proxypb.LimiterNode) error {
reqClusterLimiter := reqRootLimiterNode.GetLimiter()
clusterLimiter := m.rateLimiter.GetRootLimiters()
err := m.updateLimiterNode(reqClusterLimiter, clusterLimiter, "cluster")
if err != nil {
log.Warn("update cluster rate limiters failed", zap.Error(err))
return err
}
getDBSourceID := func(dbID int64) string {
return fmt.Sprintf("db.%d", dbID)
}
getCollectionSourceID := func(collectionID int64) string {
return fmt.Sprintf("collection.%d", collectionID)
}
getPartitionSourceID := func(partitionID int64) string {
return fmt.Sprintf("partition.%d", partitionID)
}
for dbID, reqDBRateLimiters := range reqRootLimiterNode.GetChildren() {
// update database rate limiters
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
err := m.updateLimiterNode(reqDBRateLimiters.GetLimiter(), dbRateLimiters, getDBSourceID(dbID))
if err != nil {
log.Warn("update database rate limiters failed", zap.Error(err))
return err
}
// update collection rate limiters
for collectionID, reqCollectionRateLimiter := range reqDBRateLimiters.GetChildren() {
collectionRateLimiter := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
newDatabaseLimiter, newCollectionLimiters)
err := m.updateLimiterNode(reqCollectionRateLimiter.GetLimiter(), collectionRateLimiter,
getCollectionSourceID(collectionID))
if err != nil {
log.Warn("update collection rate limiters failed", zap.Error(err))
return err
}
// update partition rate limiters
for partitionID, reqPartitionRateLimiters := range reqCollectionRateLimiter.GetChildren() {
partitionRateLimiter := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partitionID,
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
err := m.updateLimiterNode(reqPartitionRateLimiters.GetLimiter(), partitionRateLimiter,
getPartitionSourceID(partitionID))
if err != nil {
log.Warn("update partition rate limiters failed", zap.Error(err))
return err
}
}
}
}
return nil
}
// setRateGaugeByRateType sets ProxyLimiterRate metrics.
func setRateGaugeByRateType(rateType internalpb.RateType, nodeID int64, sourceID string, rate float64) {
if ratelimitutil.Limit(rate) == ratelimitutil.Inf {
return
}
nodeIDStr := strconv.FormatInt(nodeID, 10)
metrics.ProxyLimiterRate.WithLabelValues(nodeIDStr, sourceID, rateType.String()).Set(rate)
}
func getDefaultLimiterConfig(scope internalpb.RateScope) map[internalpb.RateType]*paramtable.ParamItem {
return quota.GetQuotaConfigMap(scope)
}
func IsDDLRequest(rt internalpb.RateType) bool {
switch rt {
case internalpb.RateType_DDLCollection,
internalpb.RateType_DDLPartition,
internalpb.RateType_DDLIndex,
internalpb.RateType_DDLFlush,
internalpb.RateType_DDLCompaction:
return true
default:
return false
}
}

View File

@ -0,0 +1,415 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"math"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
)
func TestSimpleRateLimiter(t *testing.T) {
collectionID := int64(1)
collectionIDToPartIDs := map[int64][]int64{collectionID: {}}
t.Run("test simpleRateLimiter", func(t *testing.T) {
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
simpleLimiter := NewSimpleLimiter()
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, collectionID, newDatabaseLimiter,
func() *rlinternal.RateLimiterNode {
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
for _, rt := range internalpb.RateType_value {
if IsDDLRequest(internalpb.RateType(rt)) {
clusterRateLimiters.GetLimiters().
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
} else {
collectionRateLimiters.GetLimiters().
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
}
return collectionRateLimiters
})
for _, rt := range internalpb.RateType_value {
if IsDDLRequest(internalpb.RateType(rt)) {
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
} else {
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), math.MaxInt)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), math.MaxInt)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
}
}
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
})
t.Run("test global static limit", func(t *testing.T) {
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
simpleLimiter := NewSimpleLimiter()
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
collectionIDToPartIDs := map[int64][]int64{
1: {},
2: {},
3: {},
}
for i := 1; i <= 3; i++ {
simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(i), newDatabaseLimiter,
func() *rlinternal.RateLimiterNode {
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
for _, rt := range internalpb.RateType_value {
if IsDDLRequest(internalpb.RateType(rt)) {
clusterRateLimiters.GetLimiters().
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(5), 1))
} else {
clusterRateLimiters.GetLimiters().
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
collectionRateLimiters.GetLimiters().
Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(2), 1))
}
}
return collectionRateLimiters
})
}
for _, rt := range internalpb.RateType_value {
if IsDDLRequest(internalpb.RateType(rt)) {
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 5)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
} else {
err := simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
assert.NoError(t, err)
err = simpleLimiter.Check(0, collectionIDToPartIDs, internalpb.RateType(rt), 1)
assert.ErrorIs(t, err, merr.ErrServiceRateLimit)
}
}
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
})
t.Run("not enable quotaAndLimit", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
for _, rt := range internalpb.RateType_value {
err := simpleLimiter.Check(0, nil, internalpb.RateType(rt), 1)
assert.NoError(t, err)
}
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
})
t.Run("test limit", func(t *testing.T) {
run := func(insertRate float64) {
bakInsertRate := Params.QuotaConfig.DMLMaxInsertRate.GetValue()
paramtable.Get().Save(Params.QuotaConfig.DMLMaxInsertRate.Key, fmt.Sprintf("%f", insertRate))
simpleLimiter := NewSimpleLimiter()
bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue()
paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
err := simpleLimiter.Check(0, nil, internalpb.RateType_DMLInsert, 1*1024*1024)
assert.NoError(t, err)
Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak)
Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate)
}
run(math.MaxFloat64)
run(math.MaxFloat64 / 1.2)
run(math.MaxFloat64 / 2)
run(math.MaxFloat64 / 3)
run(math.MaxFloat64 / 10000)
})
t.Run("test set rates", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
zeroRates := getZeroCollectionRates()
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
1: {
Limiter: &proxypb.Limiter{
Rates: zeroRates,
},
Children: make(map[int64]*proxypb.LimiterNode),
},
2: {
Limiter: &proxypb.Limiter{
Rates: zeroRates,
},
Children: make(map[int64]*proxypb.LimiterNode),
},
}))
assert.NoError(t, err)
})
t.Run("test quota states", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
1: {
// collection limiter
Limiter: &proxypb.Limiter{
Rates: getZeroCollectionRates(),
States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToWrite, milvuspb.QuotaState_DenyToRead},
Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_ForceDeny},
},
Children: make(map[int64]*proxypb.LimiterNode),
},
}))
assert.NoError(t, err)
states, codes := simpleLimiter.GetQuotaStates()
assert.Len(t, states, 2)
assert.Len(t, codes, 2)
assert.Contains(t, codes, ratelimitutil.GetQuotaErrorString(commonpb.ErrorCode_DiskQuotaExhausted))
assert.Contains(t, codes, ratelimitutil.GetQuotaErrorString(commonpb.ErrorCode_ForceDeny))
})
}
func getZeroRates() []*internalpb.Rate {
zeroRates := make([]*internalpb.Rate, 0, len(internalpb.RateType_value))
for _, rt := range internalpb.RateType_value {
zeroRates = append(zeroRates, &internalpb.Rate{
Rt: internalpb.RateType(rt), R: 0,
})
}
return zeroRates
}
func getZeroCollectionRates() []*internalpb.Rate {
collectionRate := []internalpb.RateType{
internalpb.RateType_DMLInsert,
internalpb.RateType_DMLDelete,
internalpb.RateType_DMLBulkLoad,
internalpb.RateType_DQLSearch,
internalpb.RateType_DQLQuery,
internalpb.RateType_DDLFlush,
}
zeroRates := make([]*internalpb.Rate, 0, len(collectionRate))
for _, rt := range collectionRate {
zeroRates = append(zeroRates, &internalpb.Rate{
Rt: rt, R: 0,
})
}
return zeroRates
}
func newCollectionLimiterNode(collectionLimiterNodes map[int64]*proxypb.LimiterNode) *proxypb.LimiterNode {
return &proxypb.LimiterNode{
// cluster limiter
Limiter: &proxypb.Limiter{},
// db level
Children: map[int64]*proxypb.LimiterNode{
0: {
// db limiter
Limiter: &proxypb.Limiter{},
// collection level
Children: collectionLimiterNodes,
},
},
}
}
func TestRateLimiter(t *testing.T) {
t.Run("test limit", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
rootLimiters := simpleLimiter.rateLimiter.GetRootLimiters()
for _, rt := range internalpb.RateType_value {
rootLimiters.GetLimiters().Insert(internalpb.RateType(rt), ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
for _, rt := range internalpb.RateType_value {
ok, _ := rootLimiters.Limit(internalpb.RateType(rt), 1)
assert.False(t, ok)
ok, _ = rootLimiters.Limit(internalpb.RateType(rt), math.MaxInt)
assert.False(t, ok)
ok, _ = rootLimiters.Limit(internalpb.RateType(rt), math.MaxInt)
assert.True(t, ok)
}
})
t.Run("test setRates", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
collectionRateLimiters := simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(1), newDatabaseLimiter,
func() *rlinternal.RateLimiterNode {
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
for _, rt := range internalpb.RateType_value {
collectionRateLimiters.GetLimiters().Insert(internalpb.RateType(rt),
ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
return collectionRateLimiters
})
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
1: {
// collection limiter
Limiter: &proxypb.Limiter{
Rates: getZeroRates(),
},
Children: make(map[int64]*proxypb.LimiterNode),
},
}))
assert.NoError(t, err)
for _, rt := range internalpb.RateType_value {
for i := 0; i < 100; i++ {
ok, _ := collectionRateLimiters.Limit(internalpb.RateType(rt), 1)
assert.True(t, ok)
}
}
err = simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
1: {
// collection limiter
Limiter: &proxypb.Limiter{
States: []milvuspb.QuotaState{milvuspb.QuotaState_DenyToRead, milvuspb.QuotaState_DenyToWrite},
Codes: []commonpb.ErrorCode{commonpb.ErrorCode_DiskQuotaExhausted, commonpb.ErrorCode_DiskQuotaExhausted},
},
Children: make(map[int64]*proxypb.LimiterNode),
},
}))
collectionRateLimiter := simpleLimiter.rateLimiter.GetCollectionLimiters(0, 1)
assert.NotNil(t, collectionRateLimiter)
assert.NoError(t, err)
assert.Equal(t, collectionRateLimiter.GetQuotaStates().Len(), 2)
err = simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
1: {
// collection limiter
Limiter: &proxypb.Limiter{
States: []milvuspb.QuotaState{},
},
Children: make(map[int64]*proxypb.LimiterNode),
},
}))
assert.NoError(t, err)
assert.Equal(t, collectionRateLimiter.GetQuotaStates().Len(), 0)
})
t.Run("test get error code", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
collectionRateLimiters := simpleLimiter.rateLimiter.GetOrCreateCollectionLimiters(0, int64(1), newDatabaseLimiter,
func() *rlinternal.RateLimiterNode {
collectionRateLimiters := rlinternal.NewRateLimiterNode(internalpb.RateScope_Cluster)
for _, rt := range internalpb.RateType_value {
collectionRateLimiters.GetLimiters().Insert(internalpb.RateType(rt),
ratelimitutil.NewLimiter(ratelimitutil.Limit(1000), 1))
}
return collectionRateLimiters
})
err := simpleLimiter.SetRates(newCollectionLimiterNode(map[int64]*proxypb.LimiterNode{
1: {
// collection limiter
Limiter: &proxypb.Limiter{
Rates: getZeroRates(),
States: []milvuspb.QuotaState{
milvuspb.QuotaState_DenyToWrite,
milvuspb.QuotaState_DenyToRead,
},
Codes: []commonpb.ErrorCode{
commonpb.ErrorCode_DiskQuotaExhausted,
commonpb.ErrorCode_ForceDeny,
},
},
Children: make(map[int64]*proxypb.LimiterNode),
},
}))
assert.NoError(t, err)
assert.Error(t, collectionRateLimiters.GetQuotaExceededError(internalpb.RateType_DQLQuery))
assert.Error(t, collectionRateLimiters.GetQuotaExceededError(internalpb.RateType_DMLInsert))
})
t.Run("tests refresh rate by config", func(t *testing.T) {
simpleLimiter := NewSimpleLimiter()
clusterRateLimiter := simpleLimiter.rateLimiter.GetRootLimiters()
etcdCli, _ := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
Params.EtcdCfg.Endpoints.GetAsStrings(),
Params.EtcdCfg.EtcdTLSCert.GetValue(),
Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
Params.Save(Params.QuotaConfig.DDLLimitEnabled.Key, "true")
defer Params.Reset(Params.QuotaConfig.DDLLimitEnabled.Key)
Params.Save(Params.QuotaConfig.DMLLimitEnabled.Key, "true")
defer Params.Reset(Params.QuotaConfig.DMLLimitEnabled.Key)
ctx := context.Background()
// avoid production precision issues when comparing 0-terminated numbers
r := rand.Float64()
newRate := fmt.Sprintf("%.2f", r)
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate)
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate")
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid")
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate")
etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/max", "8")
defer etcdCli.KV.Delete(ctx, "by-dev/config/quotaAndLimits/dml/insertRate/max")
assert.Eventually(t, func() bool {
limit, _ := clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DDLCollection)
return math.Abs(r-float64(limit.Limit())) < 0.01
}, 10*time.Second, 1*time.Second)
limit, _ := clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DDLPartition)
assert.Equal(t, "+inf", limit.Limit().String())
limit, _ = clusterRateLimiter.GetLimiters().Get(internalpb.RateType_DMLInsert)
assert.True(t, math.Abs(8*1024*1024-float64(limit.Limit())) < 0.01)
})
}

View File

@ -59,7 +59,7 @@ func ConstructLabel(subs ...string) string {
func init() {
var err error
Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity, false)
if err != nil {
log.Fatal("failed to initialize querynode rate collector", zap.Error(err))
}

View File

@ -234,6 +234,11 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
return err
}
db, err := b.s.meta.GetDatabaseByName(ctx, req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
partitionIDs := make([]int64, len(colMeta.Partitions))
for _, p := range colMeta.Partitions {
partitionIDs = append(partitionIDs, p.PartitionID)
@ -249,6 +254,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
PartitionIDs: partitionIDs,
StartPositions: colMeta.StartPositions,
Properties: req.GetProperties(),
DbID: db.ID,
}
resp, err := b.s.dataCoord.BroadcastAlteredCollection(ctx, dcReq)

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -239,6 +240,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
mock.Anything,
mock.Anything,
).Return(collMeta, nil)
mockGetDatabase(meta)
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
@ -256,6 +258,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
mock.Anything,
mock.Anything,
).Return(collMeta, nil)
mockGetDatabase(meta)
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
@ -273,6 +276,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
mock.Anything,
mock.Anything,
).Return(collMeta, nil)
mockGetDatabase(meta)
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
@ -327,3 +331,11 @@ func TestServerBroker_GcConfirm(t *testing.T) {
assert.True(t, broker.GcConfirm(context.Background(), 100, 10000))
})
}
func mockGetDatabase(meta *mockrootcoord.IMetaTable) {
db := model.NewDatabase(1, "default", pb.DatabaseState_DatabaseCreated)
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
Return(db, nil).Maybe()
meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything).
Return(db, nil).Maybe()
}

View File

@ -44,6 +44,7 @@ func (t *describeCollectionTask) Execute(ctx context.Context) (err error) {
if err != nil {
return err
}
aliases := t.core.meta.ListAliasesByID(coll.CollectionID)
db, err := t.core.meta.GetDatabaseByID(ctx, coll.DBID, t.GetTs())
if err != nil {

View File

@ -0,0 +1,56 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// describeDBTask describe database request task
type describeDBTask struct {
baseTask
Req *rootcoordpb.DescribeDatabaseRequest
Rsp *rootcoordpb.DescribeDatabaseResponse
allowUnavailable bool
}
func (t *describeDBTask) Prepare(ctx context.Context) error {
return nil
}
// Execute task execution
func (t *describeDBTask) Execute(ctx context.Context) (err error) {
db, err := t.core.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
t.Rsp = &rootcoordpb.DescribeDatabaseResponse{
Status: merr.Status(err),
}
return err
}
t.Rsp = &rootcoordpb.DescribeDatabaseResponse{
Status: merr.Success(),
DbID: db.ID,
DbName: db.Name,
CreatedTimestamp: db.CreatedTime,
}
return nil
}

View File

@ -0,0 +1,88 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/util"
)
func Test_describeDatabaseTask_Execute(t *testing.T) {
t.Run("failed to get database by name", func(t *testing.T) {
core := newTestCore(withInvalidMeta())
task := &describeDBTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.DescribeDatabaseRequest{
DbName: "testDB",
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
assert.NotNil(t, task.Rsp)
assert.NotNil(t, task.Rsp.Status)
})
t.Run("describe with empty database name", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
Return(model.NewDefaultDatabase(), nil)
core := newTestCore(withMeta(meta))
task := &describeDBTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.DescribeDatabaseRequest{},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
assert.NotNil(t, task.Rsp)
assert.Equal(t, task.Rsp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
assert.Equal(t, util.DefaultDBName, task.Rsp.GetDbName())
assert.Equal(t, util.DefaultDBID, task.Rsp.GetDbID())
})
t.Run("describe with specified database name", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).
Return(&model.Database{
Name: "db1",
ID: 100,
CreatedTime: 1,
}, nil)
core := newTestCore(withMeta(meta))
task := &describeDBTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.DescribeDatabaseRequest{DbName: "db1"},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
assert.NotNil(t, task.Rsp)
assert.Equal(t, task.Rsp.GetStatus().GetCode(), int32(commonpb.ErrorCode_Success))
assert.Equal(t, "db1", task.Rsp.GetDbName())
assert.Equal(t, int64(100), task.Rsp.GetDbID())
assert.Equal(t, uint64(1), task.Rsp.GetCreatedTimestamp())
})
}

View File

@ -55,6 +55,7 @@ type IMetaTable interface {
RemoveCollection(ctx context.Context, collectionID UniqueID, ts Timestamp) error
GetCollectionByName(ctx context.Context, dbName string, collectionName string, ts Timestamp) (*model.Collection, error)
GetCollectionByID(ctx context.Context, dbName string, collectionID UniqueID, ts Timestamp, allowUnavailable bool) (*model.Collection, error)
GetCollectionByIDWithMaxTs(ctx context.Context, collectionID UniqueID) (*model.Collection, error)
ListCollections(ctx context.Context, dbName string, ts Timestamp, onlyAvail bool) ([]*model.Collection, error)
ListAllAvailCollections(ctx context.Context) map[int64][]int64
ListCollectionPhysicalChannels() map[typeutil.UniqueID][]string
@ -362,7 +363,7 @@ func (mt *MetaTable) getDatabaseByNameInternal(_ context.Context, dbName string,
db, ok := mt.dbName2Meta[dbName]
if !ok {
return nil, fmt.Errorf("database:%s not found", dbName)
return nil, merr.WrapErrDatabaseNotFound(dbName)
}
return db, nil
@ -519,12 +520,12 @@ func filterUnavailable(coll *model.Collection) *model.Collection {
}
// getLatestCollectionByIDInternal should be called with ts = typeutil.MaxTimestamp
func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowAvailable bool) (*model.Collection, error) {
func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowUnavailable bool) (*model.Collection, error) {
coll, ok := mt.collID2Meta[collectionID]
if !ok || coll == nil {
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
if allowAvailable {
if allowUnavailable {
return coll.Clone(), nil
}
if !coll.Available() {
@ -623,6 +624,11 @@ func (mt *MetaTable) GetCollectionByID(ctx context.Context, dbName string, colle
return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, allowUnavailable)
}
// GetCollectionByIDWithMaxTs get collection, dbName can be ignored if ts is max timestamps
func (mt *MetaTable) GetCollectionByIDWithMaxTs(ctx context.Context, collectionID UniqueID) (*model.Collection, error) {
return mt.GetCollectionByID(ctx, "", collectionID, typeutil.MaxTimestamp, false)
}
func (mt *MetaTable) ListAllAvailCollections(ctx context.Context) map[int64][]int64 {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()

View File

@ -95,6 +95,11 @@ type mockMetaTable struct {
DropGrantFunc func(tenant string, role *milvuspb.RoleEntity) error
ListPolicyFunc func(tenant string) ([]string, error)
ListUserRoleFunc func(tenant string) ([]string, error)
DescribeDatabaseFunc func(ctx context.Context, dbName string) (*model.Database, error)
}
func (m mockMetaTable) GetDatabaseByName(ctx context.Context, dbName string, ts Timestamp) (*model.Database, error) {
return m.DescribeDatabaseFunc(ctx, dbName)
}
func (m mockMetaTable) ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]*model.Database, error) {
@ -516,6 +521,9 @@ func withInvalidMeta() Opt {
meta.ListAliasesFunc = func(ctx context.Context, dbName, collectionName string, ts Timestamp) ([]string, error) {
return nil, errors.New("error mock ListAliases")
}
meta.DescribeDatabaseFunc = func(ctx context.Context, dbName string) (*model.Database, error) {
return nil, errors.New("error mock DescribeDatabase")
}
return withMeta(meta)
}

View File

@ -843,6 +843,61 @@ func (_c *IMetaTable_GetCollectionByID_Call) RunAndReturn(run func(context.Conte
return _c
}
// GetCollectionByIDWithMaxTs provides a mock function with given fields: ctx, collectionID
func (_m *IMetaTable) GetCollectionByIDWithMaxTs(ctx context.Context, collectionID int64) (*model.Collection, error) {
ret := _m.Called(ctx, collectionID)
var r0 *model.Collection
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (*model.Collection, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) *model.Collection); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Collection)
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// IMetaTable_GetCollectionByIDWithMaxTs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionByIDWithMaxTs'
type IMetaTable_GetCollectionByIDWithMaxTs_Call struct {
*mock.Call
}
// GetCollectionByIDWithMaxTs is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *IMetaTable_Expecter) GetCollectionByIDWithMaxTs(ctx interface{}, collectionID interface{}) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
return &IMetaTable_GetCollectionByIDWithMaxTs_Call{Call: _e.mock.On("GetCollectionByIDWithMaxTs", ctx, collectionID)}
}
func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) Run(run func(ctx context.Context, collectionID int64)) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) Return(_a0 *model.Collection, _a1 error) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *IMetaTable_GetCollectionByIDWithMaxTs_Call) RunAndReturn(run func(context.Context, int64) (*model.Collection, error)) *IMetaTable_GetCollectionByIDWithMaxTs_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionByName provides a mock function with given fields: ctx, dbName, collectionName, ts
func (_m *IMetaTable) GetCollectionByName(ctx context.Context, dbName string, collectionName string, ts uint64) (*model.Collection, error) {
ret := _m.Called(ctx, dbName, collectionName, ts)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2669,6 +2669,40 @@ func (c *Core) RenameCollection(ctx context.Context, req *milvuspb.RenameCollect
return merr.Success(), nil
}
func (c *Core) DescribeDatabase(ctx context.Context, req *rootcoordpb.DescribeDatabaseRequest) (*rootcoordpb.DescribeDatabaseResponse, error) {
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil
}
log := log.Ctx(ctx).With(zap.String("dbName", req.GetDbName()))
log.Info("received request to describe database ")
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("DescribeDatabase")
t := &describeDBTask{
baseTask: newBaseTask(ctx, c),
Req: req,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Warn("failed to enqueue request to describe database", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.FailLabel).Inc()
return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil
}
if err := t.WaitToFinish(); err != nil {
log.Warn("failed to describe database", zap.Uint64("ts", t.GetTs()), zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.FailLabel).Inc()
return &rootcoordpb.DescribeDatabaseResponse{Status: merr.Status(err)}, nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeDatabase", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("DescribeDatabase").Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Info("done to describe database", zap.Uint64("ts", t.GetTs()))
return t.Rsp, nil
}
func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) {
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
return &milvuspb.CheckHealthResponse{

View File

@ -1443,6 +1443,43 @@ func TestRootCoord_CheckHealth(t *testing.T) {
})
}
func TestRootCoord_DescribeDatabase(t *testing.T) {
t.Run("not healthy", func(t *testing.T) {
ctx := context.Background()
c := newTestCore(withAbnormalCode())
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
assert.NoError(t, err)
assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil))
})
t.Run("add task failed", func(t *testing.T) {
ctx := context.Background()
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
assert.NoError(t, err)
assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil))
})
t.Run("execute task failed", func(t *testing.T) {
ctx := context.Background()
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
assert.NoError(t, err)
assert.Error(t, merr.CheckRPCCall(resp.GetStatus(), nil))
})
t.Run("run ok", func(t *testing.T) {
ctx := context.Background()
c := newTestCore(withHealthyCode(),
withValidScheduler())
resp, err := c.DescribeDatabase(ctx, &rootcoordpb.DescribeDatabaseRequest{})
assert.NoError(t, err)
assert.NoError(t, merr.CheckRPCCall(resp.GetStatus(), nil))
})
}
func TestRootCoord_RBACError(t *testing.T) {
ctx := context.Background()
c := newTestCore(withHealthyCode(), withInvalidMeta())

View File

@ -138,13 +138,16 @@ func getCollectionRateLimitConfigDefaultValue(configKey string) float64 {
return Params.QuotaConfig.DQLMinSearchRatePerCollection.GetAsFloat()
case common.CollectionDiskQuotaKey:
return Params.QuotaConfig.DiskQuotaPerCollection.GetAsFloat()
default:
return float64(0)
}
}
func getCollectionRateLimitConfig(properties map[string]string, configKey string) float64 {
return getRateLimitConfig(properties, configKey, getCollectionRateLimitConfigDefaultValue(configKey))
}
func getRateLimitConfig(properties map[string]string, configKey string, configValue float64) float64 {
megaBytes2Bytes := func(v float64) float64 {
return v * 1024.0 * 1024.0
}
@ -189,15 +192,15 @@ func getCollectionRateLimitConfig(properties map[string]string, configKey string
log.Warn("invalid configuration for collection dml rate",
zap.String("config item", configKey),
zap.String("config value", v))
return getCollectionRateLimitConfigDefaultValue(configKey)
return configValue
}
rateInBytes := toBytesIfNecessary(rate)
if rateInBytes < 0 {
return getCollectionRateLimitConfigDefaultValue(configKey)
return configValue
}
return rateInBytes
}
return getCollectionRateLimitConfigDefaultValue(configKey)
return configValue
}

View File

@ -292,3 +292,27 @@ func Test_getCollectionRateLimitConfig(t *testing.T) {
})
}
}
func TestGetRateLimitConfigErr(t *testing.T) {
key := common.CollectionQueryRateMaxKey
t.Run("negative value", func(t *testing.T) {
v := getRateLimitConfig(map[string]string{
key: "-1",
}, key, 1)
assert.EqualValues(t, 1, v)
})
t.Run("valid value", func(t *testing.T) {
v := getRateLimitConfig(map[string]string{
key: "1",
}, key, 100)
assert.EqualValues(t, 1, v)
})
t.Run("not exist value", func(t *testing.T) {
v := getRateLimitConfig(map[string]string{
key: "1",
}, "b", 100)
assert.EqualValues(t, 100, v)
})
}

View File

@ -37,7 +37,7 @@ import (
// If Limit function return true, the request will be rejected.
// Otherwise, the request will pass. Limit also returns limit of limiter.
type Limiter interface {
Check(collectionIDs []int64, rt internalpb.RateType, n int) error
Check(dbID int64, collectionIDToPartIDs map[int64][]int64, rt internalpb.RateType, n int) error
}
// Component is the interface all services implement

View File

@ -37,6 +37,10 @@ type GrpcRootCoordClient struct {
Err error
}
func (m *GrpcRootCoordClient) DescribeDatabase(ctx context.Context, in *rootcoordpb.DescribeDatabaseRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeDatabaseResponse, error) {
return &rootcoordpb.DescribeDatabaseResponse{}, m.Err
}
func (m *GrpcRootCoordClient) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}

View File

@ -0,0 +1,106 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package quota
import (
"math"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
var (
initOnce sync.Once
limitConfigMap map[internalpb.RateScope]map[internalpb.RateType]*paramtable.ParamItem
)
func initLimitConfigMaps() {
initOnce.Do(func() {
quotaConfig := &paramtable.Get().QuotaConfig
limitConfigMap = map[internalpb.RateScope]map[internalpb.RateType]*paramtable.ParamItem{
internalpb.RateScope_Cluster: {
internalpb.RateType_DDLCollection: &quotaConfig.DDLCollectionRate,
internalpb.RateType_DDLPartition: &quotaConfig.DDLPartitionRate,
internalpb.RateType_DDLIndex: &quotaConfig.MaxIndexRate,
internalpb.RateType_DDLFlush: &quotaConfig.MaxFlushRate,
internalpb.RateType_DDLCompaction: &quotaConfig.MaxCompactionRate,
internalpb.RateType_DMLInsert: &quotaConfig.DMLMaxInsertRate,
internalpb.RateType_DMLUpsert: &quotaConfig.DMLMaxUpsertRate,
internalpb.RateType_DMLDelete: &quotaConfig.DMLMaxDeleteRate,
internalpb.RateType_DMLBulkLoad: &quotaConfig.DMLMaxBulkLoadRate,
internalpb.RateType_DQLSearch: &quotaConfig.DQLMaxSearchRate,
internalpb.RateType_DQLQuery: &quotaConfig.DQLMaxQueryRate,
},
internalpb.RateScope_Database: {
internalpb.RateType_DDLCollection: &quotaConfig.DDLCollectionRatePerDB,
internalpb.RateType_DDLPartition: &quotaConfig.DDLPartitionRatePerDB,
internalpb.RateType_DDLIndex: &quotaConfig.MaxIndexRatePerDB,
internalpb.RateType_DDLFlush: &quotaConfig.MaxFlushRatePerDB,
internalpb.RateType_DDLCompaction: &quotaConfig.MaxCompactionRatePerDB,
internalpb.RateType_DMLInsert: &quotaConfig.DMLMaxInsertRatePerDB,
internalpb.RateType_DMLUpsert: &quotaConfig.DMLMaxUpsertRatePerDB,
internalpb.RateType_DMLDelete: &quotaConfig.DMLMaxDeleteRatePerDB,
internalpb.RateType_DMLBulkLoad: &quotaConfig.DMLMaxBulkLoadRatePerDB,
internalpb.RateType_DQLSearch: &quotaConfig.DQLMaxSearchRatePerDB,
internalpb.RateType_DQLQuery: &quotaConfig.DQLMaxQueryRatePerDB,
},
internalpb.RateScope_Collection: {
internalpb.RateType_DMLInsert: &quotaConfig.DMLMaxInsertRatePerCollection,
internalpb.RateType_DMLUpsert: &quotaConfig.DMLMaxUpsertRatePerCollection,
internalpb.RateType_DMLDelete: &quotaConfig.DMLMaxDeleteRatePerCollection,
internalpb.RateType_DMLBulkLoad: &quotaConfig.DMLMaxBulkLoadRatePerCollection,
internalpb.RateType_DQLSearch: &quotaConfig.DQLMaxSearchRatePerCollection,
internalpb.RateType_DQLQuery: &quotaConfig.DQLMaxQueryRatePerCollection,
internalpb.RateType_DDLFlush: &quotaConfig.MaxFlushRatePerCollection,
},
internalpb.RateScope_Partition: {
internalpb.RateType_DMLInsert: &quotaConfig.DMLMaxInsertRatePerPartition,
internalpb.RateType_DMLUpsert: &quotaConfig.DMLMaxUpsertRatePerPartition,
internalpb.RateType_DMLDelete: &quotaConfig.DMLMaxDeleteRatePerPartition,
internalpb.RateType_DMLBulkLoad: &quotaConfig.DMLMaxBulkLoadRatePerPartition,
internalpb.RateType_DQLSearch: &quotaConfig.DQLMaxSearchRatePerPartition,
internalpb.RateType_DQLQuery: &quotaConfig.DQLMaxQueryRatePerPartition,
},
}
})
}
func GetQuotaConfigMap(scope internalpb.RateScope) map[internalpb.RateType]*paramtable.ParamItem {
initLimitConfigMaps()
configMap, ok := limitConfigMap[scope]
if !ok {
log.Warn("Unknown rate scope", zap.Any("scope", scope))
return make(map[internalpb.RateType]*paramtable.ParamItem)
}
return configMap
}
func GetQuotaValue(scope internalpb.RateScope, rateType internalpb.RateType, params *paramtable.ComponentParam) float64 {
configMap := GetQuotaConfigMap(scope)
config, ok := configMap[rateType]
if !ok {
log.Warn("Unknown rate type", zap.Any("rateType", rateType))
return math.MaxFloat64
}
return config.GetAsFloat()
}

View File

@ -0,0 +1,91 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package quota
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestGetQuotaConfigMap(t *testing.T) {
paramtable.Init()
{
m := GetQuotaConfigMap(internalpb.RateScope_Cluster)
assert.Equal(t, 11, len(m))
}
{
m := GetQuotaConfigMap(internalpb.RateScope_Database)
assert.Equal(t, 11, len(m))
}
{
m := GetQuotaConfigMap(internalpb.RateScope_Collection)
assert.Equal(t, 7, len(m))
}
{
m := GetQuotaConfigMap(internalpb.RateScope_Partition)
assert.Equal(t, 6, len(m))
}
{
m := GetQuotaConfigMap(internalpb.RateScope(1000))
assert.Equal(t, 0, len(m))
}
}
func TestGetQuotaValue(t *testing.T) {
paramtable.Init()
param := paramtable.Get()
param.Save(param.QuotaConfig.DDLLimitEnabled.Key, "true")
defer param.Reset(param.QuotaConfig.DDLLimitEnabled.Key)
param.Save(param.QuotaConfig.DMLLimitEnabled.Key, "true")
defer param.Reset(param.QuotaConfig.DMLLimitEnabled.Key)
t.Run("cluster", func(t *testing.T) {
param.Save(param.QuotaConfig.DDLCollectionRate.Key, "10")
defer param.Reset(param.QuotaConfig.DDLCollectionRate.Key)
v := GetQuotaValue(internalpb.RateScope_Cluster, internalpb.RateType_DDLCollection, param)
assert.EqualValues(t, 10, v)
})
t.Run("database", func(t *testing.T) {
param.Save(param.QuotaConfig.DDLCollectionRatePerDB.Key, "10")
defer param.Reset(param.QuotaConfig.DDLCollectionRatePerDB.Key)
v := GetQuotaValue(internalpb.RateScope_Database, internalpb.RateType_DDLCollection, param)
assert.EqualValues(t, 10, v)
})
t.Run("collection", func(t *testing.T) {
param.Save(param.QuotaConfig.DMLMaxInsertRatePerCollection.Key, "10")
defer param.Reset(param.QuotaConfig.DMLMaxInsertRatePerCollection.Key)
v := GetQuotaValue(internalpb.RateScope_Collection, internalpb.RateType_DMLInsert, param)
assert.EqualValues(t, 10*1024*1024, v)
})
t.Run("partition", func(t *testing.T) {
param.Save(param.QuotaConfig.DMLMaxInsertRatePerPartition.Key, "10")
defer param.Reset(param.QuotaConfig.DMLMaxInsertRatePerPartition.Key)
v := GetQuotaValue(internalpb.RateScope_Partition, internalpb.RateType_DMLInsert, param)
assert.EqualValues(t, 10*1024*1024, v)
})
t.Run("unknown", func(t *testing.T) {
v := GetQuotaValue(internalpb.RateScope(1000), internalpb.RateType(1000), param)
assert.EqualValues(t, math.MaxFloat64, v)
})
}

View File

@ -0,0 +1,336 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ratelimitutil
import (
"fmt"
"sync"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type RateLimiterNode struct {
limiters *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]
quotaStates *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]
level internalpb.RateScope
// db id, collection id or partition id, cluster id is 0 for the cluster level
id int64
// children will be databases if current level is cluster
// children will be collections if current level is database
// children will be partitions if current level is collection
children *typeutil.ConcurrentMap[int64, *RateLimiterNode]
}
func NewRateLimiterNode(level internalpb.RateScope) *RateLimiterNode {
rln := &RateLimiterNode{
limiters: typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter](),
quotaStates: typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode](),
children: typeutil.NewConcurrentMap[int64, *RateLimiterNode](),
level: level,
}
return rln
}
func (rln *RateLimiterNode) Level() internalpb.RateScope {
return rln.level
}
// Limit returns true, the request will be rejected.
// Otherwise, the request will pass.
func (rln *RateLimiterNode) Limit(rt internalpb.RateType, n int) (bool, float64) {
limit, ok := rln.limiters.Get(rt)
if !ok {
return false, -1
}
return !limit.AllowN(time.Now(), n), float64(limit.Limit())
}
func (rln *RateLimiterNode) Cancel(rt internalpb.RateType, n int) {
limit, ok := rln.limiters.Get(rt)
if !ok {
return
}
limit.Cancel(n)
}
func (rln *RateLimiterNode) Check(rt internalpb.RateType, n int) error {
limit, rate := rln.Limit(rt, n)
if rate == 0 {
return rln.GetQuotaExceededError(rt)
}
if limit {
return rln.GetRateLimitError(rate)
}
return nil
}
func (rln *RateLimiterNode) GetQuotaExceededError(rt internalpb.RateType) error {
switch rt {
case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad:
if errCode, ok := rln.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok {
return merr.WrapErrServiceQuotaExceeded(ratelimitutil.GetQuotaErrorString(errCode))
}
case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery:
if errCode, ok := rln.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok {
return merr.WrapErrServiceQuotaExceeded(ratelimitutil.GetQuotaErrorString(errCode))
}
}
return merr.WrapErrServiceQuotaExceeded(fmt.Sprintf("rate type: %s", rt.String()))
}
func (rln *RateLimiterNode) GetRateLimitError(rate float64) error {
return merr.WrapErrServiceRateLimit(rate, "request is rejected by grpc RateLimiter middleware, please retry later")
}
func TraverseRateLimiterTree(root *RateLimiterNode, fn1 func(internalpb.RateType, *ratelimitutil.Limiter) bool,
fn2 func(node *RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool,
) {
if fn1 != nil {
root.limiters.Range(fn1)
}
if fn2 != nil {
root.quotaStates.Range(func(state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
return fn2(root, state, errCode)
})
}
root.GetChildren().Range(func(key int64, child *RateLimiterNode) bool {
TraverseRateLimiterTree(child, fn1, fn2)
return true
})
}
func (rln *RateLimiterNode) AddChild(key int64, child *RateLimiterNode) {
rln.children.Insert(key, child)
}
func (rln *RateLimiterNode) GetChild(key int64) *RateLimiterNode {
n, _ := rln.children.Get(key)
return n
}
func (rln *RateLimiterNode) GetChildren() *typeutil.ConcurrentMap[int64, *RateLimiterNode] {
return rln.children
}
func (rln *RateLimiterNode) GetLimiters() *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter] {
return rln.limiters
}
func (rln *RateLimiterNode) SetLimiters(new *typeutil.ConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]) {
rln.limiters = new
}
func (rln *RateLimiterNode) GetQuotaStates() *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode] {
return rln.quotaStates
}
func (rln *RateLimiterNode) SetQuotaStates(new *typeutil.ConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]) {
rln.quotaStates = new
}
func (rln *RateLimiterNode) GetID() int64 {
return rln.id
}
// RateLimiterTree is implemented based on RateLimiterNode to operate multilevel rate limiters
//
// it contains the following four levels generally:
//
// -> global level
// -> database level
// -> collection level
// -> partition levelearl
type RateLimiterTree struct {
root *RateLimiterNode
mu sync.RWMutex
}
// NewRateLimiterTree returns a new RateLimiterTree.
func NewRateLimiterTree(root *RateLimiterNode) *RateLimiterTree {
return &RateLimiterTree{root: root}
}
// GetRootLimiters get root limiters
func (m *RateLimiterTree) GetRootLimiters() *RateLimiterNode {
return m.root
}
func (m *RateLimiterTree) ClearInvalidLimiterNode(req *proxypb.LimiterNode) {
m.mu.Lock()
defer m.mu.Unlock()
reqDBLimits := req.GetChildren()
removeDBLimits := make([]int64, 0)
m.GetRootLimiters().GetChildren().Range(func(key int64, _ *RateLimiterNode) bool {
if _, ok := reqDBLimits[key]; !ok {
removeDBLimits = append(removeDBLimits, key)
}
return true
})
for _, dbID := range removeDBLimits {
m.GetRootLimiters().GetChildren().Remove(dbID)
}
m.GetRootLimiters().GetChildren().Range(func(dbID int64, dbNode *RateLimiterNode) bool {
reqCollectionLimits := reqDBLimits[dbID].GetChildren()
removeCollectionLimits := make([]int64, 0)
dbNode.GetChildren().Range(func(key int64, _ *RateLimiterNode) bool {
if _, ok := reqCollectionLimits[key]; !ok {
removeCollectionLimits = append(removeCollectionLimits, key)
}
return true
})
for _, collectionID := range removeCollectionLimits {
dbNode.GetChildren().Remove(collectionID)
}
return true
})
m.GetRootLimiters().GetChildren().Range(func(dbID int64, dbNode *RateLimiterNode) bool {
dbNode.GetChildren().Range(func(collectionID int64, collectionNode *RateLimiterNode) bool {
reqPartitionLimits := reqDBLimits[dbID].GetChildren()[collectionID].GetChildren()
removePartitionLimits := make([]int64, 0)
collectionNode.GetChildren().Range(func(key int64, _ *RateLimiterNode) bool {
if _, ok := reqPartitionLimits[key]; !ok {
removePartitionLimits = append(removePartitionLimits, key)
}
return true
})
for _, partitionID := range removePartitionLimits {
collectionNode.GetChildren().Remove(partitionID)
}
return true
})
return true
})
}
func (m *RateLimiterTree) GetDatabaseLimiters(dbID int64) *RateLimiterNode {
m.mu.RLock()
defer m.mu.RUnlock()
return m.root.GetChild(dbID)
}
// GetOrCreateDatabaseLimiters get limiter of database level, or create a database limiter if it doesn't exist.
func (m *RateLimiterTree) GetOrCreateDatabaseLimiters(dbID int64, newDBRateLimiter func() *RateLimiterNode) *RateLimiterNode {
dbRateLimiters := m.GetDatabaseLimiters(dbID)
if dbRateLimiters != nil {
return dbRateLimiters
}
m.mu.Lock()
defer m.mu.Unlock()
if cur := m.root.GetChild(dbID); cur != nil {
return cur
}
dbRateLimiters = newDBRateLimiter()
dbRateLimiters.id = dbID
m.root.AddChild(dbID, dbRateLimiters)
return dbRateLimiters
}
func (m *RateLimiterTree) GetCollectionLimiters(dbID, collectionID int64) *RateLimiterNode {
m.mu.RLock()
defer m.mu.RUnlock()
dbRateLimiters := m.root.GetChild(dbID)
// database rate limiter not found
if dbRateLimiters == nil {
return nil
}
return dbRateLimiters.GetChild(collectionID)
}
// GetOrCreateCollectionLimiters create limiter of collection level for all rate types and rate scopes.
// create a database rate limiters if db rate limiter does not exist
func (m *RateLimiterTree) GetOrCreateCollectionLimiters(dbID, collectionID int64,
newDBRateLimiter func() *RateLimiterNode, newCollectionRateLimiter func() *RateLimiterNode,
) *RateLimiterNode {
collectionRateLimiters := m.GetCollectionLimiters(dbID, collectionID)
if collectionRateLimiters != nil {
return collectionRateLimiters
}
dbRateLimiters := m.GetOrCreateDatabaseLimiters(dbID, newDBRateLimiter)
m.mu.Lock()
defer m.mu.Unlock()
if cur := dbRateLimiters.GetChild(collectionID); cur != nil {
return cur
}
collectionRateLimiters = newCollectionRateLimiter()
collectionRateLimiters.id = collectionID
dbRateLimiters.AddChild(collectionID, collectionRateLimiters)
return collectionRateLimiters
}
// It checks if the rate limiters exist for the database, collection, and partition,
// returns the corresponding rate limiter tree.
func (m *RateLimiterTree) GetPartitionLimiters(dbID, collectionID, partitionID int64) *RateLimiterNode {
m.mu.RLock()
defer m.mu.RUnlock()
dbRateLimiters := m.root.GetChild(dbID)
// database rate limiter not found
if dbRateLimiters == nil {
return nil
}
collectionRateLimiters := dbRateLimiters.GetChild(collectionID)
// collection rate limiter not found
if collectionRateLimiters == nil {
return nil
}
return collectionRateLimiters.GetChild(partitionID)
}
// GetOrCreatePartitionLimiters create limiter of partition level for all rate types and rate scopes.
// create a database rate limiters if db rate limiter does not exist
// create a collection rate limiters if collection rate limiter does not exist
func (m *RateLimiterTree) GetOrCreatePartitionLimiters(dbID int64, collectionID int64, partitionID int64,
newDBRateLimiter func() *RateLimiterNode, newCollectionRateLimiter func() *RateLimiterNode,
newPartRateLimiter func() *RateLimiterNode,
) *RateLimiterNode {
partRateLimiters := m.GetPartitionLimiters(dbID, collectionID, partitionID)
if partRateLimiters != nil {
return partRateLimiters
}
collectionRateLimiters := m.GetOrCreateCollectionLimiters(dbID, collectionID, newDBRateLimiter, newCollectionRateLimiter)
m.mu.Lock()
defer m.mu.Unlock()
if cur := collectionRateLimiters.GetChild(partitionID); cur != nil {
return cur
}
partRateLimiters = newPartRateLimiter()
partRateLimiters.id = partitionID
collectionRateLimiters.AddChild(partitionID, partRateLimiters)
return partRateLimiters
}

View File

@ -0,0 +1,205 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ratelimitutil
import (
"strings"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestRateLimiterNode_AddAndGetChild(t *testing.T) {
rln := NewRateLimiterNode(internalpb.RateScope_Cluster)
child := NewRateLimiterNode(internalpb.RateScope_Cluster)
// Positive test case
rln.AddChild(1, child)
if rln.GetChild(1) != child {
t.Error("AddChild did not add the child correctly")
}
// Negative test case
invalidChild := &RateLimiterNode{}
rln.AddChild(2, child)
if rln.GetChild(2) == invalidChild {
t.Error("AddChild added an invalid child")
}
}
func TestTraverseRateLimiterTree(t *testing.T) {
limiters := typeutil.NewConcurrentMap[internalpb.RateType, *ratelimitutil.Limiter]()
limiters.Insert(internalpb.RateType_DDLCollection, ratelimitutil.NewLimiter(ratelimitutil.Inf, 0))
quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny)
root := NewRateLimiterNode(internalpb.RateScope_Cluster)
root.SetLimiters(limiters)
root.SetQuotaStates(quotaStates)
// Add a child to the root node
child := NewRateLimiterNode(internalpb.RateScope_Cluster)
child.SetLimiters(limiters)
child.SetQuotaStates(quotaStates)
root.AddChild(123, child)
// Add a child to the root node
child2 := NewRateLimiterNode(internalpb.RateScope_Cluster)
child2.SetLimiters(limiters)
child2.SetQuotaStates(quotaStates)
child.AddChild(123, child2)
// Positive test case for fn1
var fn1Count int
fn1 := func(rateType internalpb.RateType, limiter *ratelimitutil.Limiter) bool {
fn1Count++
return true
}
// Negative test case for fn2
var fn2Count int
fn2 := func(node *RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
fn2Count++
return true
}
// Call TraverseRateLimiterTree with fn1 and fn2
TraverseRateLimiterTree(root, fn1, fn2)
assert.Equal(t, 3, fn1Count)
assert.Equal(t, 3, fn2Count)
}
func TestRateLimiterNodeCancel(t *testing.T) {
t.Run("cancel not exist type", func(t *testing.T) {
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
limitNode.Cancel(internalpb.RateType_DMLInsert, 10)
})
}
func TestRateLimiterNodeCheck(t *testing.T) {
t.Run("quota exceed", func(t *testing.T) {
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
limitNode.limiters.Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(0, 0))
limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny)
err := limitNode.Check(internalpb.RateType_DMLInsert, 10)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
})
t.Run("rate limit", func(t *testing.T) {
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
limitNode.limiters.Insert(internalpb.RateType_DMLInsert, ratelimitutil.NewLimiter(0.01, 0.01))
{
err := limitNode.Check(internalpb.RateType_DMLInsert, 1)
assert.NoError(t, err)
}
{
err := limitNode.Check(internalpb.RateType_DMLInsert, 1)
assert.True(t, errors.Is(err, merr.ErrServiceRateLimit))
}
})
}
func TestRateLimiterNodeGetQuotaExceededError(t *testing.T) {
t.Run("write", func(t *testing.T) {
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToWrite, commonpb.ErrorCode_ForceDeny)
err := limitNode.GetQuotaExceededError(internalpb.RateType_DMLInsert)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
// reference: ratelimitutil.GetQuotaErrorString(errCode)
assert.True(t, strings.Contains(err.Error(), "deactivated"))
})
t.Run("read", func(t *testing.T) {
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
limitNode.quotaStates.Insert(milvuspb.QuotaState_DenyToRead, commonpb.ErrorCode_ForceDeny)
err := limitNode.GetQuotaExceededError(internalpb.RateType_DQLSearch)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
// reference: ratelimitutil.GetQuotaErrorString(errCode)
assert.True(t, strings.Contains(err.Error(), "deactivated"))
})
t.Run("unknown", func(t *testing.T) {
limitNode := NewRateLimiterNode(internalpb.RateScope_Cluster)
err := limitNode.GetQuotaExceededError(internalpb.RateType_DDLCompaction)
assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded))
assert.True(t, strings.Contains(err.Error(), "rate type"))
})
}
func TestRateLimiterTreeClearInvalidLimiterNode(t *testing.T) {
root := NewRateLimiterNode(internalpb.RateScope_Cluster)
tree := NewRateLimiterTree(root)
generateNodeFFunc := func(level internalpb.RateScope) func() *RateLimiterNode {
return func() *RateLimiterNode {
return NewRateLimiterNode(level)
}
}
tree.GetOrCreatePartitionLimiters(1, 10, 100,
generateNodeFFunc(internalpb.RateScope_Database),
generateNodeFFunc(internalpb.RateScope_Collection),
generateNodeFFunc(internalpb.RateScope_Partition),
)
tree.GetOrCreatePartitionLimiters(1, 10, 200,
generateNodeFFunc(internalpb.RateScope_Database),
generateNodeFFunc(internalpb.RateScope_Collection),
generateNodeFFunc(internalpb.RateScope_Partition),
)
tree.GetOrCreatePartitionLimiters(1, 20, 300,
generateNodeFFunc(internalpb.RateScope_Database),
generateNodeFFunc(internalpb.RateScope_Collection),
generateNodeFFunc(internalpb.RateScope_Partition),
)
tree.GetOrCreatePartitionLimiters(2, 30, 400,
generateNodeFFunc(internalpb.RateScope_Database),
generateNodeFFunc(internalpb.RateScope_Collection),
generateNodeFFunc(internalpb.RateScope_Partition),
)
assert.Equal(t, 2, root.GetChildren().Len())
assert.Equal(t, 2, root.GetChild(1).GetChildren().Len())
assert.Equal(t, 2, root.GetChild(1).GetChild(10).GetChildren().Len())
tree.ClearInvalidLimiterNode(&proxypb.LimiterNode{
Children: map[int64]*proxypb.LimiterNode{
1: {
Children: map[int64]*proxypb.LimiterNode{
10: {
Children: map[int64]*proxypb.LimiterNode{
100: {},
},
},
},
},
},
})
assert.Equal(t, 1, root.GetChildren().Len())
assert.Equal(t, 1, root.GetChild(1).GetChildren().Len())
assert.Equal(t, 1, root.GetChild(1).GetChild(10).GetChildren().Len())
}

View File

@ -132,6 +132,8 @@ const (
CollectionSearchRateMaxKey = "collection.searchRate.max.vps"
CollectionSearchRateMinKey = "collection.searchRate.min.vps"
CollectionDiskQuotaKey = "collection.diskProtection.diskQuota.mb"
PartitionDiskQuotaKey = "partition.diskProtection.diskQuota.mb"
)
// common properties

View File

@ -167,7 +167,7 @@ var (
Help: "The quota states of cluster",
}, []string{
"quota_states",
"db_name",
"name",
})
// RootCoordRateLimitRatio reflects the ratio of rate limit.

View File

@ -87,6 +87,7 @@ type QueryNodeQuotaMetrics struct {
type DataCoordQuotaMetrics struct {
TotalBinlogSize int64
CollectionBinlogSize map[int64]int64
PartitionsBinlogSize map[int64]map[int64]int64
}
// DataNodeQuotaMetrics are metrics of DataNode.

View File

@ -61,6 +61,12 @@ type quotaConfig struct {
CompactionLimitEnabled ParamItem `refreshable:"true"`
MaxCompactionRate ParamItem `refreshable:"true"`
DDLCollectionRatePerDB ParamItem `refreshable:"true"`
DDLPartitionRatePerDB ParamItem `refreshable:"true"`
MaxIndexRatePerDB ParamItem `refreshable:"true"`
MaxFlushRatePerDB ParamItem `refreshable:"true"`
MaxCompactionRatePerDB ParamItem `refreshable:"true"`
// dml
DMLLimitEnabled ParamItem `refreshable:"true"`
DMLMaxInsertRate ParamItem `refreshable:"true"`
@ -71,6 +77,14 @@ type quotaConfig struct {
DMLMinDeleteRate ParamItem `refreshable:"true"`
DMLMaxBulkLoadRate ParamItem `refreshable:"true"`
DMLMinBulkLoadRate ParamItem `refreshable:"true"`
DMLMaxInsertRatePerDB ParamItem `refreshable:"true"`
DMLMinInsertRatePerDB ParamItem `refreshable:"true"`
DMLMaxUpsertRatePerDB ParamItem `refreshable:"true"`
DMLMinUpsertRatePerDB ParamItem `refreshable:"true"`
DMLMaxDeleteRatePerDB ParamItem `refreshable:"true"`
DMLMinDeleteRatePerDB ParamItem `refreshable:"true"`
DMLMaxBulkLoadRatePerDB ParamItem `refreshable:"true"`
DMLMinBulkLoadRatePerDB ParamItem `refreshable:"true"`
DMLMaxInsertRatePerCollection ParamItem `refreshable:"true"`
DMLMinInsertRatePerCollection ParamItem `refreshable:"true"`
DMLMaxUpsertRatePerCollection ParamItem `refreshable:"true"`
@ -79,6 +93,14 @@ type quotaConfig struct {
DMLMinDeleteRatePerCollection ParamItem `refreshable:"true"`
DMLMaxBulkLoadRatePerCollection ParamItem `refreshable:"true"`
DMLMinBulkLoadRatePerCollection ParamItem `refreshable:"true"`
DMLMaxInsertRatePerPartition ParamItem `refreshable:"true"`
DMLMinInsertRatePerPartition ParamItem `refreshable:"true"`
DMLMaxUpsertRatePerPartition ParamItem `refreshable:"true"`
DMLMinUpsertRatePerPartition ParamItem `refreshable:"true"`
DMLMaxDeleteRatePerPartition ParamItem `refreshable:"true"`
DMLMinDeleteRatePerPartition ParamItem `refreshable:"true"`
DMLMaxBulkLoadRatePerPartition ParamItem `refreshable:"true"`
DMLMinBulkLoadRatePerPartition ParamItem `refreshable:"true"`
// dql
DQLLimitEnabled ParamItem `refreshable:"true"`
@ -86,10 +108,18 @@ type quotaConfig struct {
DQLMinSearchRate ParamItem `refreshable:"true"`
DQLMaxQueryRate ParamItem `refreshable:"true"`
DQLMinQueryRate ParamItem `refreshable:"true"`
DQLMaxSearchRatePerDB ParamItem `refreshable:"true"`
DQLMinSearchRatePerDB ParamItem `refreshable:"true"`
DQLMaxQueryRatePerDB ParamItem `refreshable:"true"`
DQLMinQueryRatePerDB ParamItem `refreshable:"true"`
DQLMaxSearchRatePerCollection ParamItem `refreshable:"true"`
DQLMinSearchRatePerCollection ParamItem `refreshable:"true"`
DQLMaxQueryRatePerCollection ParamItem `refreshable:"true"`
DQLMinQueryRatePerCollection ParamItem `refreshable:"true"`
DQLMaxSearchRatePerPartition ParamItem `refreshable:"true"`
DQLMinSearchRatePerPartition ParamItem `refreshable:"true"`
DQLMaxQueryRatePerPartition ParamItem `refreshable:"true"`
DQLMinQueryRatePerPartition ParamItem `refreshable:"true"`
// limits
MaxCollectionNum ParamItem `refreshable:"true"`
@ -114,16 +144,20 @@ type quotaConfig struct {
GrowingSegmentsSizeHighWaterLevel ParamItem `refreshable:"true"`
DiskProtectionEnabled ParamItem `refreshable:"true"`
DiskQuota ParamItem `refreshable:"true"`
DiskQuotaPerDB ParamItem `refreshable:"true"`
DiskQuotaPerCollection ParamItem `refreshable:"true"`
DiskQuotaPerPartition ParamItem `refreshable:"true"`
// limit reading
ForceDenyReading ParamItem `refreshable:"true"`
QueueProtectionEnabled ParamItem `refreshable:"true"`
NQInQueueThreshold ParamItem `refreshable:"true"`
QueueLatencyThreshold ParamItem `refreshable:"true"`
ResultProtectionEnabled ParamItem `refreshable:"true"`
MaxReadResultRate ParamItem `refreshable:"true"`
CoolOffSpeed ParamItem `refreshable:"true"`
ForceDenyReading ParamItem `refreshable:"true"`
QueueProtectionEnabled ParamItem `refreshable:"true"`
NQInQueueThreshold ParamItem `refreshable:"true"`
QueueLatencyThreshold ParamItem `refreshable:"true"`
ResultProtectionEnabled ParamItem `refreshable:"true"`
MaxReadResultRate ParamItem `refreshable:"true"`
MaxReadResultRatePerDB ParamItem `refreshable:"true"`
MaxReadResultRatePerCollection ParamItem `refreshable:"true"`
CoolOffSpeed ParamItem `refreshable:"true"`
}
func (p *quotaConfig) init(base *BaseTable) {
@ -185,6 +219,25 @@ seconds, (0 ~ 65536)`,
}
p.DDLCollectionRate.Init(base.mgr)
p.DDLCollectionRatePerDB = ParamItem{
Key: "quotaAndLimits.ddl.db.collectionRate",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DDLLimitEnabled.GetAsBool() {
return max
}
// [0 ~ Inf)
if getAsInt(v) < 0 {
return max
}
return v
},
Doc: "qps of db level , default no limit, rate for CreateCollection, DropCollection, LoadCollection, ReleaseCollection",
Export: true,
}
p.DDLCollectionRatePerDB.Init(base.mgr)
p.DDLPartitionRate = ParamItem{
Key: "quotaAndLimits.ddl.partitionRate",
Version: "2.2.0",
@ -204,6 +257,25 @@ seconds, (0 ~ 65536)`,
}
p.DDLPartitionRate.Init(base.mgr)
p.DDLPartitionRatePerDB = ParamItem{
Key: "quotaAndLimits.ddl.db.partitionRate",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DDLLimitEnabled.GetAsBool() {
return max
}
// [0 ~ Inf)
if getAsInt(v) < 0 {
return max
}
return v
},
Doc: "qps of db level, default no limit, rate for CreatePartition, DropPartition, LoadPartition, ReleasePartition",
Export: true,
}
p.DDLPartitionRatePerDB.Init(base.mgr)
p.IndexLimitEnabled = ParamItem{
Key: "quotaAndLimits.indexRate.enabled",
Version: "2.2.0",
@ -231,6 +303,25 @@ seconds, (0 ~ 65536)`,
}
p.MaxIndexRate.Init(base.mgr)
p.MaxIndexRatePerDB = ParamItem{
Key: "quotaAndLimits.indexRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.IndexLimitEnabled.GetAsBool() {
return max
}
// [0 ~ Inf)
if getAsFloat(v) < 0 {
return max
}
return v
},
Doc: "qps of db level, default no limit, rate for CreateIndex, DropIndex",
Export: true,
}
p.MaxIndexRatePerDB.Init(base.mgr)
p.FlushLimitEnabled = ParamItem{
Key: "quotaAndLimits.flushRate.enabled",
Version: "2.2.0",
@ -258,6 +349,25 @@ seconds, (0 ~ 65536)`,
}
p.MaxFlushRate.Init(base.mgr)
p.MaxFlushRatePerDB = ParamItem{
Key: "quotaAndLimits.flushRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.FlushLimitEnabled.GetAsBool() {
return max
}
// [0 ~ Inf)
if getAsInt(v) < 0 {
return max
}
return v
},
Doc: "qps of db level, default no limit, rate for flush",
Export: true,
}
p.MaxFlushRatePerDB.Init(base.mgr)
p.MaxFlushRatePerCollection = ParamItem{
Key: "quotaAndLimits.flushRate.collection.max",
Version: "2.3.9",
@ -304,6 +414,25 @@ seconds, (0 ~ 65536)`,
}
p.MaxCompactionRate.Init(base.mgr)
p.MaxCompactionRatePerDB = ParamItem{
Key: "quotaAndLimits.compactionRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.CompactionLimitEnabled.GetAsBool() {
return max
}
// [0 ~ Inf)
if getAsInt(v) < 0 {
return max
}
return v
},
Doc: "qps of db level, default no limit, rate for manualCompaction",
Export: true,
}
p.MaxCompactionRatePerDB.Init(base.mgr)
// dml
p.DMLLimitEnabled = ParamItem{
Key: "quotaAndLimits.dml.enabled",
@ -359,6 +488,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinInsertRate.Init(base.mgr)
p.DMLMaxInsertRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.insertRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxInsertRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit",
Export: true,
}
p.DMLMaxInsertRatePerDB.Init(base.mgr)
p.DMLMinInsertRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.insertRate.db.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxInsertRatePerDB.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinInsertRatePerDB.Init(base.mgr)
p.DMLMaxInsertRatePerCollection = ParamItem{
Key: "quotaAndLimits.dml.insertRate.collection.max",
Version: "2.2.9",
@ -403,6 +576,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinInsertRatePerCollection.Init(base.mgr)
p.DMLMaxInsertRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.insertRate.partition.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxInsertRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit",
Export: true,
}
p.DMLMaxInsertRatePerPartition.Init(base.mgr)
p.DMLMinInsertRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.insertRate.partition.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxInsertRatePerPartition.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinInsertRatePerPartition.Init(base.mgr)
p.DMLMaxUpsertRate = ParamItem{
Key: "quotaAndLimits.dml.upsertRate.max",
Version: "2.3.0",
@ -447,6 +664,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinUpsertRate.Init(base.mgr)
p.DMLMaxUpsertRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.upsertRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxUpsertRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit",
Export: true,
}
p.DMLMaxUpsertRatePerDB.Init(base.mgr)
p.DMLMinUpsertRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.upsertRate.db.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxUpsertRatePerDB.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinUpsertRatePerDB.Init(base.mgr)
p.DMLMaxUpsertRatePerCollection = ParamItem{
Key: "quotaAndLimits.dml.upsertRate.collection.max",
Version: "2.3.0",
@ -491,6 +752,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinUpsertRatePerCollection.Init(base.mgr)
p.DMLMaxUpsertRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.upsertRate.partition.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxUpsertRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit",
Export: true,
}
p.DMLMaxUpsertRatePerPartition.Init(base.mgr)
p.DMLMinUpsertRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.upsertRate.partition.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxUpsertRatePerPartition.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinUpsertRatePerPartition.Init(base.mgr)
p.DMLMaxDeleteRate = ParamItem{
Key: "quotaAndLimits.dml.deleteRate.max",
Version: "2.2.0",
@ -535,6 +840,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinDeleteRate.Init(base.mgr)
p.DMLMaxDeleteRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.deleteRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxDeleteRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit",
Export: true,
}
p.DMLMaxDeleteRatePerDB.Init(base.mgr)
p.DMLMinDeleteRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.deleteRate.db.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxDeleteRatePerDB.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinDeleteRatePerDB.Init(base.mgr)
p.DMLMaxDeleteRatePerCollection = ParamItem{
Key: "quotaAndLimits.dml.deleteRate.collection.max",
Version: "2.2.9",
@ -579,6 +928,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinDeleteRatePerCollection.Init(base.mgr)
p.DMLMaxDeleteRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.deleteRate.partition.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxDeleteRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit",
Export: true,
}
p.DMLMaxDeleteRatePerPartition.Init(base.mgr)
p.DMLMinDeleteRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.deleteRate.partition.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxDeleteRatePerPartition.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinDeleteRatePerPartition.Init(base.mgr)
p.DMLMaxBulkLoadRate = ParamItem{
Key: "quotaAndLimits.dml.bulkLoadRate.max",
Version: "2.2.0",
@ -623,6 +1016,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinBulkLoadRate.Init(base.mgr)
p.DMLMaxBulkLoadRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.bulkLoadRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxBulkLoadRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit, not support yet. TODO: limit db bulkLoad rate",
Export: true,
}
p.DMLMaxBulkLoadRatePerDB.Init(base.mgr)
p.DMLMinBulkLoadRatePerDB = ParamItem{
Key: "quotaAndLimits.dml.bulkLoadRate.db.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxBulkLoadRatePerDB.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinBulkLoadRatePerDB.Init(base.mgr)
p.DMLMaxBulkLoadRatePerCollection = ParamItem{
Key: "quotaAndLimits.dml.bulkLoadRate.collection.max",
Version: "2.2.9",
@ -667,6 +1104,50 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DMLMinBulkLoadRatePerCollection.Init(base.mgr)
p.DMLMaxBulkLoadRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.bulkLoadRate.partition.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
rate = megaBytes2Bytes(rate)
}
// [0, inf)
if rate < 0 {
return p.DMLMaxBulkLoadRate.GetValue()
}
return fmt.Sprintf("%f", rate)
},
Doc: "MB/s, default no limit, not support yet. TODO: limit partition bulkLoad rate",
Export: true,
}
p.DMLMaxBulkLoadRatePerPartition.Init(base.mgr)
p.DMLMinBulkLoadRatePerPartition = ParamItem{
Key: "quotaAndLimits.dml.bulkLoadRate.partition.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DMLLimitEnabled.GetAsBool() {
return min
}
rate := megaBytes2Bytes(getAsFloat(v))
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DMLMaxBulkLoadRatePerPartition.GetAsFloat()) {
return min
}
return fmt.Sprintf("%f", rate)
},
}
p.DMLMinBulkLoadRatePerPartition.Init(base.mgr)
// dql
p.DQLLimitEnabled = ParamItem{
Key: "quotaAndLimits.dql.enabled",
@ -718,6 +1199,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DQLMinSearchRate.Init(base.mgr)
p.DQLMaxSearchRatePerDB = ParamItem{
Key: "quotaAndLimits.dql.searchRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return max
}
// [0, inf)
if getAsFloat(v) < 0 {
return p.DQLMaxSearchRate.GetValue()
}
return v
},
Doc: "vps (vectors per second), default no limit",
Export: true,
}
p.DQLMaxSearchRatePerDB.Init(base.mgr)
p.DQLMinSearchRatePerDB = ParamItem{
Key: "quotaAndLimits.dql.searchRate.db.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return min
}
rate := getAsFloat(v)
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DQLMaxSearchRatePerDB.GetAsFloat()) {
return min
}
return v
},
}
p.DQLMinSearchRatePerDB.Init(base.mgr)
p.DQLMaxSearchRatePerCollection = ParamItem{
Key: "quotaAndLimits.dql.searchRate.collection.max",
Version: "2.2.9",
@ -758,6 +1279,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DQLMinSearchRatePerCollection.Init(base.mgr)
p.DQLMaxSearchRatePerPartition = ParamItem{
Key: "quotaAndLimits.dql.searchRate.partition.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return max
}
// [0, inf)
if getAsFloat(v) < 0 {
return p.DQLMaxSearchRate.GetValue()
}
return v
},
Doc: "vps (vectors per second), default no limit",
Export: true,
}
p.DQLMaxSearchRatePerPartition.Init(base.mgr)
p.DQLMinSearchRatePerPartition = ParamItem{
Key: "quotaAndLimits.dql.searchRate.partition.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return min
}
rate := getAsFloat(v)
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DQLMaxSearchRatePerPartition.GetAsFloat()) {
return min
}
return v
},
}
p.DQLMinSearchRatePerPartition.Init(base.mgr)
p.DQLMaxQueryRate = ParamItem{
Key: "quotaAndLimits.dql.queryRate.max",
Version: "2.2.0",
@ -798,6 +1359,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DQLMinQueryRate.Init(base.mgr)
p.DQLMaxQueryRatePerDB = ParamItem{
Key: "quotaAndLimits.dql.queryRate.db.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return max
}
// [0, inf)
if getAsFloat(v) < 0 {
return p.DQLMaxQueryRate.GetValue()
}
return v
},
Doc: "qps, default no limit",
Export: true,
}
p.DQLMaxQueryRatePerDB.Init(base.mgr)
p.DQLMinQueryRatePerDB = ParamItem{
Key: "quotaAndLimits.dql.queryRate.db.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return min
}
rate := getAsFloat(v)
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DQLMaxQueryRatePerDB.GetAsFloat()) {
return min
}
return v
},
}
p.DQLMinQueryRatePerDB.Init(base.mgr)
p.DQLMaxQueryRatePerCollection = ParamItem{
Key: "quotaAndLimits.dql.queryRate.collection.max",
Version: "2.2.9",
@ -838,6 +1439,46 @@ The maximum rate will not be greater than ` + "max" + `.`,
}
p.DQLMinQueryRatePerCollection.Init(base.mgr)
p.DQLMaxQueryRatePerPartition = ParamItem{
Key: "quotaAndLimits.dql.queryRate.partition.max",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return max
}
// [0, inf)
if getAsFloat(v) < 0 {
return p.DQLMaxQueryRate.GetValue()
}
return v
},
Doc: "qps, default no limit",
Export: true,
}
p.DQLMaxQueryRatePerPartition.Init(base.mgr)
p.DQLMinQueryRatePerPartition = ParamItem{
Key: "quotaAndLimits.dql.queryRate.partition.min",
Version: "2.4.1",
DefaultValue: min,
Formatter: func(v string) string {
if !p.DQLLimitEnabled.GetAsBool() {
return min
}
rate := getAsFloat(v)
// [0, inf)
if rate < 0 {
return min
}
if !p.checkMinMaxLegal(rate, p.DQLMaxQueryRatePerPartition.GetAsFloat()) {
return min
}
return v
},
}
p.DQLMinQueryRatePerPartition.Init(base.mgr)
// limits
p.MaxCollectionNum = ParamItem{
Key: "quotaAndLimits.limits.maxCollectionNum",
@ -1132,6 +1773,27 @@ but the rate will not be lower than minRateRatio * dmlRate.`,
}
p.DiskQuota.Init(base.mgr)
p.DiskQuotaPerDB = ParamItem{
Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerDB",
Version: "2.4.1",
DefaultValue: quota,
Formatter: func(v string) string {
if !p.DiskProtectionEnabled.GetAsBool() {
return max
}
level := getAsFloat(v)
// (0, +inf)
if level <= 0 {
return p.DiskQuota.GetValue()
}
// megabytes to bytes
return fmt.Sprintf("%f", megaBytes2Bytes(level))
},
Doc: "MB, (0, +inf), default no limit",
Export: true,
}
p.DiskQuotaPerDB.Init(base.mgr)
p.DiskQuotaPerCollection = ParamItem{
Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerCollection",
Version: "2.2.8",
@ -1153,6 +1815,27 @@ but the rate will not be lower than minRateRatio * dmlRate.`,
}
p.DiskQuotaPerCollection.Init(base.mgr)
p.DiskQuotaPerPartition = ParamItem{
Key: "quotaAndLimits.limitWriting.diskProtection.diskQuotaPerPartition",
Version: "2.4.1",
DefaultValue: quota,
Formatter: func(v string) string {
if !p.DiskProtectionEnabled.GetAsBool() {
return max
}
level := getAsFloat(v)
// (0, +inf)
if level <= 0 {
return p.DiskQuota.GetValue()
}
// megabytes to bytes
return fmt.Sprintf("%f", megaBytes2Bytes(level))
},
Doc: "MB, (0, +inf), default no limit",
Export: true,
}
p.DiskQuotaPerPartition.Init(base.mgr)
// limit reading
p.ForceDenyReading = ParamItem{
Key: "quotaAndLimits.limitReading.forceDeny",
@ -1253,6 +1936,50 @@ MB/s, default no limit`,
}
p.MaxReadResultRate.Init(base.mgr)
p.MaxReadResultRatePerDB = ParamItem{
Key: "quotaAndLimits.limitReading.resultProtection.maxReadResultRatePerDB",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.ResultProtectionEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
return fmt.Sprintf("%f", megaBytes2Bytes(rate))
}
// [0, inf)
if rate < 0 {
return max
}
return v
},
Export: true,
}
p.MaxReadResultRatePerDB.Init(base.mgr)
p.MaxReadResultRatePerCollection = ParamItem{
Key: "quotaAndLimits.limitReading.resultProtection.maxReadResultRatePerCollection",
Version: "2.4.1",
DefaultValue: max,
Formatter: func(v string) string {
if !p.ResultProtectionEnabled.GetAsBool() {
return max
}
rate := getAsFloat(v)
if math.Abs(rate-defaultMax) > 0.001 { // maxRate != defaultMax
return fmt.Sprintf("%f", megaBytes2Bytes(rate))
}
// [0, inf)
if rate < 0 {
return max
}
return v
},
Export: true,
}
p.MaxReadResultRatePerCollection.Init(base.mgr)
const defaultSpeed = "0.9"
p.CoolOffSpeed = ParamItem{
Key: "quotaAndLimits.limitReading.coolOffSpeed",

View File

@ -45,12 +45,13 @@ const Inf = Limit(math.MaxFloat64)
// in bucket may be negative, and the latter events would be "punished",
// any event should wait for the tokens to be filled to greater or equal to 0.
type Limiter struct {
mu sync.Mutex
mu sync.RWMutex
limit Limit
burst float64
tokens float64
// last is the last time the limiter's tokens field was updated
last time.Time
last time.Time
hasUpdated bool
}
// NewLimiter returns a new Limiter that allows events up to rate r.
@ -63,13 +64,20 @@ func NewLimiter(r Limit, b float64) *Limiter {
// Limit returns the maximum overall event rate.
func (lim *Limiter) Limit() Limit {
lim.mu.Lock()
defer lim.mu.Unlock()
lim.mu.RLock()
defer lim.mu.RUnlock()
return lim.limit
}
// AllowN reports whether n events may happen at time now.
func (lim *Limiter) AllowN(now time.Time, n int) bool {
lim.mu.RLock()
if lim.limit == Inf {
lim.mu.RUnlock()
return true
}
lim.mu.RUnlock()
lim.mu.Lock()
defer lim.mu.Unlock()
@ -119,6 +127,7 @@ func (lim *Limiter) SetLimit(newLimit Limit) {
// use rate as burst, because Limiter is with punishment mechanism, burst is insignificant.
lim.burst = float64(newLimit)
}
lim.hasUpdated = true
}
// Cancel the AllowN operation and refund the tokens that have already been deducted by the limiter.
@ -128,6 +137,12 @@ func (lim *Limiter) Cancel(n int) {
lim.tokens += float64(n)
}
func (lim *Limiter) HasUpdated() bool {
lim.mu.RLock()
defer lim.mu.RUnlock()
return lim.hasUpdated
}
// advance calculates and returns an updated state for lim resulting from the passage of time.
// lim is not changed. advance requires that lim.mu is held.
func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) {

View File

@ -19,8 +19,11 @@ package ratelimitutil
import (
"fmt"
"math"
"strings"
"sync"
"time"
"github.com/samber/lo"
)
const (
@ -34,21 +37,22 @@ const (
type RateCollector struct {
sync.Mutex
window time.Duration
granularity time.Duration
position int
values map[string][]float64
window time.Duration
granularity time.Duration
position int
values map[string][]float64
deprecatedSubLabels []lo.Tuple2[string, string]
last time.Time
}
// NewRateCollector is shorthand for newRateCollector(window, granularity, time.Now()).
func NewRateCollector(window time.Duration, granularity time.Duration) (*RateCollector, error) {
return newRateCollector(window, granularity, time.Now())
func NewRateCollector(window time.Duration, granularity time.Duration, enableSubLabel bool) (*RateCollector, error) {
return newRateCollector(window, granularity, time.Now(), enableSubLabel)
}
// newRateCollector returns a new RateCollector with given window and granularity.
func newRateCollector(window time.Duration, granularity time.Duration, now time.Time) (*RateCollector, error) {
func newRateCollector(window time.Duration, granularity time.Duration, now time.Time, enableSubLabel bool) (*RateCollector, error) {
if window == 0 || granularity == 0 {
return nil, fmt.Errorf("create RateCollector failed, window or granularity cannot be 0, window = %d, granularity = %d", window, granularity)
}
@ -62,9 +66,52 @@ func newRateCollector(window time.Duration, granularity time.Duration, now time.
values: make(map[string][]float64),
last: now,
}
if enableSubLabel {
go rc.cleanDeprecateSubLabels()
}
return rc, nil
}
func (r *RateCollector) cleanDeprecateSubLabels() {
tick := time.NewTicker(r.window * 2)
defer tick.Stop()
for range tick.C {
r.Lock()
for _, labelInfo := range r.deprecatedSubLabels {
r.removeSubLabel(labelInfo)
}
r.Unlock()
}
}
func (r *RateCollector) removeSubLabel(labelInfo lo.Tuple2[string, string]) {
label := labelInfo.A
subLabel := labelInfo.B
if subLabel == "" {
return
}
removeKeys := make([]string, 1)
removeKeys[0] = FormatSubLabel(label, subLabel)
deleteCollectionSubLabelWithPrefix := func(dbName string) {
for key := range r.values {
if strings.HasPrefix(key, FormatSubLabel(label, GetCollectionSubLabel(dbName, ""))) {
removeKeys = append(removeKeys, key)
}
}
}
parts := strings.Split(subLabel, ".")
if strings.HasPrefix(subLabel, GetDBSubLabel("")) {
dbName := parts[1]
deleteCollectionSubLabelWithPrefix(dbName)
}
for _, key := range removeKeys {
delete(r.values, key)
}
}
// Register init values of RateCollector for specified label.
func (r *RateCollector) Register(label string) {
r.Lock()
@ -81,21 +128,77 @@ func (r *RateCollector) Deregister(label string) {
delete(r.values, label)
}
func GetDBSubLabel(dbName string) string {
return fmt.Sprintf("db.%s", dbName)
}
func GetCollectionSubLabel(dbName, collectionName string) string {
return fmt.Sprintf("collection.%s.%s", dbName, collectionName)
}
func FormatSubLabel(label, subLabel string) string {
return fmt.Sprintf("%s-%s", label, subLabel)
}
func GetDBFromSubLabel(label, fullLabel string) (string, bool) {
if !strings.HasPrefix(fullLabel, FormatSubLabel(label, GetDBSubLabel(""))) {
return "", false
}
return fullLabel[len(FormatSubLabel(label, GetDBSubLabel(""))):], true
}
func GetCollectionFromSubLabel(label, fullLabel string) (string, string, bool) {
if !strings.HasPrefix(fullLabel, FormatSubLabel(label, "")) {
return "", "", false
}
subLabels := strings.Split(fullLabel[len(FormatSubLabel(label, "")):], ".")
if len(subLabels) != 3 || subLabels[0] != "collection" {
return "", "", false
}
return subLabels[1], subLabels[2], true
}
func (r *RateCollector) DeregisterSubLabel(label, subLabel string) {
r.Lock()
defer r.Unlock()
r.deprecatedSubLabels = append(r.deprecatedSubLabels, lo.Tuple2[string, string]{
A: label,
B: subLabel,
})
}
// Add is shorthand for add(label, value, time.Now()).
func (r *RateCollector) Add(label string, value float64) {
r.add(label, value, time.Now())
func (r *RateCollector) Add(label string, value float64, subLabels ...string) {
r.add(label, value, time.Now(), subLabels...)
}
// add increases the current value of specified label.
func (r *RateCollector) add(label string, value float64, now time.Time) {
func (r *RateCollector) add(label string, value float64, now time.Time, subLabels ...string) {
r.Lock()
defer r.Unlock()
r.update(now)
if _, ok := r.values[label]; ok {
r.values[label][r.position] += value
for _, subLabel := range subLabels {
r.unsafeAddForSubLabels(label, subLabel, value)
}
}
}
func (r *RateCollector) unsafeAddForSubLabels(label, subLabel string, value float64) {
if subLabel == "" {
return
}
sub := FormatSubLabel(label, subLabel)
if _, ok := r.values[sub]; ok {
r.values[sub][r.position] += value
return
}
r.values[sub] = make([]float64, int(r.window/r.granularity))
r.values[sub][r.position] = value
}
// Max is shorthand for max(label, time.Now()).
func (r *RateCollector) Max(label string, now time.Time) (float64, error) {
return r.max(label, time.Now())
@ -145,6 +248,26 @@ func (r *RateCollector) Rate(label string, duration time.Duration) (float64, err
return r.rate(label, duration, time.Now())
}
func (r *RateCollector) RateSubLabel(label string, duration time.Duration) (map[string]float64, error) {
subLabelPrefix := FormatSubLabel(label, "")
subLabels := make(map[string]float64)
r.Lock()
for s := range r.values {
if strings.HasPrefix(s, subLabelPrefix) {
subLabels[s] = 0
}
}
r.Unlock()
for s := range subLabels {
v, err := r.rate(s, duration, time.Now())
if err != nil {
return nil, err
}
subLabels[s] = v
}
return subLabels, nil
}
// rate returns the latest mean value of the specified duration.
func (r *RateCollector) rate(label string, duration time.Duration, now time.Time) (float64, error) {
if duration > r.window {

View File

@ -22,6 +22,7 @@ import (
"testing"
"time"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
)
@ -36,7 +37,7 @@ func TestRateCollector(t *testing.T) {
ts100 = ts0.Add(time.Duration(100.0 * float64(time.Second)))
)
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0)
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0, false)
assert.NoError(t, err)
label := "mock_label"
rc.Register(label)
@ -78,7 +79,7 @@ func TestRateCollector(t *testing.T) {
ts31 = ts0.Add(time.Duration(3.1 * float64(time.Second)))
)
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0)
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, ts0, false)
assert.NoError(t, err)
label := "mock_label"
rc.Register(label)
@ -105,7 +106,7 @@ func TestRateCollector(t *testing.T) {
start := tt.now()
end := start.Add(testPeriod * time.Second)
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, start)
rc, err := newRateCollector(DefaultWindow, DefaultGranularity, start, false)
assert.NoError(t, err)
label := "mock_label"
rc.Register(label)
@ -138,3 +139,111 @@ func TestRateCollector(t *testing.T) {
}
})
}
func TestRateSubLabel(t *testing.T) {
rateCollector, err := NewRateCollector(5*time.Second, time.Second, true)
assert.NoError(t, err)
var (
label = "search"
db = "hoo"
collection = "foo"
dbSubLabel = GetDBSubLabel(db)
collectionSubLabel = GetCollectionSubLabel(db, collection)
ts0 = time.Now()
ts10 = ts0.Add(time.Duration(1.0 * float64(time.Second)))
ts19 = ts0.Add(time.Duration(1.9 * float64(time.Second)))
ts20 = ts0.Add(time.Duration(2.0 * float64(time.Second)))
ts30 = ts0.Add(time.Duration(3.0 * float64(time.Second)))
ts40 = ts0.Add(time.Duration(4.0 * float64(time.Second)))
)
rateCollector.Register(label)
defer rateCollector.Deregister(label)
rateCollector.add(label, 10, ts0, dbSubLabel, collectionSubLabel)
rateCollector.add(label, 20, ts10, dbSubLabel, collectionSubLabel)
rateCollector.add(label, 30, ts19, dbSubLabel, collectionSubLabel)
rateCollector.add(label, 40, ts20, dbSubLabel, collectionSubLabel)
rateCollector.add(label, 50, ts30, dbSubLabel, collectionSubLabel)
rateCollector.add(label, 60, ts40, dbSubLabel, collectionSubLabel)
time.Sleep(4 * time.Second)
// 10 20+30 40 50 60
{
avg, err := rateCollector.Rate(label, 3*time.Second)
assert.NoError(t, err)
assert.Equal(t, float64(50), avg)
}
{
avg, err := rateCollector.Rate(label, 5*time.Second)
assert.NoError(t, err)
assert.Equal(t, float64(42), avg)
}
{
avgs, err := rateCollector.RateSubLabel(label, 3*time.Second)
assert.NoError(t, err)
assert.Equal(t, 2, len(avgs))
assert.Equal(t, float64(50), avgs[FormatSubLabel(label, dbSubLabel)])
assert.Equal(t, float64(50), avgs[FormatSubLabel(label, collectionSubLabel)])
}
rateCollector.Add(label, 10, GetCollectionSubLabel(db, collection))
rateCollector.Add(label, 10, GetCollectionSubLabel(db, "col2"))
rateCollector.DeregisterSubLabel(label, GetCollectionSubLabel(db, "col2"))
rateCollector.DeregisterSubLabel(label, dbSubLabel)
rateCollector.removeSubLabel(lo.Tuple2[string, string]{
A: "aaa",
})
rateCollector.Lock()
for _, labelInfo := range rateCollector.deprecatedSubLabels {
rateCollector.removeSubLabel(labelInfo)
}
rateCollector.Unlock()
{
_, ok := rateCollector.values[FormatSubLabel(label, dbSubLabel)]
assert.False(t, ok)
}
{
_, ok := rateCollector.values[FormatSubLabel(label, collectionSubLabel)]
assert.False(t, ok)
}
{
assert.Len(t, rateCollector.values, 1)
_, ok := rateCollector.values[label]
assert.True(t, ok)
}
}
func TestLabelUtil(t *testing.T) {
assert.Equal(t, GetDBSubLabel("db"), "db.db")
assert.Equal(t, GetCollectionSubLabel("db", "collection"), "collection.db.collection")
{
db, ok := GetDBFromSubLabel("foo", FormatSubLabel("foo", GetDBSubLabel("db1")))
assert.True(t, ok)
assert.Equal(t, "db1", db)
}
{
_, ok := GetDBFromSubLabel("foo", "aaa")
assert.False(t, ok)
}
{
db, col, ok := GetCollectionFromSubLabel("foo", FormatSubLabel("foo", GetCollectionSubLabel("db1", "col1")))
assert.True(t, ok)
assert.Equal(t, "col1", col)
assert.Equal(t, "db1", db)
}
{
_, _, ok := GetCollectionFromSubLabel("foo", "aaa")
assert.False(t, ok)
}
}

View File

@ -0,0 +1,30 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ratelimitutil
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
var QuotaErrorString = map[commonpb.ErrorCode]string{
commonpb.ErrorCode_ForceDeny: "the writing has been deactivated by the administrator",
commonpb.ErrorCode_MemoryQuotaExhausted: "memory quota exceeded, please allocate more resources",
commonpb.ErrorCode_DiskQuotaExhausted: "disk quota exceeded, please allocate more resources",
commonpb.ErrorCode_TimeTickLongDelay: "time tick long delay",
}
func GetQuotaErrorString(errCode commonpb.ErrorCode) string {
return QuotaErrorString[errCode]
}

View File

@ -0,0 +1,43 @@
package ratelimitutil
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
func TestGetQuotaErrorString(t *testing.T) {
tests := []struct {
name string
args commonpb.ErrorCode
want string
}{
{
name: "Test ErrorCode_ForceDeny",
args: commonpb.ErrorCode_ForceDeny,
want: "the writing has been deactivated by the administrator",
},
{
name: "Test ErrorCode_MemoryQuotaExhausted",
args: commonpb.ErrorCode_MemoryQuotaExhausted,
want: "memory quota exceeded, please allocate more resources",
},
{
name: "Test ErrorCode_DiskQuotaExhausted",
args: commonpb.ErrorCode_DiskQuotaExhausted,
want: "disk quota exceeded, please allocate more resources",
},
{
name: "Test ErrorCode_TimeTickLongDelay",
args: commonpb.ErrorCode_TimeTickLongDelay,
want: "time tick long delay",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetQuotaErrorString(tt.args); got != tt.want {
t.Errorf("GetQuotaErrorString() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -30,6 +30,16 @@ func TestUniqueSet(t *testing.T) {
assert.True(t, set.Contain(9))
assert.True(t, set.Contain(5, 7, 9))
containFive := false
set.Range(func(i UniqueID) bool {
if i == 5 {
containFive = true
return false
}
return true
})
assert.True(t, containFive)
set.Remove(7)
assert.True(t, set.Contain(5))
assert.False(t, set.Contain(7))