enhance: check the partition num when creating collection with partition key (#32670)

issue: #30577

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/32678/head
SimFG 2024-05-07 10:43:29 +08:00 committed by GitHub
parent cf4db3ff4e
commit 7da1ca9efb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 2 deletions

View File

@ -206,9 +206,19 @@ func (t *createCollectionTask) validatePartitionKey() error {
return errors.New("the specified partitions should be greater than 0 if partition key is used") return errors.New("the specified partitions should be greater than 0 if partition key is used")
} }
maxPartitionNum := Params.RootCoordCfg.MaxPartitionNum.GetAsInt64()
if t.GetNumPartitions() > maxPartitionNum {
return merr.WrapErrParameterInvalidMsg("partition number (%d) exceeds max configuration (%d)",
t.GetNumPartitions(), maxPartitionNum)
}
// set default physical partitions num if enable partition key mode // set default physical partitions num if enable partition key mode
if t.GetNumPartitions() == 0 { if t.GetNumPartitions() == 0 {
t.NumPartitions = common.DefaultPartitionsWithPartitionKey defaultNum := common.DefaultPartitionsWithPartitionKey
if defaultNum > maxPartitionNum {
defaultNum = maxPartitionNum
}
t.NumPartitions = defaultNum
} }
idx = i idx = i

View File

@ -2964,6 +2964,7 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) {
func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
rc := NewRootCoordMock() rc := NewRootCoordMock()
paramtable.Init()
defer rc.Close() defer rc.Close()
ctx := context.Background() ctx := context.Background()
@ -3029,6 +3030,7 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
} }
t.Run("PreExecute", func(t *testing.T) { t.Run("PreExecute", func(t *testing.T) {
defer Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
var err error var err error
// test default num partitions // test default num partitions
@ -3036,6 +3038,13 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, task.GetNumPartitions()) assert.Equal(t, common.DefaultPartitionsWithPartitionKey, task.GetNumPartitions())
Params.Save(Params.RootCoordCfg.MaxPartitionNum.Key, "16")
task.NumPartitions = 0
err = task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, int64(16), task.GetNumPartitions())
Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
// test specify num partition without partition key field // test specify num partition without partition key field
partitionKeyField.IsPartitionKey = false partitionKeyField.IsPartitionKey = false
task.NumPartitions = common.DefaultPartitionsWithPartitionKey * 2 task.NumPartitions = common.DefaultPartitionsWithPartitionKey * 2
@ -3083,6 +3092,15 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
primaryField.IsPartitionKey = false primaryField.IsPartitionKey = false
// test partition num too large
Params.Save(Params.RootCoordCfg.MaxPartitionNum.Key, "16")
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
marshaledSchema, err = proto.Marshal(schema) marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err) assert.NoError(t, err)
task.Schema = marshaledSchema task.Schema = marshaledSchema

View File

@ -236,7 +236,7 @@ class TestPartitionKeyInvalidParams(TestcaseBase):
collection_w, _ = self.collection_wrap.init_collection(name=c_name, schema=schema, collection_w, _ = self.collection_wrap.init_collection(name=c_name, schema=schema,
num_partitions=num_partitions, num_partitions=num_partitions,
check_task=CheckTasks.err_res, check_task=CheckTasks.err_res,
check_items={"err_code": 2, "err_msg": err_msg}) check_items={"err_code": 1100, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_min_partitions(self): def test_min_partitions(self):