fix: Make legacy non-lexicographic branch break swtich (#36125)

Related to #35941
Previous PR: #36034

This patch makes the switch branching logic correct and make the unit
test work for cases which does not select the whole dataset.

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/36144/head
congqixia 2024-09-10 10:15:07 +08:00 committed by GitHub
parent c0c12c6c5b
commit 851f3b9883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 11 deletions

View File

@ -374,7 +374,6 @@ StringIndexMarisa::Range(std::string value, OpType op) {
}
ids.push_back(agent.key().id());
}
break;
} else {
// lexicographic order is not guaranteed, check all values
while (trie_.predictive_search(agent)) {
@ -385,6 +384,7 @@ StringIndexMarisa::Range(std::string value, OpType op) {
}
};
}
break;
}
case OpType::LessEqual: {
if (in_lexico_order) {
@ -396,7 +396,6 @@ StringIndexMarisa::Range(std::string value, OpType op) {
}
ids.push_back(agent.key().id());
}
break;
} else {
// lexicographic order is not guaranteed, check all values
while (trie_.predictive_search(agent)) {
@ -407,6 +406,7 @@ StringIndexMarisa::Range(std::string value, OpType op) {
}
};
}
break;
}
default:
PanicInfo(

View File

@ -21,6 +21,7 @@
#include "test_utils/indexbuilder_test_utils.h"
#include "test_utils/AssertUtils.h"
#include <boost/filesystem.hpp>
#include <numeric>
#include "test_utils/storage_test_utils.h"
constexpr int64_t nb = 100;
@ -83,39 +84,67 @@ TEST_F(StringIndexMarisaTest, NotIn) {
TEST_F(StringIndexMarisaTest, Range) {
auto index = milvus::index::CreateStringIndexMarisa();
std::vector<std::string> strings(nb);
std::vector<int> counts(10);
for (int i = 0; i < nb; ++i) {
strings[i] = std::to_string(std::rand() % 10);
int val = std::rand() % 10;
counts[val]++;
strings[i] = std::to_string(val);
}
index->Build(nb, strings.data());
{
// [0...9]
auto bitset = index->Range("0", milvus::OpType::GreaterEqual);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = index->Range("90", milvus::OpType::LessThan);
// [5...9]
int expect = std::accumulate(counts.begin() + 5, counts.end(), 0);
auto bitset = index->Range("5", milvus::OpType::GreaterEqual);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
ASSERT_EQ(Count(bitset), expect);
}
{
auto bitset = index->Range("9", milvus::OpType::LessEqual);
// [6...9]
int expect = std::accumulate(counts.begin() + 6, counts.end(), 0);
auto bitset = index->Range("5", milvus::OpType::GreaterThan);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
ASSERT_EQ(Count(bitset), expect);
}
{
auto bitset = index->Range("0", true, "9", true);
// [0...3]
int expect = std::accumulate(counts.begin(), counts.begin() + 4, 0);
auto bitset = index->Range("4", milvus::OpType::LessThan);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
ASSERT_EQ(Count(bitset), expect);
}
{
auto bitset = index->Range("0", true, "90", false);
// [0...4]
int expect = std::accumulate(counts.begin(), counts.begin() + 5, 0);
auto bitset = index->Range("4", milvus::OpType::LessEqual);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
ASSERT_EQ(Count(bitset), expect);
}
{
// [2...8]
int expect = std::accumulate(counts.begin() + 2, counts.begin() + 9, 0);
auto bitset = index->Range("2", true, "8", true);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), expect);
}
{
// [0...8]
int expect = std::accumulate(counts.begin(), counts.begin() + 9, 0);
auto bitset = index->Range("0", true, "9", false);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), expect);
}
}