From 9a719ec89e88c31c22b3dba8e3793198365befd7 Mon Sep 17 00:00:00 2001 From: SimFG Date: Sun, 28 Apr 2024 11:01:28 +0800 Subject: [PATCH] fix: skip the empty partition name in the rate limit interceptor (#32647) issue: https://github.com/milvus-io/milvus/issues/30577 Signed-off-by: SimFG --- internal/proxy/rate_limit_interceptor.go | 3 + internal/proxy/rate_limit_interceptor_test.go | 86 +++++++++++++++---- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/internal/proxy/rate_limit_interceptor.go b/internal/proxy/rate_limit_interceptor.go index 541fcebdc7..14ac320334 100644 --- a/internal/proxy/rate_limit_interceptor.go +++ b/internal/proxy/rate_limit_interceptor.go @@ -86,6 +86,9 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map if err != nil { return 0, nil, err } + if r.GetPartitionName() == "" { + return db.dbID, map[int64][]int64{collectionID: {}}, nil + } part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName()) if err != nil { return 0, nil, err diff --git a/internal/proxy/rate_limit_interceptor_test.go b/internal/proxy/rate_limit_interceptor_test.go index a018d5f307..cfea05d30b 100644 --- a/internal/proxy/rate_limit_interceptor_test.go +++ b/internal/proxy/rate_limit_interceptor_test.go @@ -64,33 +64,65 @@ func TestRateLimitInterceptor(t *testing.T) { createdTimestamp: 1, }, nil) globalMetaCache = mockCache - database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{}) + database, col2part, rt, size, err := getRequestInfo(context.Background(), &milvuspb.InsertRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }) assert.NoError(t, err) - assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size) + assert.Equal(t, proto.Size(&milvuspb.InsertRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }), size) assert.Equal(t, internalpb.RateType_DMLInsert, rt) assert.Equal(t, database, int64(100)) assert.True(t, len(col2part) == 1) assert.Equal(t, int64(10), col2part[1][0]) - database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.UpsertRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }) assert.NoError(t, err) - assert.Equal(t, proto.Size(&milvuspb.InsertRequest{}), size) + assert.Equal(t, proto.Size(&milvuspb.InsertRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }), size) assert.Equal(t, internalpb.RateType_DMLUpsert, rt) assert.Equal(t, database, int64(100)) assert.True(t, len(col2part) == 1) assert.Equal(t, int64(10), col2part[1][0]) - database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.DeleteRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }) assert.NoError(t, err) - assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{}), size) + assert.Equal(t, proto.Size(&milvuspb.DeleteRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }), size) assert.Equal(t, internalpb.RateType_DMLDelete, rt) assert.Equal(t, database, int64(100)) assert.True(t, len(col2part) == 1) assert.Equal(t, int64(10), col2part[1][0]) - database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.ImportRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }) assert.NoError(t, err) - assert.Equal(t, proto.Size(&milvuspb.ImportRequest{}), size) + assert.Equal(t, proto.Size(&milvuspb.ImportRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }), size) assert.Equal(t, internalpb.RateType_DMLBulkLoad, rt) assert.Equal(t, database, int64(100)) assert.True(t, len(col2part) == 1) @@ -109,13 +141,19 @@ func TestRateLimitInterceptor(t *testing.T) { assert.Equal(t, 1, len(col2part)) assert.Equal(t, 1, len(col2part[1])) - database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{}) + database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.QueryRequest{ + CollectionName: "foo", + PartitionNames: []string{ + "p1", + }, + DbName: "db1", + }) assert.NoError(t, err) assert.Equal(t, 1, size) assert.Equal(t, internalpb.RateType_DQLQuery, rt) assert.Equal(t, database, int64(100)) assert.Equal(t, 1, len(col2part)) - assert.Equal(t, 0, len(col2part[1])) + assert.Equal(t, 1, len(col2part[1])) database, col2part, rt, size, err = getRequestInfo(context.Background(), &milvuspb.CreateCollectionRequest{}) assert.NoError(t, err) @@ -268,13 +306,21 @@ func TestRateLimitInterceptor(t *testing.T) { limiter.limit = true interceptorFun := RateLimitInterceptor(&limiter) - rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler) + rsp, err := interceptorFun(context.Background(), &milvuspb.InsertRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }, serverInfo, handler) assert.Equal(t, commonpb.ErrorCode_RateLimit, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode()) assert.NoError(t, err) limiter.limit = false interceptorFun = RateLimitInterceptor(&limiter) - rsp, err = interceptorFun(context.Background(), &milvuspb.InsertRequest{}, serverInfo, handler) + rsp, err = interceptorFun(context.Background(), &milvuspb.InsertRequest{ + CollectionName: "foo", + PartitionName: "p1", + DbName: "db1", + }, serverInfo, handler) assert.Equal(t, commonpb.ErrorCode_Success, rsp.(*milvuspb.MutationResult).GetStatus().GetErrorCode()) assert.NoError(t, err) @@ -410,12 +456,12 @@ func TestGetInfo(t *testing.T) { mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ dbID: 100, createdTimestamp: 1, - }, nil).Twice() - mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Twice() + }, nil).Times(3) + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3) mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{ name: "p1", partitionID: 100, - }, nil) + }, nil).Twice() { db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ DbName: "foo", @@ -427,6 +473,16 @@ func TestGetInfo(t *testing.T) { assert.NotNil(t, col2par[10]) assert.Equal(t, int64(100), col2par[10][0]) } + { + db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + }) + assert.NoError(t, err) + assert.Equal(t, int64(100), db) + assert.NotNil(t, col2par[10]) + assert.Equal(t, 0, len(col2par[10])) + } { db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{ DbName: "foo",