fix: skip the empty partition name in the rate limit interceptor (#32647)

issue: https://github.com/milvus-io/milvus/issues/30577

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/32659/head
SimFG 2024-04-28 11:01:28 +08:00 committed by GitHub
parent c080dc1675
commit 9a719ec89e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 15 deletions

View File

@ -86,6 +86,9 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map
if err != nil {
return 0, nil, err
}
if r.GetPartitionName() == "" {
return db.dbID, map[int64][]int64{collectionID: {}}, nil
}
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
if err != nil {
return 0, nil, err

View File

@ -64,33 +64,65 @@ func TestRateLimitInterceptor(t *testing.T) {
createdTimestamp: 1,
}, nil)
globalMetaCache = mockCache
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{})
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
}), size)
assert.Equal(t, internalpb.RateType_DMLInsert, rt)
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
}), size)
assert.Equal(t, internalpb.RateType_DMLUpsert, rt)
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{}), size)
assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
}), size)
assert.Equal(t, internalpb.RateType_DMLDelete, rt)
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
assert.Equal(t, int64(10), col2part[1][0])
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
})
assert.NoError(t, err)
assert.Equal(t, proto.Size(&milvuspb.ImportRequest{}), size)
assert.Equal(t, proto.Size(&milvuspb.ImportRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
}), size)
assert.Equal(t, internalpb.RateType_DMLBulkLoad, rt)
assert.Equal(t, database, int64(100))
assert.True(t, len(col2part) == 1)
@ -109,13 +141,19 @@ func TestRateLimitInterceptor(t *testing.T) {
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 1, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{})
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{
CollectionName: "foo",
PartitionNames: []string{
"p1",
},
DbName: "db1",
})
assert.NoError(t, err)
assert.Equal(t, 1, size)
assert.Equal(t, internalpb.RateType_DQLQuery, rt)
assert.Equal(t, database, int64(100))
assert.Equal(t, 1, len(col2part))
assert.Equal(t, 0, len(col2part[1]))
assert.Equal(t, 1, len(col2part[1]))
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
assert.NoError(t, err)
@ -268,13 +306,21 @@ func TestRateLimitInterceptor(t *testing.T) {
limiter.limit = true
interceptorFun := RateLimitInterceptor(&limiter)
rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler)
rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
}, serverInfo, handler)
assert.Equal(t, commonpb.ErrorCode_RateLimit, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
assert.NoError(t, err)
limiter.limit = false
interceptorFun = RateLimitInterceptor(&limiter)
rsp, err = interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler)
rsp, err = interceptorFun(context.Background(), &milvuspb.InsertRequest{
CollectionName: "foo",
PartitionName: "p1",
DbName: "db1",
}, serverInfo, handler)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
assert.NoError(t, err)
@ -410,12 +456,12 @@ func TestGetInfo(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil).Twice()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Twice()
}, nil).Times(3)
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3)
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
name: "p1",
partitionID: 100,
}, nil)
}, nil).Twice()
{
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
@ -427,6 +473,16 @@ func TestGetInfo(t *testing.T) {
assert.NotNil(t, col2par[10])
assert.Equal(t, int64(100), col2par[10][0])
}
{
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
})
assert.NoError(t, err)
assert.Equal(t, int64(100), db)
assert.NotNil(t, col2par[10])
assert.Equal(t, 0, len(col2par[10]))
}
{
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
DbName: "foo",