mirror of https://github.com/milvus-io/milvus.git
Rename SubSearchResult fields for better readability (#12341)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/12359/head
parent
9276fa5133
commit
365b5a5d01
|
@ -76,14 +76,14 @@ BinarySearchBruteForceFast(MetricType metric_type,
|
|||
const uint8_t* query_data,
|
||||
const faiss::BitsetView& bitset) {
|
||||
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
|
||||
float* result_distances = sub_result.get_values();
|
||||
idx_t* result_labels = sub_result.get_labels();
|
||||
float* result_distances = sub_result.get_distances();
|
||||
idx_t* result_ids = sub_result.get_ids();
|
||||
|
||||
int64_t code_size = dim / 8;
|
||||
const idx_t block_size = size_per_chunk;
|
||||
|
||||
raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,
|
||||
result_labels, bitset);
|
||||
result_ids, bitset);
|
||||
sub_result.round_values();
|
||||
return sub_result;
|
||||
}
|
||||
|
@ -103,12 +103,12 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset,
|
|||
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
|
||||
|
||||
if (metric_type == MetricType::METRIC_L2) {
|
||||
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
|
||||
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()};
|
||||
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
|
||||
sub_qr.round_values();
|
||||
return sub_qr;
|
||||
} else {
|
||||
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
|
||||
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()};
|
||||
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
|
||||
sub_qr.round_values();
|
||||
return sub_qr;
|
||||
|
|
|
@ -67,7 +67,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
|||
auto sub_qr = SearchOnIndex(search_dataset, *indexing, search_conf, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_labels()) {
|
||||
for (auto& x : sub_qr.mutable_ids()) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * size_per_chunk;
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
|||
auto sub_qr = FloatSearchBruteForce(search_dataset, chunk.data(), size_per_chunk, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_labels()) {
|
||||
for (auto& x : sub_qr.mutable_ids()) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * vec_size_per_chunk;
|
||||
}
|
||||
|
@ -101,8 +101,8 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
|||
final_qr.merge(sub_qr);
|
||||
}
|
||||
current_chunk_id = max_chunk;
|
||||
results.distances_ = std::move(final_qr.mutable_values());
|
||||
results.ids_ = std::move(final_qr.mutable_labels());
|
||||
results.distances_ = std::move(final_qr.mutable_distances());
|
||||
results.ids_ = std::move(final_qr.mutable_ids());
|
||||
results.topk_ = topk;
|
||||
results.num_queries_ = num_queries;
|
||||
|
||||
|
@ -155,7 +155,7 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
|
|||
auto sub_result = BinarySearchBruteForce(search_dataset, chunk.data(), nsize, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_result.mutable_labels()) {
|
||||
for (auto& x : sub_result.mutable_ids()) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * vec_size_per_chunk;
|
||||
}
|
||||
|
@ -164,8 +164,8 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
|
|||
}
|
||||
|
||||
final_result.round_values();
|
||||
results.distances_ = std::move(final_result.mutable_values());
|
||||
results.ids_ = std::move(final_result.mutable_labels());
|
||||
results.distances_ = std::move(final_result.mutable_distances());
|
||||
results.ids_ = std::move(final_result.mutable_ids());
|
||||
results.topk_ = topk;
|
||||
results.num_queries_ = num_queries;
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
|
|||
auto uids = ans->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
|
||||
SubSearchResult sub_qr(num_queries, topK, metric_type, round_decimal);
|
||||
std::copy_n(dis, num_queries * topK, sub_qr.get_values());
|
||||
std::copy_n(uids, num_queries * topK, sub_qr.get_labels());
|
||||
std::copy_n(dis, num_queries * topK, sub_qr.get_distances());
|
||||
std::copy_n(uids, num_queries * topK, sub_qr.get_ids());
|
||||
sub_qr.round_values();
|
||||
return sub_qr;
|
||||
}
|
||||
|
|
|
@ -27,34 +27,34 @@ SubSearchResult::merge_impl(const SubSearchResult& right) {
|
|||
for (int64_t qn = 0; qn < num_queries_; ++qn) {
|
||||
auto offset = qn * topk_;
|
||||
|
||||
int64_t* __restrict__ left_labels = this->get_labels() + offset;
|
||||
float* __restrict__ left_values = this->get_values() + offset;
|
||||
int64_t* __restrict__ left_ids = this->get_ids() + offset;
|
||||
float* __restrict__ left_distances = this->get_distances() + offset;
|
||||
|
||||
auto right_labels = right.get_labels() + offset;
|
||||
auto right_values = right.get_values() + offset;
|
||||
auto right_ids = right.get_ids() + offset;
|
||||
auto right_distances = right.get_distances() + offset;
|
||||
|
||||
std::vector<float> buf_values(topk_);
|
||||
std::vector<int64_t> buf_labels(topk_);
|
||||
std::vector<float> buf_distances(topk_);
|
||||
std::vector<int64_t> buf_ids(topk_);
|
||||
|
||||
auto lit = 0; // left iter
|
||||
auto rit = 0; // right iter
|
||||
|
||||
for (auto buf_iter = 0; buf_iter < topk_; ++buf_iter) {
|
||||
auto left_v = left_values[lit];
|
||||
auto right_v = right_values[rit];
|
||||
auto left_v = left_distances[lit];
|
||||
auto right_v = right_distances[rit];
|
||||
// optimize out at compiling
|
||||
if (is_desc ? (left_v >= right_v) : (left_v <= right_v)) {
|
||||
buf_values[buf_iter] = left_values[lit];
|
||||
buf_labels[buf_iter] = left_labels[lit];
|
||||
buf_distances[buf_iter] = left_distances[lit];
|
||||
buf_ids[buf_iter] = left_ids[lit];
|
||||
++lit;
|
||||
} else {
|
||||
buf_values[buf_iter] = right_values[rit];
|
||||
buf_labels[buf_iter] = right_labels[rit];
|
||||
buf_distances[buf_iter] = right_distances[rit];
|
||||
buf_ids[buf_iter] = right_ids[rit];
|
||||
++rit;
|
||||
}
|
||||
}
|
||||
std::copy_n(buf_values.data(), topk_, left_values);
|
||||
std::copy_n(buf_labels.data(), topk_, left_labels);
|
||||
std::copy_n(buf_distances.data(), topk_, left_distances);
|
||||
std::copy_n(buf_ids.data(), topk_, left_ids);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ SubSearchResult::round_values() {
|
|||
if (round_decimal_ == -1)
|
||||
return;
|
||||
const float multiplier = pow(10.0, round_decimal_);
|
||||
for (auto it = this->values_.begin(); it != this->values_.end(); it++) {
|
||||
for (auto it = this->distances_.begin(); it != this->distances_.end(); it++) {
|
||||
*it = round(*it * multiplier) / multiplier;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,9 +10,11 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
#include "common/Types.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "common/Types.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
class SubSearchResult {
|
||||
|
@ -21,8 +23,8 @@ class SubSearchResult {
|
|||
: metric_type_(metric_type),
|
||||
num_queries_(num_queries),
|
||||
topk_(topk),
|
||||
labels_(num_queries * topk, -1),
|
||||
values_(num_queries * topk, init_value(metric_type)),
|
||||
ids_(num_queries * topk, -1),
|
||||
distances_(num_queries * topk, init_value(metric_type)),
|
||||
round_decimal_(round_decimal) {
|
||||
}
|
||||
|
||||
|
@ -47,35 +49,42 @@ class SubSearchResult {
|
|||
get_num_queries() const {
|
||||
return num_queries_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_topk() const {
|
||||
return topk_;
|
||||
}
|
||||
|
||||
const int64_t*
|
||||
get_labels() const {
|
||||
return labels_.data();
|
||||
get_ids() const {
|
||||
return ids_.data();
|
||||
}
|
||||
|
||||
int64_t*
|
||||
get_labels() {
|
||||
return labels_.data();
|
||||
get_ids() {
|
||||
return ids_.data();
|
||||
}
|
||||
|
||||
const float*
|
||||
get_values() const {
|
||||
return values_.data();
|
||||
get_distances() const {
|
||||
return distances_.data();
|
||||
}
|
||||
|
||||
float*
|
||||
get_values() {
|
||||
return values_.data();
|
||||
get_distances() {
|
||||
return distances_.data();
|
||||
}
|
||||
|
||||
auto&
|
||||
mutable_labels() {
|
||||
return labels_;
|
||||
mutable_ids() {
|
||||
return ids_;
|
||||
}
|
||||
|
||||
auto&
|
||||
mutable_values() {
|
||||
return values_;
|
||||
mutable_distances() {
|
||||
return distances_;
|
||||
}
|
||||
|
||||
void
|
||||
round_values();
|
||||
|
||||
|
@ -95,8 +104,8 @@ class SubSearchResult {
|
|||
int64_t topk_;
|
||||
int64_t round_decimal_;
|
||||
MetricType metric_type_;
|
||||
std::vector<int64_t> labels_;
|
||||
std::vector<float> values_;
|
||||
std::vector<int64_t> ids_;
|
||||
std::vector<float> distances_;
|
||||
};
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -68,8 +68,8 @@ empty_search_result(int64_t num_queries, int64_t topk, int64_t round_decimal, Me
|
|||
SubSearchResult result(num_queries, topk, metric_type, round_decimal);
|
||||
final_result.num_queries_ = num_queries;
|
||||
final_result.topk_ = topk;
|
||||
final_result.ids_ = std::move(result.mutable_labels());
|
||||
final_result.distances_ = std::move(result.mutable_values());
|
||||
final_result.ids_ = std::move(result.mutable_ids());
|
||||
final_result.distances_ = std::move(result.mutable_distances());
|
||||
return final_result;
|
||||
}
|
||||
|
||||
|
|
|
@ -351,8 +351,8 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
|
|||
}();
|
||||
|
||||
SearchResult results;
|
||||
results.distances_ = std::move(sub_qr.mutable_values());
|
||||
results.ids_ = std::move(sub_qr.mutable_labels());
|
||||
results.distances_ = std::move(sub_qr.mutable_distances());
|
||||
results.ids_ = std::move(sub_qr.mutable_ids());
|
||||
results.topk_ = dataset.topk;
|
||||
results.num_queries_ = dataset.num_queries;
|
||||
|
||||
|
|
|
@ -242,8 +242,8 @@ TEST(Indexing, BinaryBruteForce) {
|
|||
SearchResult sr;
|
||||
sr.num_queries_ = num_queries;
|
||||
sr.topk_ = topk;
|
||||
sr.ids_ = std::move(sub_result.mutable_labels());
|
||||
sr.distances_ = std::move(sub_result.mutable_values());
|
||||
sr.ids_ = std::move(sub_result.mutable_ids());
|
||||
sr.distances_ = std::move(sub_result.mutable_distances());
|
||||
|
||||
auto json = SearchResultToJson(sr);
|
||||
std::cout << json.dump(2);
|
||||
|
|
|
@ -37,22 +37,22 @@ TEST(Reduce, SubQueryResult) {
|
|||
std::default_random_engine e(42);
|
||||
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
|
||||
for (int i = 0; i < iteration; ++i) {
|
||||
std::vector<int64_t> labels;
|
||||
std::vector<float> values;
|
||||
std::vector<int64_t> ids;
|
||||
std::vector<float> distances;
|
||||
for (int n = 0; n < num_queries; ++n) {
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
auto gen_x = e() % limit;
|
||||
ref_results[n].push(gen_x);
|
||||
ref_results[n].pop();
|
||||
labels.push_back(gen_x);
|
||||
values.push_back(gen_x);
|
||||
ids.push_back(gen_x);
|
||||
distances.push_back(gen_x);
|
||||
}
|
||||
std::sort(labels.begin() + n * topk, labels.begin() + n * topk + topk);
|
||||
std::sort(values.begin() + n * topk, values.begin() + n * topk + topk);
|
||||
std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk);
|
||||
std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk);
|
||||
}
|
||||
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
|
||||
sub_result.mutable_values() = values;
|
||||
sub_result.mutable_labels() = labels;
|
||||
sub_result.mutable_distances() = distances;
|
||||
sub_result.mutable_ids() = ids;
|
||||
final_result.merge(sub_result);
|
||||
}
|
||||
|
||||
|
@ -62,10 +62,10 @@ TEST(Reduce, SubQueryResult) {
|
|||
auto ref_x = ref_results[n].top();
|
||||
ref_results[n].pop();
|
||||
auto index = n * topk + topk - 1 - k;
|
||||
auto label = final_result.get_labels()[index];
|
||||
auto value = final_result.get_values()[index];
|
||||
ASSERT_EQ(label, ref_x);
|
||||
ASSERT_EQ(value, ref_x);
|
||||
auto id = final_result.get_ids()[index];
|
||||
auto distance = final_result.get_distances()[index];
|
||||
ASSERT_EQ(id, ref_x);
|
||||
ASSERT_EQ(distance, ref_x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,22 +89,22 @@ TEST(Reduce, SubSearchResultDesc) {
|
|||
std::default_random_engine e(42);
|
||||
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
|
||||
for (int i = 0; i < iteration; ++i) {
|
||||
std::vector<int64_t> labels;
|
||||
std::vector<float> values;
|
||||
std::vector<int64_t> ids;
|
||||
std::vector<float> distances;
|
||||
for (int n = 0; n < num_queries; ++n) {
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
auto gen_x = e() % limit;
|
||||
ref_results[n].push(gen_x);
|
||||
ref_results[n].pop();
|
||||
labels.push_back(gen_x);
|
||||
values.push_back(gen_x);
|
||||
ids.push_back(gen_x);
|
||||
distances.push_back(gen_x);
|
||||
}
|
||||
std::sort(labels.begin() + n * topk, labels.begin() + n * topk + topk, std::greater<int64_t>());
|
||||
std::sort(values.begin() + n * topk, values.begin() + n * topk + topk, std::greater<float>());
|
||||
std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk, std::greater<int64_t>());
|
||||
std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk, std::greater<float>());
|
||||
}
|
||||
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
|
||||
sub_result.mutable_values() = values;
|
||||
sub_result.mutable_labels() = labels;
|
||||
sub_result.mutable_distances() = distances;
|
||||
sub_result.mutable_ids() = ids;
|
||||
final_result.merge(sub_result);
|
||||
}
|
||||
|
||||
|
@ -114,10 +114,10 @@ TEST(Reduce, SubSearchResultDesc) {
|
|||
auto ref_x = ref_results[n].top();
|
||||
ref_results[n].pop();
|
||||
auto index = n * topk + topk - 1 - k;
|
||||
auto label = final_result.get_labels()[index];
|
||||
auto value = final_result.get_values()[index];
|
||||
ASSERT_EQ(label, ref_x);
|
||||
ASSERT_EQ(value, ref_x);
|
||||
auto id = final_result.get_ids()[index];
|
||||
auto distance = final_result.get_distances()[index];
|
||||
ASSERT_EQ(id, ref_x);
|
||||
ASSERT_EQ(distance, ref_x);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue