enhance: optimize marisa trie range search for performance (#30079)

#30078
#29986

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
pull/30261/head
zhagnlu 2024-01-25 10:07:00 +08:00 committed by GitHub
parent ba862ef91d
commit 8c58d9af67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 37 deletions

View File

@ -205,4 +205,12 @@ Join(const std::vector<T>& items, const std::string& delimiter) {
return ss.str();
}
inline std::string
GetCommonPrefix(const std::string& str1, const std::string& str2) {
size_t len = std::min(str1.length(), str2.length());
size_t i = 0;
while (i < len && str1[i] == str2[i]) ++i;
return str1.substr(0, i);
}
} // namespace milvus

View File

@ -389,31 +389,64 @@ const TargetBitmap
StringIndexMarisa::Range(std::string value, OpType op) {
auto count = Count();
TargetBitmap bitset(count);
std::vector<size_t> ids;
marisa::Agent agent;
for (size_t offset = 0; offset < count; ++offset) {
agent.set_query(str_ids_[offset]);
trie_.reverse_lookup(agent);
std::string raw_data(agent.key().ptr(), agent.key().length());
bool set = false;
switch (op) {
case OpType::LessThan:
set = raw_data.compare(value) < 0;
break;
case OpType::LessEqual:
set = raw_data.compare(value) <= 0;
break;
case OpType::GreaterThan:
set = raw_data.compare(value) > 0;
break;
case OpType::GreaterEqual:
set = raw_data.compare(value) >= 0;
break;
default:
throw SegcoreError(OpTypeInvalid,
fmt::format("Invalid OperatorType: {}",
static_cast<int>(op)));
switch (op) {
case OpType::GreaterThan: {
while (trie_.predictive_search(agent)) {
auto key = std::string(agent.key().ptr(), agent.key().length());
if (key > value) {
ids.push_back(agent.key().id());
break;
}
};
while (trie_.predictive_search(agent)) {
ids.push_back(agent.key().id());
}
break;
}
if (set) {
case OpType::GreaterEqual: {
while (trie_.predictive_search(agent)) {
auto key = std::string(agent.key().ptr(), agent.key().length());
if (key >= value) {
ids.push_back(agent.key().id());
break;
}
}
while (trie_.predictive_search(agent)) {
ids.push_back(agent.key().id());
}
break;
}
case OpType::LessThan: {
while (trie_.predictive_search(agent)) {
auto key = std::string(agent.key().ptr(), agent.key().length());
if (key >= value) {
break;
}
ids.push_back(agent.key().id());
}
break;
}
case OpType::LessEqual: {
while (trie_.predictive_search(agent)) {
auto key = std::string(agent.key().ptr(), agent.key().length());
if (key > value) {
break;
}
ids.push_back(agent.key().id());
}
break;
}
default:
throw SegcoreError(
OpTypeInvalid,
fmt::format("Invalid OperatorType: {}", static_cast<int>(op)));
}
for (const auto str_id : ids) {
auto offsets = str_ids_to_offsets_[str_id];
for (auto offset : offsets) {
bitset[offset] = true;
}
}
@ -432,26 +465,38 @@ StringIndexMarisa::Range(std::string lower_bound_value,
!(lb_inclusive && ub_inclusive))) {
return bitset;
}
auto common_prefix = GetCommonPrefix(lower_bound_value, upper_bound_value);
marisa::Agent agent;
for (size_t offset = 0; offset < count; ++offset) {
agent.set_query(str_ids_[offset]);
trie_.reverse_lookup(agent);
std::string raw_data(agent.key().ptr(), agent.key().length());
bool set = true;
if (lb_inclusive) {
set &= raw_data.compare(lower_bound_value) >= 0;
} else {
set &= raw_data.compare(lower_bound_value) > 0;
agent.set_query(common_prefix.c_str());
std::vector<size_t> ids;
while (trie_.predictive_search(agent)) {
std::string_view val =
std::string_view(agent.key().ptr(), agent.key().length());
if (val > upper_bound_value ||
(!ub_inclusive && val == upper_bound_value)) {
break;
}
if (ub_inclusive) {
set &= raw_data.compare(upper_bound_value) <= 0;
} else {
set &= raw_data.compare(upper_bound_value) < 0;
if (val < lower_bound_value ||
(!lb_inclusive && val == lower_bound_value)) {
continue;
}
if (set) {
if (((lb_inclusive && lower_bound_value <= val) ||
(!lb_inclusive && lower_bound_value < val)) &&
((ub_inclusive && val <= upper_bound_value) ||
(!ub_inclusive && val < upper_bound_value))) {
ids.push_back(agent.key().id());
}
}
for (const auto str_id : ids) {
auto offsets = str_ids_to_offsets_[str_id];
for (auto offset : offsets) {
bitset[offset] = true;
}
}
return bitset;
}

View File

@ -190,3 +190,20 @@ TEST(Util, read_from_fd) {
tmp_file.fd, read_buf.get(), data_size * max_loop, INT_MAX),
milvus::SegcoreError);
}
TEST(Util, get_common_prefix) {
std::string str1 = "";
std::string str2 = "milvus";
auto common_prefix = milvus::GetCommonPrefix(str1, str2);
EXPECT_STREQ(common_prefix.c_str(), "");
str1 = "milvus";
str2 = "milvus is great";
common_prefix = milvus::GetCommonPrefix(str1, str2);
EXPECT_STREQ(common_prefix.c_str(), "milvus");
str1 = "milvus";
str2 = "";
common_prefix = milvus::GetCommonPrefix(str1, str2);
EXPECT_STREQ(common_prefix.c_str(), "");
}