mirror of https://github.com/milvus-io/milvus.git
enhance: optimize marisa trie range search for performance (#30079)
#30078 #29986 Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>pull/30261/head
parent
ba862ef91d
commit
8c58d9af67
|
@ -205,4 +205,12 @@ Join(const std::vector<T>& items, const std::string& delimiter) {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
inline std::string
|
||||
GetCommonPrefix(const std::string& str1, const std::string& str2) {
|
||||
size_t len = std::min(str1.length(), str2.length());
|
||||
size_t i = 0;
|
||||
while (i < len && str1[i] == str2[i]) ++i;
|
||||
return str1.substr(0, i);
|
||||
}
|
||||
|
||||
} // namespace milvus
|
||||
|
|
|
@ -389,31 +389,64 @@ const TargetBitmap
|
|||
StringIndexMarisa::Range(std::string value, OpType op) {
|
||||
auto count = Count();
|
||||
TargetBitmap bitset(count);
|
||||
std::vector<size_t> ids;
|
||||
marisa::Agent agent;
|
||||
for (size_t offset = 0; offset < count; ++offset) {
|
||||
agent.set_query(str_ids_[offset]);
|
||||
trie_.reverse_lookup(agent);
|
||||
std::string raw_data(agent.key().ptr(), agent.key().length());
|
||||
bool set = false;
|
||||
switch (op) {
|
||||
case OpType::LessThan:
|
||||
set = raw_data.compare(value) < 0;
|
||||
break;
|
||||
case OpType::LessEqual:
|
||||
set = raw_data.compare(value) <= 0;
|
||||
break;
|
||||
case OpType::GreaterThan:
|
||||
set = raw_data.compare(value) > 0;
|
||||
break;
|
||||
case OpType::GreaterEqual:
|
||||
set = raw_data.compare(value) >= 0;
|
||||
break;
|
||||
default:
|
||||
throw SegcoreError(OpTypeInvalid,
|
||||
fmt::format("Invalid OperatorType: {}",
|
||||
static_cast<int>(op)));
|
||||
switch (op) {
|
||||
case OpType::GreaterThan: {
|
||||
while (trie_.predictive_search(agent)) {
|
||||
auto key = std::string(agent.key().ptr(), agent.key().length());
|
||||
if (key > value) {
|
||||
ids.push_back(agent.key().id());
|
||||
break;
|
||||
}
|
||||
};
|
||||
while (trie_.predictive_search(agent)) {
|
||||
ids.push_back(agent.key().id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (set) {
|
||||
case OpType::GreaterEqual: {
|
||||
while (trie_.predictive_search(agent)) {
|
||||
auto key = std::string(agent.key().ptr(), agent.key().length());
|
||||
if (key >= value) {
|
||||
ids.push_back(agent.key().id());
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (trie_.predictive_search(agent)) {
|
||||
ids.push_back(agent.key().id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OpType::LessThan: {
|
||||
while (trie_.predictive_search(agent)) {
|
||||
auto key = std::string(agent.key().ptr(), agent.key().length());
|
||||
if (key >= value) {
|
||||
break;
|
||||
}
|
||||
ids.push_back(agent.key().id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OpType::LessEqual: {
|
||||
while (trie_.predictive_search(agent)) {
|
||||
auto key = std::string(agent.key().ptr(), agent.key().length());
|
||||
if (key > value) {
|
||||
break;
|
||||
}
|
||||
ids.push_back(agent.key().id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw SegcoreError(
|
||||
OpTypeInvalid,
|
||||
fmt::format("Invalid OperatorType: {}", static_cast<int>(op)));
|
||||
}
|
||||
|
||||
for (const auto str_id : ids) {
|
||||
auto offsets = str_ids_to_offsets_[str_id];
|
||||
for (auto offset : offsets) {
|
||||
bitset[offset] = true;
|
||||
}
|
||||
}
|
||||
|
@ -432,26 +465,38 @@ StringIndexMarisa::Range(std::string lower_bound_value,
|
|||
!(lb_inclusive && ub_inclusive))) {
|
||||
return bitset;
|
||||
}
|
||||
|
||||
auto common_prefix = GetCommonPrefix(lower_bound_value, upper_bound_value);
|
||||
marisa::Agent agent;
|
||||
for (size_t offset = 0; offset < count; ++offset) {
|
||||
agent.set_query(str_ids_[offset]);
|
||||
trie_.reverse_lookup(agent);
|
||||
std::string raw_data(agent.key().ptr(), agent.key().length());
|
||||
bool set = true;
|
||||
if (lb_inclusive) {
|
||||
set &= raw_data.compare(lower_bound_value) >= 0;
|
||||
} else {
|
||||
set &= raw_data.compare(lower_bound_value) > 0;
|
||||
agent.set_query(common_prefix.c_str());
|
||||
std::vector<size_t> ids;
|
||||
while (trie_.predictive_search(agent)) {
|
||||
std::string_view val =
|
||||
std::string_view(agent.key().ptr(), agent.key().length());
|
||||
if (val > upper_bound_value ||
|
||||
(!ub_inclusive && val == upper_bound_value)) {
|
||||
break;
|
||||
}
|
||||
if (ub_inclusive) {
|
||||
set &= raw_data.compare(upper_bound_value) <= 0;
|
||||
} else {
|
||||
set &= raw_data.compare(upper_bound_value) < 0;
|
||||
|
||||
if (val < lower_bound_value ||
|
||||
(!lb_inclusive && val == lower_bound_value)) {
|
||||
continue;
|
||||
}
|
||||
if (set) {
|
||||
|
||||
if (((lb_inclusive && lower_bound_value <= val) ||
|
||||
(!lb_inclusive && lower_bound_value < val)) &&
|
||||
((ub_inclusive && val <= upper_bound_value) ||
|
||||
(!ub_inclusive && val < upper_bound_value))) {
|
||||
ids.push_back(agent.key().id());
|
||||
}
|
||||
}
|
||||
for (const auto str_id : ids) {
|
||||
auto offsets = str_ids_to_offsets_[str_id];
|
||||
for (auto offset : offsets) {
|
||||
bitset[offset] = true;
|
||||
}
|
||||
}
|
||||
|
||||
return bitset;
|
||||
}
|
||||
|
||||
|
|
|
@ -190,3 +190,20 @@ TEST(Util, read_from_fd) {
|
|||
tmp_file.fd, read_buf.get(), data_size * max_loop, INT_MAX),
|
||||
milvus::SegcoreError);
|
||||
}
|
||||
|
||||
TEST(Util, get_common_prefix) {
|
||||
std::string str1 = "";
|
||||
std::string str2 = "milvus";
|
||||
auto common_prefix = milvus::GetCommonPrefix(str1, str2);
|
||||
EXPECT_STREQ(common_prefix.c_str(), "");
|
||||
|
||||
str1 = "milvus";
|
||||
str2 = "milvus is great";
|
||||
common_prefix = milvus::GetCommonPrefix(str1, str2);
|
||||
EXPECT_STREQ(common_prefix.c_str(), "milvus");
|
||||
|
||||
str1 = "milvus";
|
||||
str2 = "";
|
||||
common_prefix = milvus::GetCommonPrefix(str1, str2);
|
||||
EXPECT_STREQ(common_prefix.c_str(), "");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue