Rename SubSearchResult fields for better readability (#12341)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/12359/head
Cai Yudong 2021-11-29 17:07:40 +08:00 committed by GitHub
parent 9276fa5133
commit 365b5a5d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 85 additions and 76 deletions

View File

@ -76,14 +76,14 @@ BinarySearchBruteForceFast(MetricType metric_type,
const uint8_t* query_data,
const faiss::BitsetView& bitset) {
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
float* result_distances = sub_result.get_values();
idx_t* result_labels = sub_result.get_labels();
float* result_distances = sub_result.get_distances();
idx_t* result_ids = sub_result.get_ids();
int64_t code_size = dim / 8;
const idx_t block_size = size_per_chunk;
raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,
result_labels, bitset);
result_ids, bitset);
sub_result.round_values();
return sub_result;
}
@ -103,12 +103,12 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset,
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
if (metric_type == MetricType::METRIC_L2) {
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()};
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
sub_qr.round_values();
return sub_qr;
} else {
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()};
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
sub_qr.round_values();
return sub_qr;

View File

@ -67,7 +67,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
auto sub_qr = SearchOnIndex(search_dataset, *indexing, search_conf, sub_view);
// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_labels()) {
for (auto& x : sub_qr.mutable_ids()) {
if (x != -1) {
x += chunk_id * size_per_chunk;
}
@ -93,7 +93,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
auto sub_qr = FloatSearchBruteForce(search_dataset, chunk.data(), size_per_chunk, sub_view);
// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_labels()) {
for (auto& x : sub_qr.mutable_ids()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
@ -101,8 +101,8 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
final_qr.merge(sub_qr);
}
current_chunk_id = max_chunk;
results.distances_ = std::move(final_qr.mutable_values());
results.ids_ = std::move(final_qr.mutable_labels());
results.distances_ = std::move(final_qr.mutable_distances());
results.ids_ = std::move(final_qr.mutable_ids());
results.topk_ = topk;
results.num_queries_ = num_queries;
@ -155,7 +155,7 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
auto sub_result = BinarySearchBruteForce(search_dataset, chunk.data(), nsize, sub_view);
// convert chunk uid to segment uid
for (auto& x : sub_result.mutable_labels()) {
for (auto& x : sub_result.mutable_ids()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
@ -164,8 +164,8 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
}
final_result.round_values();
results.distances_ = std::move(final_result.mutable_values());
results.ids_ = std::move(final_result.mutable_labels());
results.distances_ = std::move(final_result.mutable_distances());
results.ids_ = std::move(final_result.mutable_ids());
results.topk_ = topk;
results.num_queries_ = num_queries;

View File

@ -32,8 +32,8 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
auto uids = ans->Get<int64_t*>(milvus::knowhere::meta::IDS);
SubSearchResult sub_qr(num_queries, topK, metric_type, round_decimal);
std::copy_n(dis, num_queries * topK, sub_qr.get_values());
std::copy_n(uids, num_queries * topK, sub_qr.get_labels());
std::copy_n(dis, num_queries * topK, sub_qr.get_distances());
std::copy_n(uids, num_queries * topK, sub_qr.get_ids());
sub_qr.round_values();
return sub_qr;
}

View File

@ -27,34 +27,34 @@ SubSearchResult::merge_impl(const SubSearchResult& right) {
for (int64_t qn = 0; qn < num_queries_; ++qn) {
auto offset = qn * topk_;
int64_t* __restrict__ left_labels = this->get_labels() + offset;
float* __restrict__ left_values = this->get_values() + offset;
int64_t* __restrict__ left_ids = this->get_ids() + offset;
float* __restrict__ left_distances = this->get_distances() + offset;
auto right_labels = right.get_labels() + offset;
auto right_values = right.get_values() + offset;
auto right_ids = right.get_ids() + offset;
auto right_distances = right.get_distances() + offset;
std::vector<float> buf_values(topk_);
std::vector<int64_t> buf_labels(topk_);
std::vector<float> buf_distances(topk_);
std::vector<int64_t> buf_ids(topk_);
auto lit = 0; // left iter
auto rit = 0; // right iter
for (auto buf_iter = 0; buf_iter < topk_; ++buf_iter) {
auto left_v = left_values[lit];
auto right_v = right_values[rit];
auto left_v = left_distances[lit];
auto right_v = right_distances[rit];
// optimize out at compiling
if (is_desc ? (left_v >= right_v) : (left_v <= right_v)) {
buf_values[buf_iter] = left_values[lit];
buf_labels[buf_iter] = left_labels[lit];
buf_distances[buf_iter] = left_distances[lit];
buf_ids[buf_iter] = left_ids[lit];
++lit;
} else {
buf_values[buf_iter] = right_values[rit];
buf_labels[buf_iter] = right_labels[rit];
buf_distances[buf_iter] = right_distances[rit];
buf_ids[buf_iter] = right_ids[rit];
++rit;
}
}
std::copy_n(buf_values.data(), topk_, left_values);
std::copy_n(buf_labels.data(), topk_, left_labels);
std::copy_n(buf_distances.data(), topk_, left_distances);
std::copy_n(buf_ids.data(), topk_, left_ids);
}
}
@ -80,7 +80,7 @@ SubSearchResult::round_values() {
if (round_decimal_ == -1)
return;
const float multiplier = pow(10.0, round_decimal_);
for (auto it = this->values_.begin(); it != this->values_.end(); it++) {
for (auto it = this->distances_.begin(); it != this->distances_.end(); it++) {
*it = round(*it * multiplier) / multiplier;
}
}

View File

@ -10,9 +10,11 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "common/Types.h"
#include <limits>
#include <vector>
#include "common/Types.h"
namespace milvus::query {
class SubSearchResult {
@ -21,8 +23,8 @@ class SubSearchResult {
: metric_type_(metric_type),
num_queries_(num_queries),
topk_(topk),
labels_(num_queries * topk, -1),
values_(num_queries * topk, init_value(metric_type)),
ids_(num_queries * topk, -1),
distances_(num_queries * topk, init_value(metric_type)),
round_decimal_(round_decimal) {
}
@ -47,35 +49,42 @@ class SubSearchResult {
get_num_queries() const {
return num_queries_;
}
int64_t
get_topk() const {
return topk_;
}
const int64_t*
get_labels() const {
return labels_.data();
get_ids() const {
return ids_.data();
}
int64_t*
get_labels() {
return labels_.data();
get_ids() {
return ids_.data();
}
const float*
get_values() const {
return values_.data();
get_distances() const {
return distances_.data();
}
float*
get_values() {
return values_.data();
get_distances() {
return distances_.data();
}
auto&
mutable_labels() {
return labels_;
mutable_ids() {
return ids_;
}
auto&
mutable_values() {
return values_;
mutable_distances() {
return distances_;
}
void
round_values();
@ -95,8 +104,8 @@ class SubSearchResult {
int64_t topk_;
int64_t round_decimal_;
MetricType metric_type_;
std::vector<int64_t> labels_;
std::vector<float> values_;
std::vector<int64_t> ids_;
std::vector<float> distances_;
};
} // namespace milvus::query

View File

@ -68,8 +68,8 @@ empty_search_result(int64_t num_queries, int64_t topk, int64_t round_decimal, Me
SubSearchResult result(num_queries, topk, metric_type, round_decimal);
final_result.num_queries_ = num_queries;
final_result.topk_ = topk;
final_result.ids_ = std::move(result.mutable_labels());
final_result.distances_ = std::move(result.mutable_values());
final_result.ids_ = std::move(result.mutable_ids());
final_result.distances_ = std::move(result.mutable_distances());
return final_result;
}

View File

@ -351,8 +351,8 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
}();
SearchResult results;
results.distances_ = std::move(sub_qr.mutable_values());
results.ids_ = std::move(sub_qr.mutable_labels());
results.distances_ = std::move(sub_qr.mutable_distances());
results.ids_ = std::move(sub_qr.mutable_ids());
results.topk_ = dataset.topk;
results.num_queries_ = dataset.num_queries;

View File

@ -242,8 +242,8 @@ TEST(Indexing, BinaryBruteForce) {
SearchResult sr;
sr.num_queries_ = num_queries;
sr.topk_ = topk;
sr.ids_ = std::move(sub_result.mutable_labels());
sr.distances_ = std::move(sub_result.mutable_values());
sr.ids_ = std::move(sub_result.mutable_ids());
sr.distances_ = std::move(sub_result.mutable_distances());
auto json = SearchResultToJson(sr);
std::cout << json.dump(2);

View File

@ -37,22 +37,22 @@ TEST(Reduce, SubQueryResult) {
std::default_random_engine e(42);
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
for (int i = 0; i < iteration; ++i) {
std::vector<int64_t> labels;
std::vector<float> values;
std::vector<int64_t> ids;
std::vector<float> distances;
for (int n = 0; n < num_queries; ++n) {
for (int k = 0; k < topk; ++k) {
auto gen_x = e() % limit;
ref_results[n].push(gen_x);
ref_results[n].pop();
labels.push_back(gen_x);
values.push_back(gen_x);
ids.push_back(gen_x);
distances.push_back(gen_x);
}
std::sort(labels.begin() + n * topk, labels.begin() + n * topk + topk);
std::sort(values.begin() + n * topk, values.begin() + n * topk + topk);
std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk);
std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk);
}
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
sub_result.mutable_values() = values;
sub_result.mutable_labels() = labels;
sub_result.mutable_distances() = distances;
sub_result.mutable_ids() = ids;
final_result.merge(sub_result);
}
@ -62,10 +62,10 @@ TEST(Reduce, SubQueryResult) {
auto ref_x = ref_results[n].top();
ref_results[n].pop();
auto index = n * topk + topk - 1 - k;
auto label = final_result.get_labels()[index];
auto value = final_result.get_values()[index];
ASSERT_EQ(label, ref_x);
ASSERT_EQ(value, ref_x);
auto id = final_result.get_ids()[index];
auto distance = final_result.get_distances()[index];
ASSERT_EQ(id, ref_x);
ASSERT_EQ(distance, ref_x);
}
}
}
@ -89,22 +89,22 @@ TEST(Reduce, SubSearchResultDesc) {
std::default_random_engine e(42);
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
for (int i = 0; i < iteration; ++i) {
std::vector<int64_t> labels;
std::vector<float> values;
std::vector<int64_t> ids;
std::vector<float> distances;
for (int n = 0; n < num_queries; ++n) {
for (int k = 0; k < topk; ++k) {
auto gen_x = e() % limit;
ref_results[n].push(gen_x);
ref_results[n].pop();
labels.push_back(gen_x);
values.push_back(gen_x);
ids.push_back(gen_x);
distances.push_back(gen_x);
}
std::sort(labels.begin() + n * topk, labels.begin() + n * topk + topk, std::greater<int64_t>());
std::sort(values.begin() + n * topk, values.begin() + n * topk + topk, std::greater<float>());
std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk, std::greater<int64_t>());
std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk, std::greater<float>());
}
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
sub_result.mutable_values() = values;
sub_result.mutable_labels() = labels;
sub_result.mutable_distances() = distances;
sub_result.mutable_ids() = ids;
final_result.merge(sub_result);
}
@ -114,10 +114,10 @@ TEST(Reduce, SubSearchResultDesc) {
auto ref_x = ref_results[n].top();
ref_results[n].pop();
auto index = n * topk + topk - 1 - k;
auto label = final_result.get_labels()[index];
auto value = final_result.get_values()[index];
ASSERT_EQ(label, ref_x);
ASSERT_EQ(value, ref_x);
auto id = final_result.get_ids()[index];
auto distance = final_result.get_distances()[index];
ASSERT_EQ(id, ref_x);
ASSERT_EQ(distance, ref_x);
}
}
}