Change output fields with star behavior (#24162)

Signed-off-by: Enwei Jiao <enwei.jiao@zilliz.com>
pull/24173/head
Enwei Jiao 2023-05-17 12:41:22 +08:00 committed by GitHub
parent a53beba14f
commit cb2a36ab52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 23 additions and 118 deletions

View File

@ -390,39 +390,19 @@ func TestTranslateOutputFields(t *testing.T) {
outputFields, err = translateOutputFields([]string{"*"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{" * "}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{" % "}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", "%"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%", floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%", idFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
//=========================================================================
outputFields, err = translateOutputFields([]string{}, schema, true)
@ -443,31 +423,15 @@ func TestTranslateOutputFields(t *testing.T) {
outputFields, err = translateOutputFields([]string{"*"}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%"}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", "%"}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%", floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%", idFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
}
func TestCreateCollectionTask(t *testing.T) {

View File

@ -765,20 +765,16 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach
// Support wildcard in output fields:
//
// "*" - all scalar fields
// "%" - all vector fields
// "*" - all fields
//
// For example, A and B are scalar fields, C and D are vector fields, duplicated fields will automatically be removed.
//
// output_fields=["*"] ==> [A,B]
// output_fields=["%"] ==> [C,D]
// output_fields=["*","%"] ==> [A,B,C,D]
// output_fields=["*",A] ==> [A,B]
// output_fields=["*",C] ==> [A,B,C]
// output_fields=["*"] ==> [A,B,C,D]
// output_fields=["*",A] ==> [A,B,C,D]
// output_fields=["*",C] ==> [A,B,C,D]
func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema, addPrimary bool) ([]string, error) {
var primaryFieldName string
scalarFieldNameMap := make(map[string]bool)
vectorFieldNameMap := make(map[string]bool)
allFielNameMap := make(map[string]bool)
resultFieldNameMap := make(map[string]bool)
resultFieldNames := make([]string, 0)
@ -786,21 +782,13 @@ func translateOutputFields(outputFields []string, schema *schemapb.CollectionSch
if field.IsPrimaryKey {
primaryFieldName = field.Name
}
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
vectorFieldNameMap[field.Name] = true
} else {
scalarFieldNameMap[field.Name] = true
}
allFielNameMap[field.Name] = true
}
for _, outputFieldName := range outputFields {
outputFieldName = strings.TrimSpace(outputFieldName)
if outputFieldName == "*" {
for fieldName := range scalarFieldNameMap {
resultFieldNameMap[fieldName] = true
}
} else if outputFieldName == "%" {
for fieldName := range vectorFieldNameMap {
for fieldName := range allFielNameMap {
resultFieldNameMap[fieldName] = true
}
} else {

View File

@ -449,7 +449,7 @@ class TestDeleteOperation(TestcaseBase):
"""
# init collection with nb default data
collection_w, _, _, ids = self.init_collection_general(prefix, insert_data=True)[0:4]
entity, _ = collection_w.query(tmp_expr, output_fields=["%"])
entity, _ = collection_w.query(tmp_expr, output_fields=["*"])
search_res, _ = collection_w.search([entity[0][ct.default_float_vec_field_name]],
ct.default_float_vec_field_name,
ct.default_search_params, ct.default_limit)
@ -1318,7 +1318,7 @@ class TestDeleteString(TestcaseBase):
# init collection with nb default data
collection_w, _, _, ids = self.init_collection_general(prefix, insert_data=True,
primary_field=ct.default_string_field_name)[0:4]
entity, _ = collection_w.query(default_string_expr, output_fields=["%"])
entity, _ = collection_w.query(default_string_expr, output_fields=["*"])
search_res, _ = collection_w.search([entity[0][ct.default_float_vec_field_name]],
ct.default_float_vec_field_name,
ct.default_search_params, ct.default_limit)

View File

@ -340,7 +340,7 @@ class TestQueryParams(TestcaseBase):
res.extend(df.iloc[i:i + 1, :-1].to_dict('records'))
self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
self.collection_wrap.load()
self.collection_wrap.query(term_expr, output_fields=["*"],
self.collection_wrap.query(term_expr, output_fields=["float", "int64", "int8", "varchar"],
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.fixture(scope="function", params=cf.gen_normal_expressions())
@ -411,7 +411,7 @@ class TestQueryParams(TestcaseBase):
pos = 100
term_expr = f'{field} not in {values[pos:]}'
res = df.iloc[:pos, :3].to_dict('records')
self.collection_wrap.query(term_expr, output_fields=["*"],
self.collection_wrap.query(term_expr, output_fields=["float", "int64", "varchar"],
check_task=CheckTasks.check_query_results, check_items={exp_res: res})
@pytest.mark.tags(CaseLabel.L1)
@ -678,7 +678,7 @@ class TestQueryParams(TestcaseBase):
check_items={exp_res: res, "with_vec": True})
# query with wildcard %
collection_w.query(default_term_expr, output_fields=["%"],
collection_w.query(default_term_expr, output_fields=["*"],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
@ -739,7 +739,7 @@ class TestQueryParams(TestcaseBase):
def test_query_output_fields_simple_wildcard(self):
"""
target: test query output_fields with simple wildcard (* and %)
method: specify output_fields as "*" and "*", "%"
method: specify output_fields as "*"
expected: output all scale field; output all fields
"""
# init collection with fields: int64, float, float_vec, float_vector1
@ -747,24 +747,9 @@ class TestQueryParams(TestcaseBase):
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
df = vectors[0]
# query with wildcard scale(*)
output_fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_string_field_name]
res = df.loc[:1, output_fields].to_dict('records')
collection_w.query(default_term_expr, output_fields=["*"],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res})
# query with wildcard % output_fields2 = [ct.default_int64_field_name, ct.default_float_vec_field_name,
# ct.another_float_vec_field_name]
output_fields2 = [ct.default_int64_field_name, ct.default_float_vec_field_name]
res2 = df.loc[:1, output_fields2].to_dict('records')
collection_w.query(default_term_expr, output_fields=["%"],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res2, "with_vec": True})
# query with wildcard all fields: vector(%) and scale(*)
# query with wildcard all fields
res3 = df.iloc[:2].to_dict('records')
collection_w.query(default_term_expr, output_fields=["*", "%"],
collection_w.query(default_term_expr, output_fields=["*"],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res3, "with_vec": True})
@ -780,45 +765,13 @@ class TestQueryParams(TestcaseBase):
df = vectors[0]
# query with output_fields=["*", float_vector)
res = df.iloc[:2, :4].to_dict('records')
res = df.iloc[:2].to_dict('records')
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load()
collection_w.query(default_term_expr, output_fields=["*", ct.default_float_vec_field_name],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
# query with output_fields=["*", float)
res2 = df.iloc[:2, :3].to_dict('records')
collection_w.load()
collection_w.query(default_term_expr, output_fields=["*", ct.default_float_field_name],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res2})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/12680")
def test_query_output_fields_part_vector_wildcard(self):
"""
target: test query output_fields with part wildcard
method: specify output_fields as wildcard and part field
expected: verify query result
"""
# init collection with fields: int64, float, float_vec, float_vector1
collection_w, df = self.init_multi_fields_collection_wrap(cf.gen_unique_str(prefix))
collection_w.load()
# query with output_fields=["%", float), expected: all fields
res = df.iloc[:2].to_dict('records')
collection_w.query(default_term_expr, output_fields=["%", ct.default_float_field_name],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
# query with output_fields=["%", float_vector), expected: int64, float_vector, float_vector1
output_fields = [ct.default_int64_field_name, ct.default_float_vec_field_name, ct.another_float_vec_field_name]
res2 = df.loc[:1, output_fields].to_dict('records')
collection_w.query(default_term_expr, output_fields=["%", ct.default_float_vec_field_name],
check_task=CheckTasks.check_query_results,
check_items={exp_res: res2, "with_vec": True})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])
def test_query_invalid_wildcard(self, output_fields):
@ -1766,7 +1719,7 @@ class TestQueryString(TestcaseBase):
df_dict_list = []
for df in df_list:
df_dict_list += df.to_dict('records')
output_fields = ["*", "%"]
output_fields = ["*"]
expression = "int64 >= 0"
collection_w.query(expression, output_fields=output_fields,
check_task=CheckTasks.check_query_results,

View File

@ -831,7 +831,7 @@ class TestCollectionSearchInvalid(TestcaseBase):
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="Now support output vector field")
@pytest.mark.parametrize("output_fields", [[default_search_field], ["%"]])
@pytest.mark.parametrize("output_fields", [[default_search_field], ["*"]])
def test_search_output_field_vector(self, output_fields):
"""
target: test search with vector as output field