diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index c755d4174c..28a17e13bb 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -76,14 +76,14 @@ BinarySearchBruteForceFast(MetricType metric_type, const uint8_t* query_data, const faiss::BitsetView& bitset) { SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal); - float* result_distances = sub_result.get_values(); - idx_t* result_labels = sub_result.get_labels(); + float* result_distances = sub_result.get_distances(); + idx_t* result_ids = sub_result.get_ids(); int64_t code_size = dim / 8; const idx_t block_size = size_per_chunk; raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances, - result_labels, bitset); + result_ids, bitset); sub_result.round_values(); return sub_result; } @@ -103,12 +103,12 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset, auto chunk_data = reinterpret_cast(chunk_data_raw); if (metric_type == MetricType::METRIC_L2) { - faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()}; + faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()}; faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset); sub_qr.round_values(); return sub_qr; } else { - faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()}; + faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()}; faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset); sub_qr.round_values(); return sub_qr; diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 2ee6fac1ee..b9b6df32b4 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -67,7 +67,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment, auto sub_qr = SearchOnIndex(search_dataset, *indexing, search_conf, sub_view); // convert chunk uid to segment uid - for (auto& x : sub_qr.mutable_labels()) { + for (auto& x : sub_qr.mutable_ids()) { if (x != -1) { x += chunk_id * size_per_chunk; } @@ -93,7 +93,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment, auto sub_qr = FloatSearchBruteForce(search_dataset, chunk.data(), size_per_chunk, sub_view); // convert chunk uid to segment uid - for (auto& x : sub_qr.mutable_labels()) { + for (auto& x : sub_qr.mutable_ids()) { if (x != -1) { x += chunk_id * vec_size_per_chunk; } @@ -101,8 +101,8 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment, final_qr.merge(sub_qr); } current_chunk_id = max_chunk; - results.distances_ = std::move(final_qr.mutable_values()); - results.ids_ = std::move(final_qr.mutable_labels()); + results.distances_ = std::move(final_qr.mutable_distances()); + results.ids_ = std::move(final_qr.mutable_ids()); results.topk_ = topk; results.num_queries_ = num_queries; @@ -155,7 +155,7 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment, auto sub_result = BinarySearchBruteForce(search_dataset, chunk.data(), nsize, sub_view); // convert chunk uid to segment uid - for (auto& x : sub_result.mutable_labels()) { + for (auto& x : sub_result.mutable_ids()) { if (x != -1) { x += chunk_id * vec_size_per_chunk; } @@ -164,8 +164,8 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment, } final_result.round_values(); - results.distances_ = std::move(final_result.mutable_values()); - results.ids_ = std::move(final_result.mutable_labels()); + results.distances_ = std::move(final_result.mutable_distances()); + results.ids_ = std::move(final_result.mutable_ids()); results.topk_ = topk; results.num_queries_ = num_queries; diff --git a/internal/core/src/query/SearchOnIndex.cpp b/internal/core/src/query/SearchOnIndex.cpp index db44bf6ec4..69c64cc701 100644 --- a/internal/core/src/query/SearchOnIndex.cpp +++ b/internal/core/src/query/SearchOnIndex.cpp @@ -32,8 +32,8 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset, auto uids = ans->Get(milvus::knowhere::meta::IDS); SubSearchResult sub_qr(num_queries, topK, metric_type, round_decimal); - std::copy_n(dis, num_queries * topK, sub_qr.get_values()); - std::copy_n(uids, num_queries * topK, sub_qr.get_labels()); + std::copy_n(dis, num_queries * topK, sub_qr.get_distances()); + std::copy_n(uids, num_queries * topK, sub_qr.get_ids()); sub_qr.round_values(); return sub_qr; } diff --git a/internal/core/src/query/SubSearchResult.cpp b/internal/core/src/query/SubSearchResult.cpp index e0d7d7bd3a..0838664e36 100644 --- a/internal/core/src/query/SubSearchResult.cpp +++ b/internal/core/src/query/SubSearchResult.cpp @@ -27,34 +27,34 @@ SubSearchResult::merge_impl(const SubSearchResult& right) { for (int64_t qn = 0; qn < num_queries_; ++qn) { auto offset = qn * topk_; - int64_t* __restrict__ left_labels = this->get_labels() + offset; - float* __restrict__ left_values = this->get_values() + offset; + int64_t* __restrict__ left_ids = this->get_ids() + offset; + float* __restrict__ left_distances = this->get_distances() + offset; - auto right_labels = right.get_labels() + offset; - auto right_values = right.get_values() + offset; + auto right_ids = right.get_ids() + offset; + auto right_distances = right.get_distances() + offset; - std::vector buf_values(topk_); - std::vector buf_labels(topk_); + std::vector buf_distances(topk_); + std::vector buf_ids(topk_); auto lit = 0; // left iter auto rit = 0; // right iter for (auto buf_iter = 0; buf_iter < topk_; ++buf_iter) { - auto left_v = left_values[lit]; - auto right_v = right_values[rit]; + auto left_v = left_distances[lit]; + auto right_v = right_distances[rit]; // optimize out at compiling if (is_desc ? (left_v >= right_v) : (left_v <= right_v)) { - buf_values[buf_iter] = left_values[lit]; - buf_labels[buf_iter] = left_labels[lit]; + buf_distances[buf_iter] = left_distances[lit]; + buf_ids[buf_iter] = left_ids[lit]; ++lit; } else { - buf_values[buf_iter] = right_values[rit]; - buf_labels[buf_iter] = right_labels[rit]; + buf_distances[buf_iter] = right_distances[rit]; + buf_ids[buf_iter] = right_ids[rit]; ++rit; } } - std::copy_n(buf_values.data(), topk_, left_values); - std::copy_n(buf_labels.data(), topk_, left_labels); + std::copy_n(buf_distances.data(), topk_, left_distances); + std::copy_n(buf_ids.data(), topk_, left_ids); } } @@ -80,7 +80,7 @@ SubSearchResult::round_values() { if (round_decimal_ == -1) return; const float multiplier = pow(10.0, round_decimal_); - for (auto it = this->values_.begin(); it != this->values_.end(); it++) { + for (auto it = this->distances_.begin(); it != this->distances_.end(); it++) { *it = round(*it * multiplier) / multiplier; } } diff --git a/internal/core/src/query/SubSearchResult.h b/internal/core/src/query/SubSearchResult.h index fcc7b41ef6..b12927f800 100644 --- a/internal/core/src/query/SubSearchResult.h +++ b/internal/core/src/query/SubSearchResult.h @@ -10,9 +10,11 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #pragma once -#include "common/Types.h" + #include #include +#include "common/Types.h" + namespace milvus::query { class SubSearchResult { @@ -21,8 +23,8 @@ class SubSearchResult { : metric_type_(metric_type), num_queries_(num_queries), topk_(topk), - labels_(num_queries * topk, -1), - values_(num_queries * topk, init_value(metric_type)), + ids_(num_queries * topk, -1), + distances_(num_queries * topk, init_value(metric_type)), round_decimal_(round_decimal) { } @@ -47,35 +49,42 @@ class SubSearchResult { get_num_queries() const { return num_queries_; } + int64_t get_topk() const { return topk_; } const int64_t* - get_labels() const { - return labels_.data(); + get_ids() const { + return ids_.data(); } + int64_t* - get_labels() { - return labels_.data(); + get_ids() { + return ids_.data(); } + const float* - get_values() const { - return values_.data(); + get_distances() const { + return distances_.data(); } + float* - get_values() { - return values_.data(); + get_distances() { + return distances_.data(); } + auto& - mutable_labels() { - return labels_; + mutable_ids() { + return ids_; } + auto& - mutable_values() { - return values_; + mutable_distances() { + return distances_; } + void round_values(); @@ -95,8 +104,8 @@ class SubSearchResult { int64_t topk_; int64_t round_decimal_; MetricType metric_type_; - std::vector labels_; - std::vector values_; + std::vector ids_; + std::vector distances_; }; } // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index 6d15991662..e3f61d7b29 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -68,8 +68,8 @@ empty_search_result(int64_t num_queries, int64_t topk, int64_t round_decimal, Me SubSearchResult result(num_queries, topk, metric_type, round_decimal); final_result.num_queries_ = num_queries; final_result.topk_ = topk; - final_result.ids_ = std::move(result.mutable_labels()); - final_result.distances_ = std::move(result.mutable_values()); + final_result.ids_ = std::move(result.mutable_ids()); + final_result.distances_ = std::move(result.mutable_distances()); return final_result; } diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 600c4b8b37..553c62ceee 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -351,8 +351,8 @@ SegmentSealedImpl::vector_search(int64_t vec_count, }(); SearchResult results; - results.distances_ = std::move(sub_qr.mutable_values()); - results.ids_ = std::move(sub_qr.mutable_labels()); + results.distances_ = std::move(sub_qr.mutable_distances()); + results.ids_ = std::move(sub_qr.mutable_ids()); results.topk_ = dataset.topk; results.num_queries_ = dataset.num_queries; diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index e8dd47b4bc..5ab4d7426e 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -242,8 +242,8 @@ TEST(Indexing, BinaryBruteForce) { SearchResult sr; sr.num_queries_ = num_queries; sr.topk_ = topk; - sr.ids_ = std::move(sub_result.mutable_labels()); - sr.distances_ = std::move(sub_result.mutable_values()); + sr.ids_ = std::move(sub_result.mutable_ids()); + sr.distances_ = std::move(sub_result.mutable_distances()); auto json = SearchResultToJson(sr); std::cout << json.dump(2); diff --git a/internal/core/unittest/test_reduce.cpp b/internal/core/unittest/test_reduce.cpp index 7f028d1e45..f55c7d13f1 100644 --- a/internal/core/unittest/test_reduce.cpp +++ b/internal/core/unittest/test_reduce.cpp @@ -37,22 +37,22 @@ TEST(Reduce, SubQueryResult) { std::default_random_engine e(42); SubSearchResult final_result(num_queries, topk, metric_type, round_decimal); for (int i = 0; i < iteration; ++i) { - std::vector labels; - std::vector values; + std::vector ids; + std::vector distances; for (int n = 0; n < num_queries; ++n) { for (int k = 0; k < topk; ++k) { auto gen_x = e() % limit; ref_results[n].push(gen_x); ref_results[n].pop(); - labels.push_back(gen_x); - values.push_back(gen_x); + ids.push_back(gen_x); + distances.push_back(gen_x); } - std::sort(labels.begin() + n * topk, labels.begin() + n * topk + topk); - std::sort(values.begin() + n * topk, values.begin() + n * topk + topk); + std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk); + std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk); } SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal); - sub_result.mutable_values() = values; - sub_result.mutable_labels() = labels; + sub_result.mutable_distances() = distances; + sub_result.mutable_ids() = ids; final_result.merge(sub_result); } @@ -62,10 +62,10 @@ TEST(Reduce, SubQueryResult) { auto ref_x = ref_results[n].top(); ref_results[n].pop(); auto index = n * topk + topk - 1 - k; - auto label = final_result.get_labels()[index]; - auto value = final_result.get_values()[index]; - ASSERT_EQ(label, ref_x); - ASSERT_EQ(value, ref_x); + auto id = final_result.get_ids()[index]; + auto distance = final_result.get_distances()[index]; + ASSERT_EQ(id, ref_x); + ASSERT_EQ(distance, ref_x); } } } @@ -89,22 +89,22 @@ TEST(Reduce, SubSearchResultDesc) { std::default_random_engine e(42); SubSearchResult final_result(num_queries, topk, metric_type, round_decimal); for (int i = 0; i < iteration; ++i) { - std::vector labels; - std::vector values; + std::vector ids; + std::vector distances; for (int n = 0; n < num_queries; ++n) { for (int k = 0; k < topk; ++k) { auto gen_x = e() % limit; ref_results[n].push(gen_x); ref_results[n].pop(); - labels.push_back(gen_x); - values.push_back(gen_x); + ids.push_back(gen_x); + distances.push_back(gen_x); } - std::sort(labels.begin() + n * topk, labels.begin() + n * topk + topk, std::greater()); - std::sort(values.begin() + n * topk, values.begin() + n * topk + topk, std::greater()); + std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk, std::greater()); + std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk, std::greater()); } SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal); - sub_result.mutable_values() = values; - sub_result.mutable_labels() = labels; + sub_result.mutable_distances() = distances; + sub_result.mutable_ids() = ids; final_result.merge(sub_result); } @@ -114,10 +114,10 @@ TEST(Reduce, SubSearchResultDesc) { auto ref_x = ref_results[n].top(); ref_results[n].pop(); auto index = n * topk + topk - 1 - k; - auto label = final_result.get_labels()[index]; - auto value = final_result.get_values()[index]; - ASSERT_EQ(label, ref_x); - ASSERT_EQ(value, ref_x); + auto id = final_result.get_ids()[index]; + auto distance = final_result.get_distances()[index]; + ASSERT_EQ(id, ref_x); + ASSERT_EQ(distance, ref_x); } } } \ No newline at end of file