fix: mask with valid data when preCheckOverflow (#37221)

#37175

---------

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/37327/head
smellthemoon 2024-10-31 10:44:26 +08:00 committed by GitHub
parent 2092dc0ba1
commit b8492498ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 125 additions and 38 deletions

View File

@ -147,15 +147,9 @@ PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1,
? active_count_ - overflow_check_pos_
: batch_size_;
overflow_check_pos_ += batch_size;
if (cached_overflow_res_ != nullptr &&
cached_overflow_res_->size() == batch_size) {
return cached_overflow_res_;
}
auto valid_res = ProcessChunksForValid<T>(is_index_mode_);
auto res_vec = std::make_shared<ColumnVector>(TargetBitmap(batch_size),
std::move(valid_res));
cached_overflow_res_ = res_vec;
return res_vec;
};

View File

@ -235,7 +235,6 @@ class PhyBinaryRangeFilterExpr : public SegmentExpr {
private:
std::shared_ptr<const milvus::expr::BinaryRangeFilterExpr> expr_;
ColumnVectorPtr cached_overflow_res_{nullptr};
int64_t overflow_check_pos_{0};
};
} //namespace exec

View File

@ -503,7 +503,7 @@ class SegmentExpr : public Expr {
template <typename T>
TargetBitmap
ProcessDataChunksForValid() {
TargetBitmap valid_result(batch_size_);
TargetBitmap valid_result(GetNextBatchSize());
valid_result.set();
int64_t processed_size = 0;
for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) {

View File

@ -754,57 +754,37 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() {
? active_count_ - overflow_check_pos_
: batch_size_;
overflow_check_pos_ += batch_size;
if (cached_overflow_res_ != nullptr &&
cached_overflow_res_->size() == batch_size) {
return cached_overflow_res_;
}
auto valid = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
TargetBitmapView valid_res(res_vec->GetValidRawData(), batch_size);
switch (expr_->op_type_) {
case proto::plan::GreaterThan:
case proto::plan::GreaterEqual: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;
if (milvus::query::lt_lb<T>(val)) {
res.set();
res &= valid_res;
return res_vec;
}
return res_vec;
}
case proto::plan::LessThan:
case proto::plan::LessEqual: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;
if (milvus::query::gt_ub<T>(val)) {
res.set();
res &= valid_res;
return res_vec;
}
return res_vec;
}
case proto::plan::Equal: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;
res.reset();
return res_vec;
}
case proto::plan::NotEqual: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;
res.set();
res &= valid_res;
return res_vec;
}
default: {

View File

@ -346,7 +346,6 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr {
private:
std::shared_ptr<const milvus::expr::UnaryRangeFilterExpr> expr_;
ColumnVectorPtr cached_overflow_res_{nullptr};
int64_t overflow_check_pos_{0};
};
} // namespace exec

View File

@ -560,6 +560,117 @@ TEST_P(ExprTest, TestRangeNullable) {
}
return v != 2000;
}},
{R"(binary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
lower_inclusive: false,
upper_inclusive: false,
lower_value: <
int64_val: 1000000
>
upper_value: <
int64_val: 1000001
>
>)",
[](int v, bool valid) { return false; }},
{R"(binary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
lower_inclusive: false,
upper_inclusive: false,
lower_value: <
int64_val: -1000001
>
upper_value: <
int64_val: -1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: GreaterEqual,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: GreaterEqual,
value: <
int64_val: -1000000
>
>)",
[](int v, bool valid) {
if (!valid) {
return false;
}
return true;
}},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: LessEqual,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) {
if (!valid) {
return false;
}
return true;
}},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: LessThan,
value: <
int64_val: -1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: Equal,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: NotEqual,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) {
if (!valid) {
return false;
}
return true;
}},
};
std::string raw_plan_tmp = R"(vector_anns: <
@ -582,6 +693,9 @@ TEST_P(ExprTest, TestRangeNullable) {
auto nullable_fid =
schema->AddDebugField("nullable", DataType::INT64, true);
auto nullable_fid_pre_check =
schema->AddDebugField("pre_check", DataType::INT8, true);
auto seg = CreateGrowingSegment(schema, empty_index_meta);
int N = 1000;
std::vector<int> data_col;
@ -625,7 +739,8 @@ TEST_P(ExprTest, TestRangeNullable) {
auto val = data_col[i];
auto valid_data = valid_data_col[i];
auto ref = ref_func(val, valid_data);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
ASSERT_EQ(ans, ref)
<< clause << "@" << i << "!!" << val << "!!" << valid_data;
}
}
}