From 140c5a0a75f76ecf9aa08841ac82536fde918107 Mon Sep 17 00:00:00 2001 From: Bingyi Sun Date: Mon, 20 Jan 2025 19:03:04 +0800 Subject: [PATCH] enhance: add unit test for string pk (#39329) https://github.com/milvus-io/milvus/issues/39107 --------- Signed-off-by: sunby --- .../core/unittest/test_chunked_segment.cpp | 87 ++++++++++++++----- 1 file changed, 64 insertions(+), 23 deletions(-) diff --git a/internal/core/unittest/test_chunked_segment.cpp b/internal/core/unittest/test_chunked_segment.cpp index bbfef50c04..55fe403f9b 100644 --- a/internal/core/unittest/test_chunked_segment.cpp +++ b/internal/core/unittest/test_chunked_segment.cpp @@ -21,6 +21,7 @@ #include "common/Schema.h" #include "common/Types.h" #include "expr/ITypeExpr.h" +#include "gtest/gtest.h" #include "index/IndexFactory.h" #include "index/IndexInfo.h" #include "index/Meta.h" @@ -160,13 +161,16 @@ TEST(test_chunk_segment, TestSearchOnSealed) { } } -class TestChunkSegment : public testing::Test { +class TestChunkSegment : public testing::TestWithParam { protected: void SetUp() override { + bool pk_is_string = GetParam(); auto schema = std::make_shared(); auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true); - auto pk_fid = schema->AddDebugField("pk", DataType::INT64, true); + + auto pk_fid = schema->AddDebugField( + "pk", pk_is_string ? DataType::VARCHAR : DataType::INT64, true); auto str_fid = schema->AddDebugField("string1", DataType::VARCHAR, true); auto str2_fid = @@ -185,10 +189,11 @@ class TestChunkSegment : public testing::Test { test_data_count = 10000; auto arrow_i64_field = arrow::field("int64", arrow::int64()); - auto arrow_pk_field = arrow::field("pk", arrow::int64()); + auto arrow_pk_field = + arrow::field("pk", pk_is_string ? arrow::utf8() : arrow::int64()); auto arrow_ts_field = arrow::field("ts", arrow::int64()); - auto arrow_str_field = arrow::field("string1", arrow::int64()); - auto arrow_str2_field = arrow::field("string2", arrow::int64()); + auto arrow_str_field = arrow::field("string1", arrow::utf8()); + auto arrow_str2_field = arrow::field("string2", arrow::utf8()); std::vector> arrow_fields = { arrow_i64_field, arrow_pk_field, @@ -204,7 +209,7 @@ class TestChunkSegment : public testing::Test { {"string1", str_fid}, {"string2", str2_fid}}; - int start_id = 1; + int start_id = 0; chunk_num = 2; std::vector field_infos; @@ -215,6 +220,12 @@ class TestChunkSegment : public testing::Test { field_infos.push_back(field_info); } + std::vector str_data; + for (int i = 0; i < test_data_count * chunk_num; i++) { + str_data.push_back("test" + std::to_string(i)); + } + std::sort(str_data.begin(), str_data.end()); + // generate data for (int chunk_id = 0; chunk_id < chunk_num; chunk_id++, start_id += test_data_count) { @@ -232,7 +243,7 @@ class TestChunkSegment : public testing::Test { auto str_builder = std::make_shared(); for (int i = 0; i < test_data_count; i++) { - auto status = str_builder->Append("test" + std::to_string(i)); + auto status = str_builder->Append(str_data[start_id + i]); ASSERT_TRUE(status.ok()); } std::shared_ptr arrow_str; @@ -245,7 +256,9 @@ class TestChunkSegment : public testing::Test { auto arrow_schema = std::make_shared(arrow::FieldVector(1, f)); - auto col = i < 3 ? arrow_int64 : arrow_str; + auto col = i < 3 && (field_ids[i] != pk_fid || !pk_is_string) + ? arrow_int64 + : arrow_str; auto record_batch = arrow::RecordBatch::Make( arrow_schema, arrow_int64->length(), {col}); @@ -272,7 +285,10 @@ class TestChunkSegment : public testing::Test { std::unordered_map fields; }; -TEST_F(TestChunkSegment, TestTermExpr) { +INSTANTIATE_TEST_SUITE_P(TestChunkSegment, TestChunkSegment, testing::Bool()); + +TEST_P(TestChunkSegment, TestTermExpr) { + bool pk_is_string = GetParam(); // query int64 expr std::vector filter_data; for (int i = 1; i <= 10; ++i) { @@ -289,9 +305,17 @@ TEST_F(TestChunkSegment, TestTermExpr) { plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP); ASSERT_EQ(10, final.count()); + std::vector filter_str_data; + for (int i = 1; i <= 10; ++i) { + proto::plan::GenericValue v; + v.set_string_val("test" + std::to_string(i)); + filter_str_data.push_back(v); + } // query pk expr auto pk_term_filter_expr = std::make_shared( - expr::ColumnInfo(fields.at("pk"), DataType::INT64), filter_data); + expr::ColumnInfo(fields.at("pk"), + pk_is_string ? DataType::VARCHAR : DataType::INT64), + pk_is_string ? filter_str_data : filter_data); plan = std::make_shared(DEFAULT_PLANNODE_ID, pk_term_filter_expr); final = query::ExecuteQueryExpr( @@ -301,10 +325,17 @@ TEST_F(TestChunkSegment, TestTermExpr) { // query pk in second chunk std::vector filter_data2; proto::plan::GenericValue v; - v.set_int64_val(test_data_count + 1); + if (pk_is_string) { + v.set_string_val("test" + std::to_string(test_data_count + 1)); + } else { + v.set_int64_val(test_data_count + 1); + } filter_data2.push_back(v); + pk_term_filter_expr = std::make_shared( - expr::ColumnInfo(fields.at("pk"), DataType::INT64), filter_data2); + expr::ColumnInfo(fields.at("pk"), + pk_is_string ? DataType::VARCHAR : DataType::INT64), + filter_data2); plan = std::make_shared(DEFAULT_PLANNODE_ID, pk_term_filter_expr); final = query::ExecuteQueryExpr( @@ -312,12 +343,16 @@ TEST_F(TestChunkSegment, TestTermExpr) { ASSERT_EQ(1, final.count()); } -TEST_F(TestChunkSegment, TestCompareExpr) { - auto expr = std::make_shared(fields.at("int64"), - fields.at("pk"), - DataType::INT64, - DataType::INT64, - proto::plan::OpType::Equal); +TEST_P(TestChunkSegment, TestCompareExpr) { + srand(time(NULL)); + bool pk_is_string = GetParam(); + DataType pk_data_type = pk_is_string ? DataType::VARCHAR : DataType::INT64; + auto expr = std::make_shared( + pk_is_string ? fields.at("string1") : fields.at("int64"), + fields.at("pk"), + pk_data_type, + pk_data_type, + proto::plan::OpType::Equal); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final = query::ExecuteQueryExpr( @@ -341,6 +376,11 @@ TEST_F(TestChunkSegment, TestCompareExpr) { milvus::proto::schema::Int64); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(fid.get()); file_manager_ctx.fieldDataMeta.field_id = fid.get(); + milvus::storage::IndexMeta index_meta; + index_meta.field_id = fid.get(); + index_meta.build_id = rand(); + index_meta.index_version = rand(); + file_manager_ctx.indexMeta = index_meta; index::CreateIndexInfo create_index_info; create_index_info.field_type = DataType::INT64; create_index_info.index_type = index::INVERTED_INDEX_TYPE; @@ -360,11 +400,12 @@ TEST_F(TestChunkSegment, TestCompareExpr) { load_index_info.field_id = fid.get(); segment->LoadIndex(load_index_info); - expr = std::make_shared(fields.at("int64"), - fields.at("pk"), - DataType::INT64, - DataType::INT64, - proto::plan::OpType::Equal); + expr = std::make_shared( + pk_is_string ? fields.at("string1") : fields.at("int64"), + fields.at("pk"), + pk_data_type, + pk_data_type, + proto::plan::OpType::Equal); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); final = query::ExecuteQueryExpr( plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);