mirror of https://github.com/milvus-io/milvus.git
Let growing segment call knowhere brute search API (#18227)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/18240/head
parent
ff3de654c8
commit
015a2f0866
|
@ -12,105 +12,26 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
#include <faiss/utils/distances.h>
|
||||
|
||||
#include "SearchBruteForce.h"
|
||||
#include "SubSearchResult.h"
|
||||
#include "common/Types.h"
|
||||
#include "segcore/Utils.h"
|
||||
#include "knowhere/archive/BruteForce.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
// copy from faiss/IndexBinaryFlat.cpp::IndexBinaryFlat::search()
|
||||
// disable lint to make further migration easier
|
||||
static void
|
||||
binary_search(const knowhere::MetricType& metric_type,
|
||||
const uint8_t* xb,
|
||||
int64_t ntotal,
|
||||
int code_size,
|
||||
idx_t n, // num_queries
|
||||
const uint8_t* x,
|
||||
idx_t k, // topk
|
||||
float* D,
|
||||
idx_t* labels,
|
||||
const BitsetView bitset) {
|
||||
using namespace faiss; // NOLINT
|
||||
if (metric_type == knowhere::metric::JACCARD || metric_type == knowhere::metric::TANIMOTO) {
|
||||
float_maxheap_array_t res = {size_t(n), size_t(k), labels, D};
|
||||
binary_distance_knn_hc(METRIC_Jaccard, &res, x, xb, ntotal, code_size, bitset);
|
||||
|
||||
if (metric_type == knowhere::metric::TANIMOTO) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = Jaccard_2_Tanimoto(D[i]);
|
||||
}
|
||||
}
|
||||
} else if (metric_type == knowhere::metric::HAMMING) {
|
||||
std::vector<int32_t> int_distances(n * k);
|
||||
int_maxheap_array_t res = {size_t(n), size_t(k), labels, int_distances.data()};
|
||||
binary_distance_knn_hc(METRIC_Hamming, &res, x, xb, ntotal, code_size, bitset);
|
||||
for (int i = 0; i < n * k; ++i) {
|
||||
D[i] = int_distances[i];
|
||||
}
|
||||
} else if (metric_type == knowhere::metric::SUBSTRUCTURE || metric_type == knowhere::metric::SUPERSTRUCTURE) {
|
||||
// only matched ids will be chosen, not to use heap
|
||||
auto faiss_metric_type = knowhere::GetFaissMetricType(metric_type);
|
||||
binary_distance_knn_mc(faiss_metric_type, x, xb, n, ntotal, k, code_size, D, labels, bitset);
|
||||
} else {
|
||||
std::string msg = "binary search not support metric type: " + metric_type;
|
||||
PanicInfo(msg);
|
||||
}
|
||||
}
|
||||
|
||||
SubSearchResult
|
||||
BinarySearchBruteForce(const dataset::SearchDataset& dataset,
|
||||
BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t size_per_chunk,
|
||||
int64_t chunk_rows,
|
||||
const BitsetView& bitset) {
|
||||
// TODO: refactor the internal function
|
||||
auto metric_type = dataset.metric_type;
|
||||
auto num_queries = dataset.num_queries;
|
||||
auto topk = dataset.topk;
|
||||
auto dim = dataset.dim;
|
||||
auto round_decimal = dataset.round_decimal;
|
||||
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
|
||||
auto query_data = reinterpret_cast<const uint8_t*>(dataset.query_data);
|
||||
auto chunk_data = reinterpret_cast<const uint8_t*>(chunk_data_raw);
|
||||
|
||||
int64_t code_size = dim / 8;
|
||||
binary_search(metric_type, chunk_data, size_per_chunk, code_size, num_queries, query_data, topk,
|
||||
sub_result.get_distances(), sub_result.get_seg_offsets(), bitset);
|
||||
SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal);
|
||||
try {
|
||||
knowhere::BruteForceSearch(dataset.metric_type, chunk_data_raw, dataset.query_data, dataset.dim, chunk_rows,
|
||||
dataset.num_queries, dataset.topk, sub_result.get_seg_offsets(),
|
||||
sub_result.get_distances(), bitset);
|
||||
} catch (std::exception& e) {
|
||||
PanicInfo(e.what());
|
||||
}
|
||||
sub_result.round_values();
|
||||
return sub_result;
|
||||
}
|
||||
|
||||
SubSearchResult
|
||||
FloatSearchBruteForce(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t size_per_chunk,
|
||||
const BitsetView& bitset) {
|
||||
auto metric_type = dataset.metric_type;
|
||||
auto num_queries = dataset.num_queries;
|
||||
auto topk = dataset.topk;
|
||||
auto dim = dataset.dim;
|
||||
auto round_decimal = dataset.round_decimal;
|
||||
SubSearchResult sub_qr(num_queries, topk, metric_type, round_decimal);
|
||||
auto query_data = reinterpret_cast<const float*>(dataset.query_data);
|
||||
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
|
||||
if (metric_type == knowhere::metric::L2) {
|
||||
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
|
||||
sub_qr.get_distances()};
|
||||
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, nullptr, bitset);
|
||||
} else if (metric_type == knowhere::metric::IP) {
|
||||
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
|
||||
sub_qr.get_distances()};
|
||||
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
|
||||
} else {
|
||||
std::string msg = "search not support metric type: " + metric_type;
|
||||
PanicInfo(msg);
|
||||
}
|
||||
sub_qr.round_values();
|
||||
return sub_qr;
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -11,24 +11,16 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "common/Schema.h"
|
||||
#include "common/BitsetView.h"
|
||||
#include "query/SubSearchResult.h"
|
||||
#include "query/helper.h"
|
||||
#include "segcore/ConcurrentVector.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
SubSearchResult
|
||||
BinarySearchBruteForce(const dataset::SearchDataset& dataset,
|
||||
BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t size_per_chunk,
|
||||
const BitsetView& bitset);
|
||||
|
||||
SubSearchResult
|
||||
FloatSearchBruteForce(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t size_per_chunk,
|
||||
int64_t chunk_rows,
|
||||
const BitsetView& bitset);
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -89,7 +89,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
|||
auto size_per_chunk = element_end - element_begin;
|
||||
|
||||
auto sub_view = bitset.subview(element_begin, size_per_chunk);
|
||||
auto sub_qr = FloatSearchBruteForce(search_dataset, chunk.data(), size_per_chunk, sub_view);
|
||||
auto sub_qr = BruteForceSearch(search_dataset, chunk.data(), size_per_chunk, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_seg_offsets()) {
|
||||
|
@ -150,7 +150,7 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
|
|||
auto nsize = element_end - element_begin;
|
||||
|
||||
auto sub_view = bitset.subview(element_begin, nsize);
|
||||
auto sub_result = BinarySearchBruteForce(search_dataset, chunk.data(), nsize, sub_view);
|
||||
auto sub_result = BruteForceSearch(search_dataset, chunk.data(), nsize, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_result.mutable_seg_offsets()) {
|
||||
|
|
|
@ -381,14 +381,7 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
|
|||
auto vec_data = insert_record_.get_field_data_base(field_id);
|
||||
AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
auto chunk_data = vec_data->get_chunk_data(0);
|
||||
|
||||
auto sub_qr = [&] {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
return query::FloatSearchBruteForce(dataset, chunk_data, row_count, bitset);
|
||||
} else {
|
||||
return query::BinarySearchBruteForce(dataset, chunk_data, row_count, bitset);
|
||||
}
|
||||
}();
|
||||
auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset);
|
||||
|
||||
SearchResult results;
|
||||
results.distances_ = std::move(sub_qr.mutable_distances());
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
set( KNOWHERE_VERSION v1.1.13 )
|
||||
set( KNOWHERE_SOURCE_MD5 "5ea7ce8ae71b4aa496ee3c66ccf56d5a")
|
||||
set( KNOWHERE_VERSION v1.1.14 )
|
||||
set( KNOWHERE_SOURCE_MD5 "de9303c3f667662aa92f3676a1f6ef96")
|
||||
|
||||
if ( DEFINED ENV{MILVUS_KNOWHERE_URL} )
|
||||
set( KNOWHERE_SOURCE_URL "$ENV{MILVUS_KNOWHERE_URL}" )
|
||||
|
|
|
@ -114,10 +114,10 @@ class TestFloatSearchBruteForce : public ::testing::Test {
|
|||
|
||||
dataset::SearchDataset dataset{metric_type, nq, topk, -1, dim, query.data()};
|
||||
if (!is_supported_float_metric(metric_type)) {
|
||||
ASSERT_ANY_THROW(FloatSearchBruteForce(dataset, base.data(), nb, bitset_view));
|
||||
ASSERT_ANY_THROW(BruteForceSearch(dataset, base.data(), nb, bitset_view));
|
||||
return;
|
||||
}
|
||||
auto result = FloatSearchBruteForce(dataset, base.data(), nb, bitset_view);
|
||||
auto result = BruteForceSearch(dataset, base.data(), nb, bitset_view);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type);
|
||||
auto ans = result.get_seg_offsets() + i * topk;
|
||||
|
|
|
@ -324,7 +324,7 @@ TEST(Indexing, BinaryBruteForce) {
|
|||
query_data //
|
||||
};
|
||||
|
||||
auto sub_result = query::BinarySearchBruteForce(search_dataset, bin_vec.data(), N, nullptr);
|
||||
auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, nullptr);
|
||||
|
||||
SearchResult sr;
|
||||
sr.total_nq_ = num_queries;
|
||||
|
|
|
@ -538,7 +538,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
|
|||
dim, //
|
||||
query_ptr //
|
||||
};
|
||||
auto sub_result = FloatSearchBruteForce(search_dataset, vec_col.data(), N, nullptr);
|
||||
auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, nullptr);
|
||||
|
||||
auto sr = segment->Search(plan.get(), ph_group.get(), time);
|
||||
segment->FillPrimaryKeys(plan.get(), *sr);
|
||||
|
|
Loading…
Reference in New Issue