mirror of https://github.com/milvus-io/milvus.git
[skip e2e] Add more testcases with different parameter combinations in test_reduce (#18967)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com> Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/18975/head
parent
dcddc9d665
commit
da96659569
|
@ -24,7 +24,7 @@ using SubSearchResultUniq = std::unique_ptr<SubSearchResult>;
|
|||
|
||||
std::default_random_engine e(42);
|
||||
|
||||
std::unique_ptr<SubSearchResult>
|
||||
SubSearchResultUniq
|
||||
GenSubSearchResult(const int64_t nq,
|
||||
const int64_t topk,
|
||||
const knowhere::MetricType &metric_type,
|
||||
|
@ -34,8 +34,8 @@ GenSubSearchResult(const int64_t nq,
|
|||
SubSearchResultUniq sub_result = std::make_unique<SubSearchResult>(nq, topk, metric_type, round_decimal);
|
||||
std::vector<int64_t> ids;
|
||||
std::vector<float> distances;
|
||||
for (int n = 0; n < nq; ++n) {
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
for (auto n = 0; n < nq; ++n) {
|
||||
for (auto k = 0; k < topk; ++k) {
|
||||
auto gen_x = e() % limit;
|
||||
ids.push_back(gen_x);
|
||||
distances.push_back(gen_x);
|
||||
|
@ -57,7 +57,7 @@ template<class queue_type>
|
|||
void
|
||||
CheckSubSearchResult(const int64_t nq,
|
||||
const int64_t topk,
|
||||
SubSearchResult& search_result,
|
||||
SubSearchResult& result,
|
||||
std::vector<queue_type>& result_ref) {
|
||||
ASSERT_EQ(result_ref.size(), nq);
|
||||
for (int n = 0; n < nq; ++n) {
|
||||
|
@ -66,8 +66,8 @@ CheckSubSearchResult(const int64_t nq,
|
|||
auto ref_x = result_ref[n].top();
|
||||
result_ref[n].pop();
|
||||
auto index = n * topk + topk - 1 - k;
|
||||
auto id = search_result.get_seg_offsets()[index];
|
||||
auto distance = search_result.get_distances()[index];
|
||||
auto id = result.get_seg_offsets()[index];
|
||||
auto distance = result.get_distances()[index];
|
||||
ASSERT_EQ(id, ref_x);
|
||||
ASSERT_EQ(distance, ref_x);
|
||||
}
|
||||
|
@ -76,19 +76,19 @@ CheckSubSearchResult(const int64_t nq,
|
|||
|
||||
template<class queue_type>
|
||||
void
|
||||
TestSubSearchResultMerge(const knowhere::MetricType& metric_type) {
|
||||
int64_t num_queries = 16;
|
||||
int64_t topk = 10;
|
||||
int64_t iteration = 10;
|
||||
int64_t round_decimal = 3;
|
||||
TestSubSearchResultMerge(const knowhere::MetricType& metric_type,
|
||||
const int64_t iteration,
|
||||
const int64_t nq,
|
||||
const int64_t topk) {
|
||||
const int64_t round_decimal = 3;
|
||||
|
||||
std::vector<queue_type> result_ref(num_queries);
|
||||
std::vector<queue_type> result_ref(nq);
|
||||
|
||||
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
|
||||
SubSearchResult final_result(nq, topk, metric_type, round_decimal);
|
||||
for (int i = 0; i < iteration; ++i) {
|
||||
SubSearchResultUniq sub_result = GenSubSearchResult(num_queries, topk, metric_type, round_decimal);
|
||||
SubSearchResultUniq sub_result = GenSubSearchResult(nq, topk, metric_type, round_decimal);
|
||||
auto ids = sub_result->get_ids();
|
||||
for (int n = 0; n < num_queries; ++n) {
|
||||
for (int n = 0; n < nq; ++n) {
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
int64_t x = ids[n * topk + k];
|
||||
result_ref[n].push(x);
|
||||
|
@ -99,12 +99,28 @@ TestSubSearchResultMerge(const knowhere::MetricType& metric_type) {
|
|||
}
|
||||
final_result.merge(*sub_result);
|
||||
}
|
||||
CheckSubSearchResult<queue_type>(num_queries, topk, final_result, result_ref);
|
||||
CheckSubSearchResult<queue_type>(nq, topk, final_result, result_ref);
|
||||
}
|
||||
|
||||
TEST(Reduce, SubSearchResult) {
|
||||
using queue_type_l2 = std::priority_queue<int64_t, std::vector<int64_t>, std::less<int64_t>>;
|
||||
using queue_type_ip = std::priority_queue<int64_t, std::vector<int64_t>, std::greater<int64_t>>;
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP);
|
||||
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 1, 1);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 1, 10);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 16, 1);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 16, 10);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 1, 1);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 1, 10);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 16, 1);
|
||||
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 16, 10);
|
||||
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 1, 1);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 1, 10);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 16, 1);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 16, 10);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 1, 1);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 1, 10);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 16, 1);
|
||||
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 16, 10);
|
||||
}
|
Loading…
Reference in New Issue