enhance: skip segment when using pk in (..) expr (#29394)

#29293

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
pull/29297/head
zhagnlu 2023-12-21 20:06:42 +08:00 committed by GitHub
parent ad71b9aeb2
commit a6eb7e5f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 107 additions and 3 deletions

View File

@ -117,11 +117,35 @@ PhyTermFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
}
}
template <typename T>
bool
PhyTermFilterExpr::CanSkipSegment() {
const auto& skip_index = segment_->GetSkipIndex();
T min, max;
for (auto i = 0; i < expr_->vals_.size(); i++) {
auto val = GetValueFromProto<T>(expr_->vals_[i]);
max = i == 0 ? val : std::max(val, max);
min = i == 0 ? val : std::min(val, min);
}
// using skip index to help skipping this segment
if (segment_->type() == SegmentType::Sealed &&
skip_index.CanSkipBinaryRange<T>(field_id_, 0, min, max, true, true)) {
cached_bits_.resize(num_rows_, false);
cached_offsets_ = std::make_shared<ColumnVector>(DataType::INT64, 0);
cached_offsets_inited_ = true;
return true;
}
return false;
}
void
PhyTermFilterExpr::InitPkCacheOffset() {
auto id_array = std::make_unique<IdArray>();
switch (pk_type_) {
case DataType::INT64: {
if (CanSkipSegment<int64_t>()) {
return;
}
auto dst_ids = id_array->mutable_int_id();
for (const auto& id : expr_->vals_) {
dst_ids->add_data(GetValueFromProto<int64_t>(id));
@ -129,6 +153,9 @@ PhyTermFilterExpr::InitPkCacheOffset() {
break;
}
case DataType::VARCHAR: {
if (CanSkipSegment<std::string>()) {
return;
}
auto dst_ids = id_array->mutable_str_id();
for (const auto& id : expr_->vals_) {
dst_ids->add_data(GetValueFromProto<std::string>(id));
@ -142,7 +169,7 @@ PhyTermFilterExpr::InitPkCacheOffset() {
auto [uids, seg_offsets] =
segment_->search_ids(*id_array, query_timestamp_);
cached_bits_.resize(num_rows_);
cached_bits_.resize(num_rows_, false);
cached_offsets_ =
std::make_shared<ColumnVector>(DataType::INT64, seg_offsets.size());
int64_t* cached_offsets_ptr = (int64_t*)cached_offsets_->GetRawData();
@ -164,7 +191,6 @@ PhyTermFilterExpr::ExecPkTermImpl() {
auto real_batch_size = current_data_chunk_pos_ + batch_size_ >= num_rows_
? num_rows_ - current_data_chunk_pos_
: batch_size_;
current_data_chunk_pos_ += real_batch_size;
if (real_batch_size == 0) {
return nullptr;
@ -175,7 +201,7 @@ PhyTermFilterExpr::ExecPkTermImpl() {
bool* res = (bool*)res_vec->GetRawData();
for (size_t i = 0; i < real_batch_size; ++i) {
res[i] = cached_bits_[i];
res[i] = cached_bits_[current_data_chunk_pos_++];
}
std::vector<VectorPtr> vecs{res_vec, cached_offsets_};

View File

@ -84,6 +84,10 @@ class PhyTermFilterExpr : public SegmentExpr {
void
InitPkCacheOffset();
template <typename T>
bool
CanSkipSegment();
VectorPtr
ExecPkTermImpl();

View File

@ -111,6 +111,13 @@ ExecPlanNodeVisitor::ExecuteExprNodeInternal(
// offset cache only get once because not support iterator batch
auto cache_offset_vec =
std::dynamic_pointer_cast<ColumnVector>(row->child(1));
// If get empty cached offsets. mean no record hits in this segment
// no need to get next batch.
if (cache_offset_vec->size() == 0) {
auto active_count = segment->get_active_count(timestamp_);
bitset_holder.resize(active_count);
break;
}
auto cache_offset_vec_ptr =
(int64_t*)(cache_offset_vec->GetRawData());
for (size_t i = 0; i < cache_offset_vec->size(); ++i) {

View File

@ -1730,6 +1730,73 @@ TEST(Expr, TestExprs) {
// test_case(500);
}
TEST(Expr, test_term_pk) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
schema->AddField(FieldName("Timestamp"), FieldId(1), DataType::INT64);
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
schema->set_primary_field_id(int64_fid);
auto seg = CreateSealedSegment(schema);
int N = 100000;
auto raw_data = DataGen(schema, N);
// load field data
auto fields = schema->get_fields();
for (auto field_data : raw_data.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a");
auto field_meta = fields.at(FieldId(field_id));
info.channel->push(
CreateFieldDataFromDataArray(N, &field_data, field_meta));
info.channel->close();
seg->LoadFieldData(FieldId(field_id), info);
}
std::vector<proto::plan::GenericValue> retrieve_ints;
for (int i = 0; i < 10; ++i) {
proto::plan::GenericValue val;
val.set_int64_val(i);
retrieve_ints.push_back(val);
}
auto expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints);
query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP);
BitsetType final;
auto plan =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
visitor.ExecuteExprNode(plan, seg.get(), final);
EXPECT_EQ(final.size(), N);
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(final[i], true);
}
for (int i = 10; i < N; ++i) {
EXPECT_EQ(final[i], false);
}
retrieve_ints.clear();
for (int i = 0; i < 10; ++i) {
proto::plan::GenericValue val;
val.set_int64_val(i + N);
retrieve_ints.push_back(val);
}
expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
visitor.ExecuteExprNode(plan, seg.get(), final);
EXPECT_EQ(final.size(), N);
for (int i = 0; i < N; ++i) {
EXPECT_EQ(final[i], false);
}
}
TEST(Expr, TestSealedSegmentGetBatchSize) {
using namespace milvus;
using namespace milvus::query;