Let growing segment call knowhere brute search API (#18227)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/18240/head
Cai Yudong 2022-07-12 11:58:25 +08:00 committed by GitHub
parent ff3de654c8
commit 015a2f0866
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 26 additions and 120 deletions

View File

@ -12,105 +12,26 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <faiss/utils/BinaryDistance.h>
#include <faiss/utils/distances.h>
#include "SearchBruteForce.h" #include "SearchBruteForce.h"
#include "SubSearchResult.h" #include "knowhere/archive/BruteForce.h"
#include "common/Types.h"
#include "segcore/Utils.h"
namespace milvus::query { namespace milvus::query {
// copy from faiss/IndexBinaryFlat.cpp::IndexBinaryFlat::search()
// disable lint to make further migration easier
static void
binary_search(const knowhere::MetricType& metric_type,
const uint8_t* xb,
int64_t ntotal,
int code_size,
idx_t n, // num_queries
const uint8_t* x,
idx_t k, // topk
float* D,
idx_t* labels,
const BitsetView bitset) {
using namespace faiss; // NOLINT
if (metric_type == knowhere::metric::JACCARD || metric_type == knowhere::metric::TANIMOTO) {
float_maxheap_array_t res = {size_t(n), size_t(k), labels, D};
binary_distance_knn_hc(METRIC_Jaccard, &res, x, xb, ntotal, code_size, bitset);
if (metric_type == knowhere::metric::TANIMOTO) {
for (int i = 0; i < k * n; i++) {
D[i] = Jaccard_2_Tanimoto(D[i]);
}
}
} else if (metric_type == knowhere::metric::HAMMING) {
std::vector<int32_t> int_distances(n * k);
int_maxheap_array_t res = {size_t(n), size_t(k), labels, int_distances.data()};
binary_distance_knn_hc(METRIC_Hamming, &res, x, xb, ntotal, code_size, bitset);
for (int i = 0; i < n * k; ++i) {
D[i] = int_distances[i];
}
} else if (metric_type == knowhere::metric::SUBSTRUCTURE || metric_type == knowhere::metric::SUPERSTRUCTURE) {
// only matched ids will be chosen, not to use heap
auto faiss_metric_type = knowhere::GetFaissMetricType(metric_type);
binary_distance_knn_mc(faiss_metric_type, x, xb, n, ntotal, k, code_size, D, labels, bitset);
} else {
std::string msg = "binary search not support metric type: " + metric_type;
PanicInfo(msg);
}
}
SubSearchResult SubSearchResult
BinarySearchBruteForce(const dataset::SearchDataset& dataset, BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw, const void* chunk_data_raw,
int64_t size_per_chunk, int64_t chunk_rows,
const BitsetView& bitset) { const BitsetView& bitset) {
// TODO: refactor the internal function SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal);
auto metric_type = dataset.metric_type; try {
auto num_queries = dataset.num_queries; knowhere::BruteForceSearch(dataset.metric_type, chunk_data_raw, dataset.query_data, dataset.dim, chunk_rows,
auto topk = dataset.topk; dataset.num_queries, dataset.topk, sub_result.get_seg_offsets(),
auto dim = dataset.dim; sub_result.get_distances(), bitset);
auto round_decimal = dataset.round_decimal; } catch (std::exception& e) {
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal); PanicInfo(e.what());
auto query_data = reinterpret_cast<const uint8_t*>(dataset.query_data); }
auto chunk_data = reinterpret_cast<const uint8_t*>(chunk_data_raw);
int64_t code_size = dim / 8;
binary_search(metric_type, chunk_data, size_per_chunk, code_size, num_queries, query_data, topk,
sub_result.get_distances(), sub_result.get_seg_offsets(), bitset);
sub_result.round_values(); sub_result.round_values();
return sub_result; return sub_result;
} }
SubSearchResult
FloatSearchBruteForce(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t size_per_chunk,
const BitsetView& bitset) {
auto metric_type = dataset.metric_type;
auto num_queries = dataset.num_queries;
auto topk = dataset.topk;
auto dim = dataset.dim;
auto round_decimal = dataset.round_decimal;
SubSearchResult sub_qr(num_queries, topk, metric_type, round_decimal);
auto query_data = reinterpret_cast<const float*>(dataset.query_data);
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
if (metric_type == knowhere::metric::L2) {
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
sub_qr.get_distances()};
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, nullptr, bitset);
} else if (metric_type == knowhere::metric::IP) {
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
sub_qr.get_distances()};
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
} else {
std::string msg = "search not support metric type: " + metric_type;
PanicInfo(msg);
}
sub_qr.round_values();
return sub_qr;
}
} // namespace milvus::query } // namespace milvus::query

View File

@ -11,24 +11,16 @@
#pragma once #pragma once
#include "common/Schema.h"
#include "common/BitsetView.h" #include "common/BitsetView.h"
#include "query/SubSearchResult.h" #include "query/SubSearchResult.h"
#include "query/helper.h" #include "query/helper.h"
#include "segcore/ConcurrentVector.h"
namespace milvus::query { namespace milvus::query {
SubSearchResult SubSearchResult
BinarySearchBruteForce(const dataset::SearchDataset& dataset, BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw, const void* chunk_data_raw,
int64_t size_per_chunk, int64_t chunk_rows,
const BitsetView& bitset);
SubSearchResult
FloatSearchBruteForce(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t size_per_chunk,
const BitsetView& bitset); const BitsetView& bitset);
} // namespace milvus::query } // namespace milvus::query

View File

@ -89,7 +89,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
auto size_per_chunk = element_end - element_begin; auto size_per_chunk = element_end - element_begin;
auto sub_view = bitset.subview(element_begin, size_per_chunk); auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_qr = FloatSearchBruteForce(search_dataset, chunk.data(), size_per_chunk, sub_view); auto sub_qr = BruteForceSearch(search_dataset, chunk.data(), size_per_chunk, sub_view);
// convert chunk uid to segment uid // convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) { for (auto& x : sub_qr.mutable_seg_offsets()) {
@ -150,7 +150,7 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
auto nsize = element_end - element_begin; auto nsize = element_end - element_begin;
auto sub_view = bitset.subview(element_begin, nsize); auto sub_view = bitset.subview(element_begin, nsize);
auto sub_result = BinarySearchBruteForce(search_dataset, chunk.data(), nsize, sub_view); auto sub_result = BruteForceSearch(search_dataset, chunk.data(), nsize, sub_view);
// convert chunk uid to segment uid // convert chunk uid to segment uid
for (auto& x : sub_result.mutable_seg_offsets()) { for (auto& x : sub_result.mutable_seg_offsets()) {

View File

@ -381,14 +381,7 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
auto vec_data = insert_record_.get_field_data_base(field_id); auto vec_data = insert_record_.get_field_data_base(field_id);
AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment"); AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
auto chunk_data = vec_data->get_chunk_data(0); auto chunk_data = vec_data->get_chunk_data(0);
auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset);
auto sub_qr = [&] {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
return query::FloatSearchBruteForce(dataset, chunk_data, row_count, bitset);
} else {
return query::BinarySearchBruteForce(dataset, chunk_data, row_count, bitset);
}
}();
SearchResult results; SearchResult results;
results.distances_ = std::move(sub_qr.mutable_distances()); results.distances_ = std::move(sub_qr.mutable_distances());

View File

@ -11,8 +11,8 @@
# or implied. See the License for the specific language governing permissions and limitations under the License. # or implied. See the License for the specific language governing permissions and limitations under the License.
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
set( KNOWHERE_VERSION v1.1.13 ) set( KNOWHERE_VERSION v1.1.14 )
set( KNOWHERE_SOURCE_MD5 "5ea7ce8ae71b4aa496ee3c66ccf56d5a") set( KNOWHERE_SOURCE_MD5 "de9303c3f667662aa92f3676a1f6ef96")
if ( DEFINED ENV{MILVUS_KNOWHERE_URL} ) if ( DEFINED ENV{MILVUS_KNOWHERE_URL} )
set( KNOWHERE_SOURCE_URL "$ENV{MILVUS_KNOWHERE_URL}" ) set( KNOWHERE_SOURCE_URL "$ENV{MILVUS_KNOWHERE_URL}" )

View File

@ -114,10 +114,10 @@ class TestFloatSearchBruteForce : public ::testing::Test {
dataset::SearchDataset dataset{metric_type, nq, topk, -1, dim, query.data()}; dataset::SearchDataset dataset{metric_type, nq, topk, -1, dim, query.data()};
if (!is_supported_float_metric(metric_type)) { if (!is_supported_float_metric(metric_type)) {
ASSERT_ANY_THROW(FloatSearchBruteForce(dataset, base.data(), nb, bitset_view)); ASSERT_ANY_THROW(BruteForceSearch(dataset, base.data(), nb, bitset_view));
return; return;
} }
auto result = FloatSearchBruteForce(dataset, base.data(), nb, bitset_view); auto result = BruteForceSearch(dataset, base.data(), nb, bitset_view);
for (int i = 0; i < nq; i++) { for (int i = 0; i < nq; i++) {
auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type); auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type);
auto ans = result.get_seg_offsets() + i * topk; auto ans = result.get_seg_offsets() + i * topk;

View File

@ -324,7 +324,7 @@ TEST(Indexing, BinaryBruteForce) {
query_data // query_data //
}; };
auto sub_result = query::BinarySearchBruteForce(search_dataset, bin_vec.data(), N, nullptr); auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, nullptr);
SearchResult sr; SearchResult sr;
sr.total_nq_ = num_queries; sr.total_nq_ = num_queries;

View File

@ -538,7 +538,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
dim, // dim, //
query_ptr // query_ptr //
}; };
auto sub_result = FloatSearchBruteForce(search_dataset, vec_col.data(), N, nullptr); auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, nullptr);
auto sr = segment->Search(plan.get(), ph_group.get(), time); auto sr = segment->Search(plan.get(), ph_group.get(), time);
segment->FillPrimaryKeys(plan.get(), *sr); segment->FillPrimaryKeys(plan.get(), *sr);