mirror of https://github.com/milvus-io/milvus.git
fix: skip the empty partition name in the rate limit interceptor (#32647)
issue: https://github.com/milvus-io/milvus/issues/30577 Signed-off-by: SimFG <bang.fu@zilliz.com>pull/32659/head
parent
c080dc1675
commit
9a719ec89e
|
@ -86,6 +86,9 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map
|
|||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if r.GetPartitionName() == "" {
|
||||
return db.dbID, map[int64][]int64{collectionID: {}}, nil
|
||||
}
|
||||
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
|
|
|
@ -64,33 +64,65 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||
createdTimestamp: 1,
|
||||
}, nil)
|
||||
globalMetaCache = mockCache
|
||||
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{})
|
||||
database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
|
||||
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLInsert, rt)
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size)
|
||||
assert.Equal(t, proto.Size(&milvuspb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLUpsert, rt)
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{}), size)
|
||||
assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLDelete, rt)
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
assert.Equal(t, int64(10), col2part[1][0])
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, proto.Size(&milvuspb.ImportRequest{}), size)
|
||||
assert.Equal(t, proto.Size(&milvuspb.ImportRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
}), size)
|
||||
assert.Equal(t, internalpb.RateType_DMLBulkLoad, rt)
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.True(t, len(col2part) == 1)
|
||||
|
@ -109,13 +141,19 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 1, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{})
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionNames: []string{
|
||||
"p1",
|
||||
},
|
||||
DbName: "db1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, size)
|
||||
assert.Equal(t, internalpb.RateType_DQLQuery, rt)
|
||||
assert.Equal(t, database, int64(100))
|
||||
assert.Equal(t, 1, len(col2part))
|
||||
assert.Equal(t, 0, len(col2part[1]))
|
||||
assert.Equal(t, 1, len(col2part[1]))
|
||||
|
||||
database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{})
|
||||
assert.NoError(t, err)
|
||||
|
@ -268,13 +306,21 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||
|
||||
limiter.limit = true
|
||||
interceptorFun := RateLimitInterceptor(&limiter)
|
||||
rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler)
|
||||
rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
}, serverInfo, handler)
|
||||
assert.Equal(t, commonpb.ErrorCode_RateLimit, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
|
||||
assert.NoError(t, err)
|
||||
|
||||
limiter.limit = false
|
||||
interceptorFun = RateLimitInterceptor(&limiter)
|
||||
rsp, err = interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler)
|
||||
rsp, err = interceptorFun(context.Background(), &milvuspb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
PartitionName: "p1",
|
||||
DbName: "db1",
|
||||
}, serverInfo, handler)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode())
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -410,12 +456,12 @@ func TestGetInfo(t *testing.T) {
|
|||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Twice()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Twice()
|
||||
}, nil).Times(3)
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3)
|
||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
||||
name: "p1",
|
||||
partitionID: 100,
|
||||
}, nil)
|
||||
}, nil).Twice()
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
|
@ -427,6 +473,16 @@ func TestGetInfo(t *testing.T) {
|
|||
assert.NotNil(t, col2par[10])
|
||||
assert.Equal(t, int64(100), col2par[10][0])
|
||||
}
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(100), db)
|
||||
assert.NotNil(t, col2par[10])
|
||||
assert.Equal(t, 0, len(col2par[10]))
|
||||
}
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||
DbName: "foo",
|
||||
|
|
Loading…
Reference in New Issue