mirror of https://github.com/milvus-io/milvus.git
parent
bf3a6dc3c6
commit
3bf205d9a8
|
@ -106,6 +106,13 @@ struct FieldMeta {
|
|||
return vector_info_->dim_;
|
||||
}
|
||||
|
||||
MetricType
|
||||
get_metric_type() const {
|
||||
Assert(is_vector());
|
||||
Assert(vector_info_.has_value());
|
||||
return vector_info_->metric_type_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
get_name() const {
|
||||
return name_;
|
||||
|
|
|
@ -20,24 +20,25 @@
|
|||
|
||||
namespace milvus {
|
||||
|
||||
using boost::algorithm::to_lower_copy;
|
||||
using boost::algorithm::to_upper_copy;
|
||||
namespace Metric = knowhere::Metric;
|
||||
static const auto metric_bimap = [] {
|
||||
boost::bimap<std::string, MetricType> mapping;
|
||||
using pos = boost::bimap<std::string, MetricType>::value_type;
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::IP)), MetricType::METRIC_INNER_PRODUCT));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::JACCARD)), MetricType::METRIC_Jaccard));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::TANIMOTO)), MetricType::METRIC_Tanimoto));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::HAMMING)), MetricType::METRIC_Hamming));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::SUBSTRUCTURE)), MetricType::METRIC_Substructure));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::SUPERSTRUCTURE)), MetricType::METRIC_Superstructure));
|
||||
mapping.insert(pos(std::string(Metric::L2), MetricType::METRIC_L2));
|
||||
mapping.insert(pos(std::string(Metric::IP), MetricType::METRIC_INNER_PRODUCT));
|
||||
mapping.insert(pos(std::string(Metric::JACCARD), MetricType::METRIC_Jaccard));
|
||||
mapping.insert(pos(std::string(Metric::TANIMOTO), MetricType::METRIC_Tanimoto));
|
||||
mapping.insert(pos(std::string(Metric::HAMMING), MetricType::METRIC_Hamming));
|
||||
mapping.insert(pos(std::string(Metric::SUBSTRUCTURE), MetricType::METRIC_Substructure));
|
||||
mapping.insert(pos(std::string(Metric::SUPERSTRUCTURE), MetricType::METRIC_Superstructure));
|
||||
return mapping;
|
||||
}();
|
||||
|
||||
MetricType
|
||||
GetMetricType(const std::string& type_name) {
|
||||
auto real_name = to_lower_copy(type_name);
|
||||
// Assume Metric is all upper at knowhere
|
||||
auto real_name = to_upper_copy(type_name);
|
||||
AssertInfo(metric_bimap.left.count(real_name), "metric type not found: (" + type_name + ")");
|
||||
return metric_bimap.left.at(real_name);
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ set(MILVUS_QUERY_SRCS
|
|||
Plan.cpp
|
||||
Search.cpp
|
||||
SearchOnSealed.cpp
|
||||
SearchOnIndex.cpp
|
||||
SearchBruteForce.cpp
|
||||
SubQueryResult.cpp
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <faiss/utils/distances.h>
|
||||
#include "utils/tools.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnIndex.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
|
@ -65,11 +66,14 @@ FloatSearch(const segcore::SegmentSmallIndex& segment,
|
|||
auto dim = field.get_dim();
|
||||
auto topK = info.topK_;
|
||||
auto total_count = topK * num_queries;
|
||||
auto metric_type = GetMetricType(info.metric_type_);
|
||||
// TODO: optimize
|
||||
|
||||
// step 3: small indexing search
|
||||
std::vector<int64_t> final_uids(total_count, -1);
|
||||
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
|
||||
// std::vector<int64_t> final_uids(total_count, -1);
|
||||
// std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
|
||||
SubQueryResult final_qr(num_queries, topK, metric_type);
|
||||
dataset::FloatQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data};
|
||||
|
||||
auto max_indexed_id = indexing_record.get_finished_ack();
|
||||
const auto& indexing_entry = indexing_record.get_vec_entry(vecfield_offset);
|
||||
|
@ -77,20 +81,18 @@ FloatSearch(const segcore::SegmentSmallIndex& segment,
|
|||
|
||||
// TODO: use sub_qr
|
||||
for (int chunk_id = 0; chunk_id < max_indexed_id; ++chunk_id) {
|
||||
auto bitset = create_bitmap_view(bitmaps_opt, chunk_id);
|
||||
auto indexing = indexing_entry.get_vec_indexing(chunk_id);
|
||||
auto dataset = knowhere::GenDataset(num_queries, dim, query_data);
|
||||
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
|
||||
auto ans = indexing->Query(dataset, search_conf, bitmap_view);
|
||||
auto dis = ans->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
auto uids = ans->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto sub_qr = SearchOnIndex(query_dataset, *indexing, search_conf, bitset);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
auto& x = uids[i];
|
||||
for (auto& x : sub_qr.mutable_labels()) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * indexing_entry.get_chunk_size();
|
||||
}
|
||||
}
|
||||
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), dis, uids);
|
||||
|
||||
final_qr.merge(sub_qr);
|
||||
}
|
||||
using segcore::FloatVector;
|
||||
auto vec_ptr = record.get_entity<FloatVector>(vecfield_offset);
|
||||
|
@ -100,37 +102,28 @@ FloatSearch(const segcore::SegmentSmallIndex& segment,
|
|||
Assert(vec_chunk_size == indexing_entry.get_chunk_size());
|
||||
auto max_chunk = upper_div(ins_barrier, vec_chunk_size);
|
||||
|
||||
// TODO: use sub_qr
|
||||
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
|
||||
std::vector<int64_t> buf_uids(total_count, -1);
|
||||
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
|
||||
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
|
||||
|
||||
// should be not visitable
|
||||
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
|
||||
auto& chunk = vec_ptr->get_chunk(chunk_id);
|
||||
|
||||
auto element_begin = chunk_id * vec_chunk_size;
|
||||
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_chunk_size);
|
||||
auto chunk_size = element_end - element_begin;
|
||||
|
||||
auto nsize = element_end - element_begin;
|
||||
|
||||
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
|
||||
// TODO: make it wrapped
|
||||
faiss::knn_L2sqr(query_data, chunk.data(), dim, num_queries, nsize, &buf, bitmap_view);
|
||||
|
||||
Assert(buf_uids.size() == total_count);
|
||||
auto sub_qr = FloatSearchBruteForce(query_dataset, chunk.data(), chunk_size, bitmap_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : buf_uids) {
|
||||
for (auto& x : sub_qr.mutable_labels()) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * vec_chunk_size;
|
||||
}
|
||||
}
|
||||
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
|
||||
final_qr.merge(sub_qr);
|
||||
}
|
||||
|
||||
results.result_distances_ = std::move(final_dis);
|
||||
results.internal_seg_offsets_ = std::move(final_uids);
|
||||
results.result_distances_ = std::move(final_qr.mutable_values());
|
||||
results.internal_seg_offsets_ = std::move(final_qr.mutable_labels());
|
||||
results.topK_ = topK;
|
||||
results.num_queries_ = num_queries;
|
||||
|
||||
|
@ -168,14 +161,13 @@ BinarySearch(const segcore::SegmentSmallIndex& segment,
|
|||
|
||||
Assert(field.get_data_type() == DataType::VECTOR_BINARY);
|
||||
auto dim = field.get_dim();
|
||||
auto code_size = dim / 8;
|
||||
auto topK = info.topK_;
|
||||
auto total_count = topK * num_queries;
|
||||
|
||||
// step 3: small indexing search
|
||||
// TODO: this is too intrusive
|
||||
// TODO: use QuerySubResult instead
|
||||
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data};
|
||||
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data};
|
||||
|
||||
using segcore::BinaryVector;
|
||||
auto vec_ptr = record.get_entity<BinaryVector>(vecfield_offset);
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
#include <queue>
|
||||
#include "SubQueryResult.h"
|
||||
|
||||
#include <faiss/utils/distances.h>
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
SubQueryResult
|
||||
BinarySearchBruteForceFast(MetricType metric_type,
|
||||
int64_t code_size,
|
||||
int64_t dim,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
int64_t topk,
|
||||
|
@ -31,6 +33,7 @@ BinarySearchBruteForceFast(MetricType metric_type,
|
|||
float* result_distances = sub_result.get_values();
|
||||
idx_t* result_labels = sub_result.get_labels();
|
||||
|
||||
int64_t code_size = dim / 8;
|
||||
const idx_t block_size = chunk_size;
|
||||
bool use_heap = true;
|
||||
|
||||
|
@ -95,14 +98,26 @@ BinarySearchBruteForceFast(MetricType metric_type,
|
|||
return sub_result;
|
||||
}
|
||||
|
||||
void
|
||||
FloatSearchBruteForceFast(MetricType metric_type,
|
||||
const float* chunk_data,
|
||||
int64_t chunk_size,
|
||||
float* result_distances,
|
||||
idx_t* result_labels,
|
||||
const faiss::BitsetView& bitset) {
|
||||
// TODO
|
||||
SubQueryResult
|
||||
FloatSearchBruteForce(const dataset::FloatQueryDataset& query_dataset,
|
||||
const float* chunk_data,
|
||||
int64_t chunk_size,
|
||||
const faiss::BitsetView& bitset) {
|
||||
auto metric_type = query_dataset.metric_type;
|
||||
auto num_queries = query_dataset.num_queries;
|
||||
auto topk = query_dataset.topk;
|
||||
auto dim = query_dataset.dim;
|
||||
SubQueryResult sub_qr(num_queries, topk, metric_type);
|
||||
|
||||
if (metric_type == MetricType::METRIC_L2) {
|
||||
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
|
||||
faiss::knn_L2sqr(query_dataset.query_data, chunk_data, dim, num_queries, chunk_size, &buf, bitset);
|
||||
return sub_qr;
|
||||
} else {
|
||||
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
|
||||
faiss::knn_inner_product(query_dataset.query_data, chunk_data, dim, num_queries, chunk_size, &buf, bitset);
|
||||
return sub_qr;
|
||||
}
|
||||
}
|
||||
|
||||
SubQueryResult
|
||||
|
@ -111,7 +126,7 @@ BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
|
|||
int64_t chunk_size,
|
||||
const faiss::BitsetView& bitset) {
|
||||
// TODO: refactor the internal function
|
||||
return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size,
|
||||
return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.dim, binary_chunk, chunk_size,
|
||||
query_dataset.topk, query_dataset.num_queries, query_dataset.query_data, bitset);
|
||||
}
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -14,25 +14,20 @@
|
|||
#include "segcore/ConcurrentVector.h"
|
||||
#include "common/Schema.h"
|
||||
#include "query/SubQueryResult.h"
|
||||
#include "query/helper.h"
|
||||
|
||||
namespace milvus::query {
|
||||
using MetricType = faiss::MetricType;
|
||||
|
||||
namespace dataset {
|
||||
struct BinaryQueryDataset {
|
||||
MetricType metric_type;
|
||||
int64_t num_queries;
|
||||
int64_t topk;
|
||||
int64_t code_size;
|
||||
const uint8_t* query_data;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
||||
SubQueryResult
|
||||
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
const faiss::BitsetView& bitset = nullptr);
|
||||
const faiss::BitsetView& bitset);
|
||||
|
||||
SubQueryResult
|
||||
FloatSearchBruteForce(const dataset::FloatQueryDataset& query_dataset,
|
||||
const float* chunk_data,
|
||||
int64_t chunk_size,
|
||||
const faiss::BitsetView& bitset);
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "SearchOnIndex.h"
|
||||
namespace milvus::query {
|
||||
SubQueryResult
|
||||
SearchOnIndex(const dataset::FloatQueryDataset& query_dataset,
|
||||
const knowhere::VecIndex& indexing,
|
||||
const knowhere::Config& search_conf,
|
||||
const faiss::BitsetView& bitset) {
|
||||
auto num_queries = query_dataset.num_queries;
|
||||
auto topK = query_dataset.topk;
|
||||
auto dim = query_dataset.dim;
|
||||
auto metric_type = query_dataset.metric_type;
|
||||
|
||||
auto dataset = knowhere::GenDataset(num_queries, dim, query_dataset.query_data);
|
||||
|
||||
// NOTE: VecIndex Query API forget to add const qualifier
|
||||
// NOTE: use const_cast as a workaround
|
||||
auto& indexing_nonconst = const_cast<knowhere::VecIndex&>(indexing);
|
||||
auto ans = indexing_nonconst.Query(dataset, search_conf, bitset);
|
||||
|
||||
auto dis = ans->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
auto uids = ans->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
|
||||
SubQueryResult sub_qr(num_queries, topK, metric_type);
|
||||
std::copy_n(dis, num_queries * topK, sub_qr.get_values());
|
||||
std::copy_n(uids, num_queries * topK, sub_qr.get_labels());
|
||||
return sub_qr;
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "query/SubQueryResult.h"
|
||||
#include "query/helper.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include "utils/Json.h"
|
||||
|
||||
namespace milvus::query {
|
||||
SubQueryResult
|
||||
SearchOnIndex(const dataset::FloatQueryDataset& query_dataset,
|
||||
const knowhere::VecIndex& indexing,
|
||||
const knowhere::Config& search_conf,
|
||||
const faiss::BitsetView& bitset);
|
||||
|
||||
} // namespace milvus::query
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
#include "common/Types.h"
|
||||
|
||||
namespace milvus::query {
|
||||
namespace dataset {
|
||||
|
||||
struct FloatQueryDataset {
|
||||
MetricType metric_type;
|
||||
int64_t num_queries;
|
||||
int64_t topk;
|
||||
int64_t dim;
|
||||
const float* query_data;
|
||||
};
|
||||
|
||||
struct BinaryQueryDataset {
|
||||
MetricType metric_type;
|
||||
int64_t num_queries;
|
||||
int64_t topk;
|
||||
int64_t dim;
|
||||
const uint8_t* query_data;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace milvus::query
|
|
@ -22,6 +22,7 @@
|
|||
#include "utils/EasyAssert.h"
|
||||
#include "utils/tools.h"
|
||||
#include <boost/container/vector.hpp>
|
||||
#include "common/Types.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -213,10 +214,15 @@ class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
|
|||
class VectorTrait {};
|
||||
|
||||
class FloatVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float;
|
||||
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
|
||||
};
|
||||
|
||||
class BinaryVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = uint8_t;
|
||||
static constexpr auto metric_type = DataType::VECTOR_BINARY;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
|
@ -45,7 +45,7 @@ VecIndexingEntry::get_build_conf() const {
|
|||
return knowhere::Config{{knowhere::meta::DIM, field_meta_.get_dim()},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{knowhere::Metric::TYPE, MetricTypeToName(field_meta_.get_metric_type())},
|
||||
{knowhere::meta::DEVICEID, 0}};
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ VecIndexingEntry::get_search_conf(int top_K) const {
|
|||
{knowhere::meta::TOPK, top_K},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{knowhere::Metric::TYPE, MetricTypeToName(field_meta_.get_metric_type())},
|
||||
{knowhere::meta::DEVICEID, 0}};
|
||||
}
|
||||
|
||||
|
|
|
@ -246,17 +246,16 @@ TEST(Indexing, BinaryBruteForce) {
|
|||
schema->AddField("age", DataType::INT64);
|
||||
auto dataset = DataGen(schema, N, 10);
|
||||
auto bin_vec = dataset.get_col<uint8_t>(0);
|
||||
auto line_sizeof = schema->operator[](0).get_sizeof();
|
||||
auto query_data = 1024 * line_sizeof + bin_vec.data();
|
||||
auto query_data = 1024 * dim / 8 + bin_vec.data();
|
||||
query::dataset::BinaryQueryDataset query_dataset{
|
||||
faiss::MetricType::METRIC_Jaccard, //
|
||||
num_queries, //
|
||||
topk, //
|
||||
line_sizeof, //
|
||||
dim, //
|
||||
query_data //
|
||||
};
|
||||
|
||||
auto sub_result = query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N);
|
||||
auto sub_result = query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N, nullptr);
|
||||
|
||||
QueryResult qr;
|
||||
qr.num_queries_ = num_queries;
|
||||
|
|
|
@ -312,6 +312,51 @@ TEST(Query, ExecTerm) {
|
|||
// for(auto x: )
|
||||
}
|
||||
|
||||
TEST(Query, ExecEmpty) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("age", DataType::FLOAT);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
int64_t N = 1000 * 1000;
|
||||
auto segment = CreateSegment(schema);
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
QueryResult qr;
|
||||
Timestamp time = 1000000;
|
||||
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
|
||||
std::cout << QueryResultToJson(qr);
|
||||
|
||||
for (auto i : qr.internal_seg_offsets_) {
|
||||
ASSERT_EQ(i, -1);
|
||||
}
|
||||
|
||||
for (auto v : qr.result_distances_) {
|
||||
ASSERT_EQ(v, std::numeric_limits<float>::max());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Query, ExecWithoutPredicate) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
|
@ -336,13 +381,13 @@ TEST(Query, ExecWithoutPredicate) {
|
|||
]
|
||||
}
|
||||
})";
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
int64_t N = 1000 * 1000;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateSegment(schema);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
@ -397,6 +442,47 @@ TEST(Query, ExecWithoutPredicate) {
|
|||
ASSERT_EQ(json.dump(2), ref.dump(2));
|
||||
}
|
||||
|
||||
TEST(Indexing, InnerProduct) {
|
||||
int64_t N = 100000;
|
||||
constexpr auto dim = 16;
|
||||
constexpr auto topk = 10;
|
||||
auto num_queries = 5;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"normalized": {
|
||||
"metric_type": "IP",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
schema->AddField("normalized", DataType::VECTOR_FLOAT, dim, MetricType::METRIC_INNER_PRODUCT);
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateSegment(schema);
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
|
||||
auto col = dataset.get_col<float>(0);
|
||||
|
||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, col.data());
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
std::vector<Timestamp> ts{(Timestamp)N * 2};
|
||||
const auto* ptr = ph_group.get();
|
||||
QueryResult qr;
|
||||
segment->Search(plan.get(), &ptr, ts.data(), 1, qr);
|
||||
std::cout << QueryResultToJson(qr).dump(2);
|
||||
}
|
||||
|
||||
TEST(Query, FillSegment) {
|
||||
namespace pb = milvus::proto;
|
||||
pb::schema::CollectionSchema proto;
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
#include <cstring>
|
||||
#include "segcore/SegmentBase.h"
|
||||
#include "Constants.h"
|
||||
#include <boost/algorithm/string/predicate.hpp>
|
||||
using boost::algorithm::starts_with;
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
struct GeneratedData {
|
||||
|
@ -92,11 +95,25 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
|
|||
switch (field.get_data_type()) {
|
||||
case engine::DataType::VECTOR_FLOAT: {
|
||||
auto dim = field.get_dim();
|
||||
vector<float> data(dim * N);
|
||||
for (auto& x : data) {
|
||||
x = distr(er) + offset;
|
||||
vector<float> final;
|
||||
bool is_ip = starts_with(field.get_name(), "normalized");
|
||||
for (int n = 0; n < N; ++n) {
|
||||
vector<float> data(dim);
|
||||
float sum = 0;
|
||||
for (auto& x : data) {
|
||||
x = distr(er) + offset;
|
||||
sum += x * x;
|
||||
}
|
||||
if (is_ip) {
|
||||
sum = sqrt(sum);
|
||||
for (auto& x : data) {
|
||||
x /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
final.insert(final.end(), data.begin(), data.end());
|
||||
}
|
||||
insert_cols(data);
|
||||
insert_cols(final);
|
||||
break;
|
||||
}
|
||||
case engine::DataType::VECTOR_BINARY: {
|
||||
|
@ -111,9 +128,9 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
|
|||
}
|
||||
case engine::DataType::INT64: {
|
||||
vector<int64_t> data(N);
|
||||
int64_t index = 0;
|
||||
// begin with counter
|
||||
if (field.get_name().rfind("counter", 0) == 0) {
|
||||
if (starts_with(field.get_name(), "counter")) {
|
||||
int64_t index = 0;
|
||||
for (auto& x : data) {
|
||||
x = index++;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue