mirror of https://github.com/milvus-io/milvus.git
enhance: check the partition num when creating collection with partition key (#32670)
issue: #30577 Signed-off-by: SimFG <bang.fu@zilliz.com>pull/32678/head
parent
cf4db3ff4e
commit
7da1ca9efb
|
@ -206,9 +206,19 @@ func (t *createCollectionTask) validatePartitionKey() error {
|
||||||
return errors.New("the specified partitions should be greater than 0 if partition key is used")
|
return errors.New("the specified partitions should be greater than 0 if partition key is used")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxPartitionNum := Params.RootCoordCfg.MaxPartitionNum.GetAsInt64()
|
||||||
|
if t.GetNumPartitions() > maxPartitionNum {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("partition number (%d) exceeds max configuration (%d)",
|
||||||
|
t.GetNumPartitions(), maxPartitionNum)
|
||||||
|
}
|
||||||
|
|
||||||
// set default physical partitions num if enable partition key mode
|
// set default physical partitions num if enable partition key mode
|
||||||
if t.GetNumPartitions() == 0 {
|
if t.GetNumPartitions() == 0 {
|
||||||
t.NumPartitions = common.DefaultPartitionsWithPartitionKey
|
defaultNum := common.DefaultPartitionsWithPartitionKey
|
||||||
|
if defaultNum > maxPartitionNum {
|
||||||
|
defaultNum = maxPartitionNum
|
||||||
|
}
|
||||||
|
t.NumPartitions = defaultNum
|
||||||
}
|
}
|
||||||
|
|
||||||
idx = i
|
idx = i
|
||||||
|
|
|
@ -2964,6 +2964,7 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) {
|
||||||
|
|
||||||
func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
|
func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
|
||||||
rc := NewRootCoordMock()
|
rc := NewRootCoordMock()
|
||||||
|
paramtable.Init()
|
||||||
|
|
||||||
defer rc.Close()
|
defer rc.Close()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -3029,6 +3030,7 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("PreExecute", func(t *testing.T) {
|
t.Run("PreExecute", func(t *testing.T) {
|
||||||
|
defer Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// test default num partitions
|
// test default num partitions
|
||||||
|
@ -3036,6 +3038,13 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, task.GetNumPartitions())
|
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, task.GetNumPartitions())
|
||||||
|
|
||||||
|
Params.Save(Params.RootCoordCfg.MaxPartitionNum.Key, "16")
|
||||||
|
task.NumPartitions = 0
|
||||||
|
err = task.PreExecute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(16), task.GetNumPartitions())
|
||||||
|
Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
|
||||||
|
|
||||||
// test specify num partition without partition key field
|
// test specify num partition without partition key field
|
||||||
partitionKeyField.IsPartitionKey = false
|
partitionKeyField.IsPartitionKey = false
|
||||||
task.NumPartitions = common.DefaultPartitionsWithPartitionKey * 2
|
task.NumPartitions = common.DefaultPartitionsWithPartitionKey * 2
|
||||||
|
@ -3083,6 +3092,15 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
primaryField.IsPartitionKey = false
|
primaryField.IsPartitionKey = false
|
||||||
|
|
||||||
|
// test partition num too large
|
||||||
|
Params.Save(Params.RootCoordCfg.MaxPartitionNum.Key, "16")
|
||||||
|
marshaledSchema, err = proto.Marshal(schema)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
task.Schema = marshaledSchema
|
||||||
|
err = task.PreExecute(ctx)
|
||||||
|
assert.Error(t, err)
|
||||||
|
Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
|
||||||
|
|
||||||
marshaledSchema, err = proto.Marshal(schema)
|
marshaledSchema, err = proto.Marshal(schema)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
task.Schema = marshaledSchema
|
task.Schema = marshaledSchema
|
||||||
|
|
|
@ -236,7 +236,7 @@ class TestPartitionKeyInvalidParams(TestcaseBase):
|
||||||
collection_w, _ = self.collection_wrap.init_collection(name=c_name, schema=schema,
|
collection_w, _ = self.collection_wrap.init_collection(name=c_name, schema=schema,
|
||||||
num_partitions=num_partitions,
|
num_partitions=num_partitions,
|
||||||
check_task=CheckTasks.err_res,
|
check_task=CheckTasks.err_res,
|
||||||
check_items={"err_code": 2, "err_msg": err_msg})
|
check_items={"err_code": 1100, "err_msg": err_msg})
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_min_partitions(self):
|
def test_min_partitions(self):
|
||||||
|
|
Loading…
Reference in New Issue