enhance: Add skip load validation for create collection task (#35737)

Related to #35415

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/35738/head
congqixia 2024-08-29 10:05:08 +08:00 committed by GitHub
parent 99dff06391
commit 985d84d3ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 191 additions and 0 deletions

View File

@ -394,6 +394,10 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
return err
}
if err := validateLoadFieldsList(t.schema); err != nil {
return err
}
t.CreateCollectionRequest.Schema, err = proto.Marshal(t.schema)
if err != nil {
return err

View File

@ -875,6 +875,22 @@ func TestCreateCollectionTask(t *testing.T) {
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidateVectorField
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for _, field := range schema.Fields {
field.TypeParams = append(field.TypeParams, &commonpb.KeyValuePair{
Key: common.FieldSkipLoadKey,
Value: "true",
})
}
// Validate default load list
skipLoadSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = skipLoadSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||

View File

@ -633,6 +633,40 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
return nil
}
func validateLoadFieldsList(schema *schemapb.CollectionSchema) error {
// ignore error if not found
// partitionKeyField, _ := s.schemaHelper.GetPartitionKeyField()
var vectorCnt int
for _, field := range schema.Fields {
shouldLoad, err := common.ShouldFieldBeLoaded(field.GetTypeParams())
if err != nil {
return err
}
// shoud load field, skip other check
if shouldLoad {
if typeutil.IsVectorType(field.GetDataType()) {
vectorCnt++
}
continue
}
if field.IsPrimaryKey {
return merr.WrapErrParameterInvalidMsg("Primary key field %s cannot skip loading", field.GetName())
}
if field.IsPartitionKey {
return merr.WrapErrParameterInvalidMsg("Partition Key field %s cannot skip loading", field.GetName())
}
}
if vectorCnt == 0 {
return merr.WrapErrParameterInvalidMsg("cannot config all vector field(s) skip loading")
}
return nil
}
// parsePrimaryFieldData2IDs get IDs to fill grpc result, for example insert request, delete request etc.
func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, error) {
primaryData := &schemapb.IDs{}

View File

@ -2472,3 +2472,140 @@ func TestGetCostValue(t *testing.T) {
assert.Equal(t, 100, cost)
})
}
func TestValidateLoadFieldsList(t *testing.T) {
type testCase struct {
tag string
schema *schemapb.CollectionSchema
expectErr bool
}
rowIDField := &schemapb.FieldSchema{
FieldID: common.RowIDField,
Name: common.RowIDFieldName,
DataType: schemapb.DataType_Int64,
}
timestampField := &schemapb.FieldSchema{
FieldID: common.TimeStampField,
Name: common.TimeStampFieldName,
DataType: schemapb.DataType_Int64,
}
pkField := &schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
}
scalarField := &schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID + 1,
Name: "text",
DataType: schemapb.DataType_VarChar,
}
partitionKeyField := &schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID + 2,
Name: "part_key",
DataType: schemapb.DataType_Int64,
IsPartitionKey: true,
}
vectorField := &schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID + 3,
Name: "vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "768"},
},
}
dynamicField := &schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID + 4,
Name: common.MetaFieldName,
DataType: schemapb.DataType_JSON,
IsDynamic: true,
}
addSkipLoadAttr := func(f *schemapb.FieldSchema, flag bool) *schemapb.FieldSchema {
result := typeutil.Clone(f)
result.TypeParams = append(f.TypeParams, &commonpb.KeyValuePair{
Key: common.FieldSkipLoadKey,
Value: strconv.FormatBool(flag),
})
return result
}
testCases := []testCase{
{
tag: "default",
schema: &schemapb.CollectionSchema{
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
rowIDField,
timestampField,
pkField,
scalarField,
partitionKeyField,
vectorField,
dynamicField,
},
},
expectErr: false,
},
{
tag: "pk_not_loaded",
schema: &schemapb.CollectionSchema{
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
rowIDField,
timestampField,
addSkipLoadAttr(pkField, true),
scalarField,
partitionKeyField,
vectorField,
dynamicField,
},
},
expectErr: true,
},
{
tag: "part_key_not_loaded",
schema: &schemapb.CollectionSchema{
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
rowIDField,
timestampField,
addSkipLoadAttr(pkField, true),
scalarField,
partitionKeyField,
vectorField,
dynamicField,
},
},
expectErr: true,
},
{
tag: "vector_not_loaded",
schema: &schemapb.CollectionSchema{
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
rowIDField,
timestampField,
pkField,
scalarField,
partitionKeyField,
addSkipLoadAttr(vectorField, true),
dynamicField,
},
},
expectErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.tag, func(t *testing.T) {
err := validateLoadFieldsList(tc.schema)
if tc.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}