// Copyright (C) 2019-2020 Zilliz. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software distributed under the License // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License #include #include #include #include #include #include #include "common/Types.h" #include "pb/plan.pb.h" #include "query/Expr.h" #include "query/ExprImpl.h" #include "query/Plan.h" #include "query/PlanNode.h" #include "query/generated/ExecExprVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "simdjson/padded_string.h" #include "test_utils/DataGen.h" #include "index/IndexFactory.h" TEST(Expr, TestArrayRange) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; std::vector>> testcases = { {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > lower_inclusive: false, upper_inclusive: false, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return 1 < val && val < 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > lower_inclusive: false, upper_inclusive: false, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return 1 < val && val < 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > lower_inclusive: true, upper_inclusive: false, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return 1 <= val && val < 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > lower_inclusive: true, upper_inclusive: false, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return 1 <= val && val < 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > lower_inclusive: false, upper_inclusive: true, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return 1 < val && val <= 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > lower_inclusive: false, upper_inclusive: true, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return 1 < val && val <= 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > lower_inclusive: true, upper_inclusive: true, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return 1 <= val && val <= 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > lower_inclusive: true, upper_inclusive: true, lower_value: < int64_val: 1 > upper_value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return 1 <= val && val <= 10000; }}, {R"(binary_range_expr: < column_info: < field_id: 104 data_type: Array nested_path:"0" element_type:VarChar > lower_inclusive: true, upper_inclusive: true, lower_value: < string_val: "aaa" > upper_value: < string_val: "zzz" > >)", "string", [](milvus::Array& array) { auto val = array.get_data(0); return "aaa" <= val && val <= "zzz"; }}, {R"(binary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"0" element_type:Float > lower_inclusive: true, upper_inclusive: true, lower_value: < float_val: 1.1 > upper_value: < float_val: 2048.12 > >)", "float", [](milvus::Array& array) { auto val = array.get_data(0); return 1.1 <= val && val <= 2048.12; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > op: GreaterEqual, value: < int64_val: 10000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val >= 10000; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > op: GreaterThan, value: < int64_val: 2000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val > 2000; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > op: LessEqual, value: < int64_val: 2000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val <= 2000; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > op: LessThan, value: < int64_val: 2000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val < 2000; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > op: Equal, value: < int64_val: 2000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val == 2000; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > op: NotEqual, value: < int64_val: 2000 > >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val != 2000; }}, {R"(unary_range_expr: < column_info: < field_id: 103 data_type: Array nested_path:"0" element_type:Bool > op: Equal, value: < bool_val: false > >)", "bool", [](milvus::Array& array) { auto val = array.get_data(0); return !val; }}, {R"(unary_range_expr: < column_info: < field_id: 104 data_type: Array nested_path:"0" element_type:VarChar > op: Equal, value: < string_val: "abc" > >)", "string", [](milvus::Array& array) { auto val = array.get_data(0); return val == "abc"; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"0" element_type:Float > op: Equal, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { auto val = array.get_data(0); return val == 2.2; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"1024" element_type:Float > op: Equal, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val == 2.2; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"1024" element_type:Float > op: NotEqual, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val != 2.2; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"1024" element_type:Float > op: GreaterEqual, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val >= 2.2; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"1024" element_type:Float > op: GreaterThan, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val > 2.2; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"1024" element_type:Float > op: LessEqual, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val <= 2.2; }}, {R"(unary_range_expr: < column_info: < field_id: 105 data_type: Array nested_path:"1024" element_type:Float > op: LessThan, value: < float_val: 2.2 > >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val < 2.2; }}, }; std::string raw_plan_tmp = R"(vector_anns: < field_id: 100 predicates: < @@@@ > query_info: < topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag: "$0" >)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto long_array_fid = schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); auto bool_array_fid = schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); auto string_array_fid = schema->AddDebugField( "string_array", DataType::ARRAY, DataType::VARCHAR); auto float_array_fid = schema->AddDebugField("double_array", DataType::ARRAY, DataType::FLOAT); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::map> array_cols; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_long_array_col = raw_data.get_col(long_array_fid); auto new_bool_array_col = raw_data.get_col(bool_array_fid); auto new_string_array_col = raw_data.get_col(string_array_fid); auto new_float_array_col = raw_data.get_col(float_array_fid); array_cols["long"].insert(array_cols["long"].end(), new_long_array_col.begin(), new_long_array_col.end()); array_cols["bool"].insert(array_cols["bool"].end(), new_bool_array_col.begin(), new_bool_array_col.end()); array_cols["string"].insert(array_cols["string"].end(), new_string_array_col.begin(), new_string_array_col.end()); array_cols["float"].insert(array_cols["float"].end(), new_float_array_col.begin(), new_float_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (auto [clause, array_type, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; raw_plan.replace(loc, 4, clause); auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols[array_type][i]); auto ref = ref_func(array); ASSERT_EQ(ans, ref); } } } TEST(Expr, TestArrayEqual) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; std::vector< std::tuple)>>> testcases = { {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array element_type:Int64 > op:Equal value:< array_val: array: array: same_type:true element_type:Int64 >> >)", [](std::vector v) { if (v.size() != 3) { return false; } for (int i = 0; i < 3; ++i) { if (v[i] != i + 1) { return false; } } return true; }}, {R"(unary_range_expr: < column_info: < field_id: 102 data_type: Array element_type:Int64 > op:NotEqual value: array: array: same_type:true element_type:Int64 >> >)", [](std::vector v) { if (v.size() != 3) { return true; } for (int i = 0; i < 3; ++i) { if (v[i] != i + 1) { return true; } } return false; }}, }; std::string raw_plan_tmp = R"(vector_anns: < field_id: 100 predicates: < @@@@ > query_info: < topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag: "$0" >)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto long_array_fid = schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector long_array_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter, 0, 1, 3); auto new_long_array_col = raw_data.get_col(long_array_fid); long_array_col.insert(long_array_col.end(), new_long_array_col.begin(), new_long_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; raw_plan.replace(loc, 4, clause); auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(long_array_col[i]); std::vector array_values(array.length()); for (int j = 0; j < array.length(); ++j) { array_values.push_back(array.get_data(j)); } auto ref = ref_func(array_values); ASSERT_EQ(ans, ref); } } } TEST(Expr, PraseArrayContainsExpr) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; std::vector raw_plans{ R"(vector_anns:< field_id:100 predicates:< json_contains_expr:< column_info:< field_id:101 data_type:Array element_type:Int64 > elements: op:Contains elements_same_type:true > > query_info:< topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag:"$0" >)", R"(vector_anns:< field_id:100 predicates:< json_contains_expr:< column_info:< field_id:101 data_type:Array element_type:Int64 > elements: elements: elements: op:ContainsAll elements_same_type:true > > query_info:< topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag:"$0" >)", R"(vector_anns:< field_id:100 predicates:< json_contains_expr:< column_info:< field_id:101 data_type:Array element_type:Int64 > elements: elements: elements: op:ContainsAny elements_same_type:true > > query_info:< topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag:"$0" >)", }; for (auto& raw_plan : raw_plans) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto schema = std::make_shared(); schema->AddDebugField( "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); schema->AddField( FieldName("array"), FieldId(101), DataType::ARRAY, DataType::INT64); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); } } template struct ArrayTestcase { std::vector term; std::vector nested_path; }; TEST(Expr, TestArrayContains) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto int_array_fid = schema->AddDebugField("int_array", DataType::ARRAY, DataType::INT8); auto long_array_fid = schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); auto bool_array_fid = schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); auto float_array_fid = schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); auto double_array_fid = schema->AddDebugField( "double_array", DataType::ARRAY, DataType::DOUBLE); auto string_array_fid = schema->AddDebugField( "string_array", DataType::ARRAY, DataType::VARCHAR); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::map> array_cols; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_int_array_col = raw_data.get_col(int_array_fid); auto new_long_array_col = raw_data.get_col(long_array_fid); auto new_bool_array_col = raw_data.get_col(bool_array_fid); auto new_float_array_col = raw_data.get_col(float_array_fid); auto new_double_array_col = raw_data.get_col(double_array_fid); auto new_string_array_col = raw_data.get_col(string_array_fid); array_cols["int"].insert(array_cols["int"].end(), new_int_array_col.begin(), new_int_array_col.end()); array_cols["long"].insert(array_cols["long"].end(), new_long_array_col.begin(), new_long_array_col.end()); array_cols["bool"].insert(array_cols["bool"].end(), new_bool_array_col.begin(), new_bool_array_col.end()); array_cols["float"].insert(array_cols["float"].end(), new_float_array_col.begin(), new_float_array_col.end()); array_cols["double"].insert(array_cols["double"].end(), new_double_array_col.begin(), new_double_array_col.end()); array_cols["string"].insert(array_cols["string"].end(), new_string_array_col.begin(), new_string_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); std::vector> bool_testcases{{{true, true}, {}}, {{false, false}, {}}}; for (auto testcase : bool_testcases) { auto check = [&](const std::vector& values) { for (auto const& e : testcase.term) { if (std::find(values.begin(), values.end(), e) != values.end()) { return true; } } return false; }; RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(bool_array_fid, DataType::ARRAY), testcase.term, true, proto::plan::JSONContainsExpr_JSONOp_Contains, proto::plan::GenericValue::ValCase::kBoolVal); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["bool"][i]); std::vector res; for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } ASSERT_EQ(ans, check(res)); } } std::vector> double_testcases{ {{1.123, 10.34}, {"double"}}, {{10.34, 100.234}, {"double"}}, {{100.234, 1000.4546}, {"double"}}, {{1000.4546, 1.123}, {"double"}}, {{1000.4546, 10.34}, {"double"}}, {{1.123, 100.234}, {"double"}}, }; for (auto testcase : double_testcases) { auto check = [&](const std::vector& values) { for (auto const& e : testcase.term) { if (std::find(values.begin(), values.end(), e) != values.end()) { return true; } } return false; }; RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(double_array_fid, DataType::ARRAY), testcase.term, true, proto::plan::JSONContainsExpr_JSONOp_Contains, proto::plan::GenericValue::ValCase::kFloatVal); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["double"][i]); std::vector res; for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } ASSERT_EQ(ans, check(res)); } } for (auto testcase : double_testcases) { auto check = [&](const std::vector& values) { for (auto const& e : testcase.term) { if (std::find(values.begin(), values.end(), e) != values.end()) { return true; } } return false; }; RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(float_array_fid, DataType::ARRAY), testcase.term, true, proto::plan::JSONContainsExpr_JSONOp_Contains, proto::plan::GenericValue::ValCase::kFloatVal); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["float"][i]); std::vector res; for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } ASSERT_EQ(ans, check(res)); } } std::vector> testcases{ {{1, 10}, {"int"}}, {{10, 100}, {"int"}}, {{100, 1000}, {"int"}}, {{1000, 10}, {"int"}}, {{2, 4, 6, 8, 10}, {"int"}}, {{1, 2, 3, 4, 5}, {"int"}}, }; for (auto testcase : testcases) { auto check = [&](const std::vector& values) { for (auto const& e : testcase.term) { if (std::find(values.begin(), values.end(), e) == values.end()) { return false; } } return true; }; RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(int_array_fid, DataType::ARRAY), testcase.term, true, proto::plan::JSONContainsExpr_JSONOp_ContainsAll, proto::plan::GenericValue::ValCase::kInt64Val); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["int"][i]); std::vector res; for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } ASSERT_EQ(ans, check(res)); } } for (auto testcase : testcases) { auto check = [&](const std::vector& values) { for (auto const& e : testcase.term) { if (std::find(values.begin(), values.end(), e) == values.end()) { return false; } } return true; }; RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(long_array_fid, DataType::ARRAY), testcase.term, true, proto::plan::JSONContainsExpr_JSONOp_ContainsAll, proto::plan::GenericValue::ValCase::kInt64Val); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["long"][i]); std::vector res; for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } ASSERT_EQ(ans, check(res)); } } std::vector> testcases_string = { {{"1sads", "10dsf"}, {"string"}}, {{"10dsf", "100"}, {"string"}}, {{"100", "10dsf", "1sads"}, {"string"}}, {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, }; for (auto testcase : testcases_string) { auto check = [&](const std::vector& values) { for (auto const& e : testcase.term) { if (std::find(values.begin(), values.end(), e) == values.end()) { return false; } } return true; }; RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(string_array_fid, DataType::ARRAY), testcase.term, true, proto::plan::JSONContainsExpr_JSONOp_ContainsAll, proto::plan::GenericValue::ValCase::kStringVal); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["string"][i]); std::vector res; for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } ASSERT_EQ(ans, check(res)); } } } TEST(Expr, TestArrayBinaryArith) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto int_array_fid = schema->AddDebugField("int_array", DataType::ARRAY, DataType::INT8); auto long_array_fid = schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); auto float_array_fid = schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); auto double_array_fid = schema->AddDebugField( "double_array", DataType::ARRAY, DataType::DOUBLE); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::map> array_cols; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_int_array_col = raw_data.get_col(int_array_fid); auto new_long_array_col = raw_data.get_col(long_array_fid); auto new_float_array_col = raw_data.get_col(float_array_fid); auto new_double_array_col = raw_data.get_col(double_array_fid); array_cols["int"].insert(array_cols["int"].end(), new_int_array_col.begin(), new_int_array_col.end()); array_cols["long"].insert(array_cols["long"].end(), new_long_array_col.begin(), new_long_array_col.end()); array_cols["float"].insert(array_cols["float"].end(), new_float_array_col.begin(), new_float_array_col.end()); array_cols["double"].insert(array_cols["double"].end(), new_double_array_col.begin(), new_double_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); std::vector>> testcases = { {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 data_type: Array nested_path:"0" element_type:Int8 > arith_op:Add right_operand: op:Equal value: >)", "int", [](milvus::Array& array) { auto val = array.get_data(0); return val + 2 == 5; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 data_type: Array nested_path:"0" element_type:Int8 > arith_op:Add right_operand: op:NotEqual value: >)", "int", [](milvus::Array& array) { auto val = array.get_data(0); return val + 2 != 5; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Sub right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val - 1 == 144; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Sub right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val - 1 != 144; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Array nested_path:"0" element_type:Float > arith_op:Add right_operand: op:Equal value: >)", "float", [](milvus::Array& array) { auto val = array.get_data(0); return val + 2.2 == 133.2; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Array nested_path:"0" element_type:Float > arith_op:Add right_operand: op:NotEqual value: >)", "float", [](milvus::Array& array) { auto val = array.get_data(0); return val + 2.2 != 133.2; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 104 data_type: Array nested_path:"0" element_type:Double > arith_op:Sub right_operand: op:Equal value: >)", "double", [](milvus::Array& array) { auto val = array.get_data(0); return val - 11.1 == 125.7; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 104 data_type: Array nested_path:"0" element_type:Double > arith_op:Sub right_operand: op:NotEqual value: >)", "double", [](milvus::Array& array) { auto val = array.get_data(0); return val - 11.1 != 125.7; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Mul right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val * 2 == 8; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Mul right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val * 2 != 20; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Div right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val / 2 == 8; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Div right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val / 2 != 20; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Mod right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val % 3 == 0; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Int64 > arith_op:Mod right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val % 3 != 2; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Array nested_path:"1024" element_type:Float > arith_op:Add right_operand: op:Equal value: >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val + 2.2 == 133.2; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Array nested_path:"1024" element_type:Float > arith_op:Add right_operand: op:NotEqual value: >)", "float", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val + 2.2 != 133.2; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 104 data_type: Array nested_path:"1024" element_type:Double > arith_op:Sub right_operand: op:Equal value: >)", "double", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val - 11.1 == 125.7; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 104 data_type: Array nested_path:"1024" element_type:Double > arith_op:Sub right_operand: op:NotEqual value: >)", "double", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val - 11.1 != 125.7; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > arith_op:Mul right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val * 2 == 8; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > arith_op:Mul right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val * 2 != 20; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > arith_op:Div right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val / 2 == 8; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > arith_op:Div right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val / 2 != 20; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > arith_op:Mod right_operand: op:Equal value: >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val % 3 == 0; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Array nested_path:"1024" element_type:Int64 > arith_op:Mod right_operand: op:NotEqual value: >)", "long", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val % 3 != 2; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 data_type: Array nested_path:"0" element_type:Int8 > arith_op:ArrayLength op:Equal value: >)", "int", [](milvus::Array& array) { return array.length() == 10; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 data_type: Array nested_path:"0" element_type:Int8 > arith_op:ArrayLength op:NotEqual value: >)", "int", [](milvus::Array& array) { return array.length() != 8; }}, }; std::string raw_plan_tmp = R"(vector_anns: < field_id: 100 predicates: < @@@@ > query_info: < topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag: "$0" >)"; for (auto [clause, array_type, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; raw_plan.replace(loc, 4, clause); auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols[array_type][i]); auto ref = ref_func(array); ASSERT_EQ(ans, ref); } } } template struct UnaryRangeTestcase { milvus::OpType op_type; T value; std::vector nested_path; std::function check_func; }; TEST(Expr, TestArrayStringMatch) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto string_array_fid = schema->AddDebugField( "string_array", DataType::ARRAY, DataType::VARCHAR); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::map> array_cols; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_string_array_col = raw_data.get_col(string_array_fid); array_cols["string"].insert(array_cols["string"].end(), new_string_array_col.begin(), new_string_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); std::vector> prefix_testcases{ {OpType::PrefixMatch, "abc", {"0"}, [](milvus::Array& array) { return PrefixMatch(array.get_data(0), "abc"); }}, {OpType::PrefixMatch, "def", {"1"}, [](milvus::Array& array) { return PrefixMatch(array.get_data(1), "def"); }}, {OpType::PrefixMatch, "def", {"1024"}, [](milvus::Array& array) { if (array.length() <= 1024) { return false; } return PrefixMatch(array.get_data(1024), "def"); }}, }; //vector_anns: op:PrefixMatch value: > > query_info:<> placeholder_tag:"$0" > for (auto& testcase : prefix_testcases) { RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(string_array_fid, DataType::ARRAY, testcase.nested_path), testcase.op_type, testcase.value, proto::plan::GenericValue::ValCase::kStringVal); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["string"][i]); ASSERT_EQ(ans, testcase.check_func(array)); } } } TEST(Expr, TestArrayInTerm) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto long_array_fid = schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); auto bool_array_fid = schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); auto float_array_fid = schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); auto string_array_fid = schema->AddDebugField( "string_array", DataType::ARRAY, DataType::VARCHAR); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::map> array_cols; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_long_array_col = raw_data.get_col(long_array_fid); auto new_bool_array_col = raw_data.get_col(bool_array_fid); auto new_float_array_col = raw_data.get_col(float_array_fid); auto new_string_array_col = raw_data.get_col(string_array_fid); array_cols["long"].insert(array_cols["long"].end(), new_long_array_col.begin(), new_long_array_col.end()); array_cols["bool"].insert(array_cols["bool"].end(), new_bool_array_col.begin(), new_bool_array_col.end()); array_cols["float"].insert(array_cols["float"].end(), new_float_array_col.begin(), new_float_array_col.end()); array_cols["string"].insert(array_cols["string"].end(), new_string_array_col.begin(), new_string_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); std::vector>> testcases = { {R"(term_expr: < column_info: < field_id: 101 data_type: Array nested_path:"0" element_type:Int64 > values: values: values: >)", "long", [](milvus::Array& array) { auto val = array.get_data(0); return val == 1 || val ==2 || val == 3; }}, {R"(term_expr: < column_info: < field_id: 101 data_type: Array nested_path:"0" element_type:Int64 > >)", "long", [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Bool > values: values: >)", "bool", [](milvus::Array& array) { auto val = array.get_data(0); return !val; }}, {R"(term_expr: < column_info: < field_id: 102 data_type: Array nested_path:"0" element_type:Bool > >)", "bool", [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 103 data_type: Array nested_path:"0" element_type:Float > values: values: >)", "float", [](milvus::Array& array) { auto val = array.get_data(0); return val == 1.23 || val == 124.31; }}, {R"(term_expr: < column_info: < field_id: 103 data_type: Array nested_path:"0" element_type:Float > >)", "float", [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 104 data_type: Array nested_path:"0" element_type:VarChar > values: values: >)", "string", [](milvus::Array& array) { auto val = array.get_data(0); return val == "abc" || val == "idhgf1s"; }}, {R"(term_expr: < column_info: < field_id: 104 data_type: Array nested_path:"0" element_type:VarChar > >)", "string", [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 104 data_type: Array nested_path:"1024" element_type:VarChar > values: values: >)", "string", [](milvus::Array& array) { if (array.length() <= 1024) { return false; } auto val = array.get_data(1024); return val == "abc" || val == "idhgf1s"; }}, }; std::string raw_plan_tmp = R"(vector_anns: < field_id: 100 predicates: < @@@@ > query_info: < topk: 10 round_decimal: 3 metric_type: "L2" search_params: "{\"nprobe\": 10}" > placeholder_tag: "$0" >)"; for (auto [clause, array_type, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; raw_plan.replace(loc, 4, clause); auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols[array_type][i]); ASSERT_EQ(ans, ref_func(array)); } } } TEST(Expr, TestTermInArray) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto long_array_fid = schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::map> array_cols; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_long_array_col = raw_data.get_col(long_array_fid); array_cols["long"].insert(array_cols["long"].end(), new_long_array_col.begin(), new_long_array_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); struct TermTestCases { std::vector values; std::vector nested_path; std::function check_func; }; std::vector testcases = { {{100}, {}, [](milvus::Array& array) { for (int i = 0; i < array.length(); ++i) { auto val = array.get_data(i); if (val == 100) { return true; } } return false; }}, {{1024}, {}, [](milvus::Array& array) { for (int i = 0; i < array.length(); ++i) { auto val = array.get_data(i); if (val == 1024) { return true; } } return false; }}, }; for (auto& testcase : testcases) { RetrievePlanNode plan; plan.predicate_ = std::make_unique>( ColumnInfo(long_array_fid, DataType::ARRAY, testcase.nested_path), testcase.values, proto::plan::GenericValue::ValCase::kInt64Val, true); auto start = std::chrono::steady_clock::now(); auto final = visitor.call_child(*plan.predicate_.value()); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count() << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto array = milvus::Array(array_cols["long"][i]); ASSERT_EQ(ans, testcase.check_func(array)); } } }