Fix inner product

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2021-01-07 09:32:17 +08:00 committed by yefu.chen
parent bf3a6dc3c6
commit 3bf205d9a8
14 changed files with 294 additions and 73 deletions

View File

@ -106,6 +106,13 @@ struct FieldMeta {
return vector_info_->dim_;
}
MetricType
get_metric_type() const {
Assert(is_vector());
Assert(vector_info_.has_value());
return vector_info_->metric_type_;
}
const std::string&
get_name() const {
return name_;

View File

@ -20,24 +20,25 @@
namespace milvus {
using boost::algorithm::to_lower_copy;
using boost::algorithm::to_upper_copy;
namespace Metric = knowhere::Metric;
static const auto metric_bimap = [] {
boost::bimap<std::string, MetricType> mapping;
using pos = boost::bimap<std::string, MetricType>::value_type;
mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2));
mapping.insert(pos(to_lower_copy(std::string(Metric::IP)), MetricType::METRIC_INNER_PRODUCT));
mapping.insert(pos(to_lower_copy(std::string(Metric::JACCARD)), MetricType::METRIC_Jaccard));
mapping.insert(pos(to_lower_copy(std::string(Metric::TANIMOTO)), MetricType::METRIC_Tanimoto));
mapping.insert(pos(to_lower_copy(std::string(Metric::HAMMING)), MetricType::METRIC_Hamming));
mapping.insert(pos(to_lower_copy(std::string(Metric::SUBSTRUCTURE)), MetricType::METRIC_Substructure));
mapping.insert(pos(to_lower_copy(std::string(Metric::SUPERSTRUCTURE)), MetricType::METRIC_Superstructure));
mapping.insert(pos(std::string(Metric::L2), MetricType::METRIC_L2));
mapping.insert(pos(std::string(Metric::IP), MetricType::METRIC_INNER_PRODUCT));
mapping.insert(pos(std::string(Metric::JACCARD), MetricType::METRIC_Jaccard));
mapping.insert(pos(std::string(Metric::TANIMOTO), MetricType::METRIC_Tanimoto));
mapping.insert(pos(std::string(Metric::HAMMING), MetricType::METRIC_Hamming));
mapping.insert(pos(std::string(Metric::SUBSTRUCTURE), MetricType::METRIC_Substructure));
mapping.insert(pos(std::string(Metric::SUPERSTRUCTURE), MetricType::METRIC_Superstructure));
return mapping;
}();
MetricType
GetMetricType(const std::string& type_name) {
auto real_name = to_lower_copy(type_name);
// Assume Metric is all upper at knowhere
auto real_name = to_upper_copy(type_name);
AssertInfo(metric_bimap.left.count(real_name), "metric type not found: (" + type_name + ")");
return metric_bimap.left.at(real_name);
}

View File

@ -12,6 +12,7 @@ set(MILVUS_QUERY_SRCS
Plan.cpp
Search.cpp
SearchOnSealed.cpp
SearchOnIndex.cpp
SearchBruteForce.cpp
SubQueryResult.cpp
)

View File

@ -17,6 +17,7 @@
#include <faiss/utils/distances.h>
#include "utils/tools.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h"
namespace milvus::query {
@ -65,11 +66,14 @@ FloatSearch(const segcore::SegmentSmallIndex& segment,
auto dim = field.get_dim();
auto topK = info.topK_;
auto total_count = topK * num_queries;
auto metric_type = GetMetricType(info.metric_type_);
// TODO: optimize
// step 3: small indexing search
std::vector<int64_t> final_uids(total_count, -1);
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
// std::vector<int64_t> final_uids(total_count, -1);
// std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
SubQueryResult final_qr(num_queries, topK, metric_type);
dataset::FloatQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data};
auto max_indexed_id = indexing_record.get_finished_ack();
const auto& indexing_entry = indexing_record.get_vec_entry(vecfield_offset);
@ -77,20 +81,18 @@ FloatSearch(const segcore::SegmentSmallIndex& segment,
// TODO: use sub_qr
for (int chunk_id = 0; chunk_id < max_indexed_id; ++chunk_id) {
auto bitset = create_bitmap_view(bitmaps_opt, chunk_id);
auto indexing = indexing_entry.get_vec_indexing(chunk_id);
auto dataset = knowhere::GenDataset(num_queries, dim, query_data);
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
auto ans = indexing->Query(dataset, search_conf, bitmap_view);
auto dis = ans->Get<float*>(milvus::knowhere::meta::DISTANCE);
auto uids = ans->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto sub_qr = SearchOnIndex(query_dataset, *indexing, search_conf, bitset);
// convert chunk uid to segment uid
for (int64_t i = 0; i < total_count; ++i) {
auto& x = uids[i];
for (auto& x : sub_qr.mutable_labels()) {
if (x != -1) {
x += chunk_id * indexing_entry.get_chunk_size();
}
}
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), dis, uids);
final_qr.merge(sub_qr);
}
using segcore::FloatVector;
auto vec_ptr = record.get_entity<FloatVector>(vecfield_offset);
@ -100,37 +102,28 @@ FloatSearch(const segcore::SegmentSmallIndex& segment,
Assert(vec_chunk_size == indexing_entry.get_chunk_size());
auto max_chunk = upper_div(ins_barrier, vec_chunk_size);
// TODO: use sub_qr
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
// should be not visitable
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
auto& chunk = vec_ptr->get_chunk(chunk_id);
auto element_begin = chunk_id * vec_chunk_size;
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_chunk_size);
auto chunk_size = element_end - element_begin;
auto nsize = element_end - element_begin;
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
// TODO: make it wrapped
faiss::knn_L2sqr(query_data, chunk.data(), dim, num_queries, nsize, &buf, bitmap_view);
Assert(buf_uids.size() == total_count);
auto sub_qr = FloatSearchBruteForce(query_dataset, chunk.data(), chunk_size, bitmap_view);
// convert chunk uid to segment uid
for (auto& x : buf_uids) {
for (auto& x : sub_qr.mutable_labels()) {
if (x != -1) {
x += chunk_id * vec_chunk_size;
}
}
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
final_qr.merge(sub_qr);
}
results.result_distances_ = std::move(final_dis);
results.internal_seg_offsets_ = std::move(final_uids);
results.result_distances_ = std::move(final_qr.mutable_values());
results.internal_seg_offsets_ = std::move(final_qr.mutable_labels());
results.topK_ = topK;
results.num_queries_ = num_queries;
@ -168,14 +161,13 @@ BinarySearch(const segcore::SegmentSmallIndex& segment,
Assert(field.get_data_type() == DataType::VECTOR_BINARY);
auto dim = field.get_dim();
auto code_size = dim / 8;
auto topK = info.topK_;
auto total_count = topK * num_queries;
// step 3: small indexing search
// TODO: this is too intrusive
// TODO: use QuerySubResult instead
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data};
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data};
using segcore::BinaryVector;
auto vec_ptr = record.get_entity<BinaryVector>(vecfield_offset);

View File

@ -16,11 +16,13 @@
#include <queue>
#include "SubQueryResult.h"
#include <faiss/utils/distances.h>
namespace milvus::query {
SubQueryResult
BinarySearchBruteForceFast(MetricType metric_type,
int64_t code_size,
int64_t dim,
const uint8_t* binary_chunk,
int64_t chunk_size,
int64_t topk,
@ -31,6 +33,7 @@ BinarySearchBruteForceFast(MetricType metric_type,
float* result_distances = sub_result.get_values();
idx_t* result_labels = sub_result.get_labels();
int64_t code_size = dim / 8;
const idx_t block_size = chunk_size;
bool use_heap = true;
@ -95,14 +98,26 @@ BinarySearchBruteForceFast(MetricType metric_type,
return sub_result;
}
void
FloatSearchBruteForceFast(MetricType metric_type,
const float* chunk_data,
int64_t chunk_size,
float* result_distances,
idx_t* result_labels,
const faiss::BitsetView& bitset) {
// TODO
SubQueryResult
FloatSearchBruteForce(const dataset::FloatQueryDataset& query_dataset,
const float* chunk_data,
int64_t chunk_size,
const faiss::BitsetView& bitset) {
auto metric_type = query_dataset.metric_type;
auto num_queries = query_dataset.num_queries;
auto topk = query_dataset.topk;
auto dim = query_dataset.dim;
SubQueryResult sub_qr(num_queries, topk, metric_type);
if (metric_type == MetricType::METRIC_L2) {
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
faiss::knn_L2sqr(query_dataset.query_data, chunk_data, dim, num_queries, chunk_size, &buf, bitset);
return sub_qr;
} else {
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
faiss::knn_inner_product(query_dataset.query_data, chunk_data, dim, num_queries, chunk_size, &buf, bitset);
return sub_qr;
}
}
SubQueryResult
@ -111,7 +126,7 @@ BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
int64_t chunk_size,
const faiss::BitsetView& bitset) {
// TODO: refactor the internal function
return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size,
return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.dim, binary_chunk, chunk_size,
query_dataset.topk, query_dataset.num_queries, query_dataset.query_data, bitset);
}
} // namespace milvus::query

View File

@ -14,25 +14,20 @@
#include "segcore/ConcurrentVector.h"
#include "common/Schema.h"
#include "query/SubQueryResult.h"
#include "query/helper.h"
namespace milvus::query {
using MetricType = faiss::MetricType;
namespace dataset {
struct BinaryQueryDataset {
MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t code_size;
const uint8_t* query_data;
};
} // namespace dataset
SubQueryResult
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
const uint8_t* binary_chunk,
int64_t chunk_size,
const faiss::BitsetView& bitset = nullptr);
const faiss::BitsetView& bitset);
SubQueryResult
FloatSearchBruteForce(const dataset::FloatQueryDataset& query_dataset,
const float* chunk_data,
int64_t chunk_size,
const faiss::BitsetView& bitset);
} // namespace milvus::query

View File

@ -0,0 +1,40 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "SearchOnIndex.h"
namespace milvus::query {
SubQueryResult
SearchOnIndex(const dataset::FloatQueryDataset& query_dataset,
const knowhere::VecIndex& indexing,
const knowhere::Config& search_conf,
const faiss::BitsetView& bitset) {
auto num_queries = query_dataset.num_queries;
auto topK = query_dataset.topk;
auto dim = query_dataset.dim;
auto metric_type = query_dataset.metric_type;
auto dataset = knowhere::GenDataset(num_queries, dim, query_dataset.query_data);
// NOTE: VecIndex Query API forget to add const qualifier
// NOTE: use const_cast as a workaround
auto& indexing_nonconst = const_cast<knowhere::VecIndex&>(indexing);
auto ans = indexing_nonconst.Query(dataset, search_conf, bitset);
auto dis = ans->Get<float*>(milvus::knowhere::meta::DISTANCE);
auto uids = ans->Get<int64_t*>(milvus::knowhere::meta::IDS);
SubQueryResult sub_qr(num_queries, topK, metric_type);
std::copy_n(dis, num_queries * topK, sub_qr.get_values());
std::copy_n(uids, num_queries * topK, sub_qr.get_labels());
return sub_qr;
}
} // namespace milvus::query

View File

@ -0,0 +1,27 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "query/SubQueryResult.h"
#include "query/helper.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include "utils/Json.h"
namespace milvus::query {
SubQueryResult
SearchOnIndex(const dataset::FloatQueryDataset& query_dataset,
const knowhere::VecIndex& indexing,
const knowhere::Config& search_conf,
const faiss::BitsetView& bitset);
} // namespace milvus::query

View File

@ -0,0 +1,35 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "common/Types.h"
namespace milvus::query {
namespace dataset {
struct FloatQueryDataset {
MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t dim;
const float* query_data;
};
struct BinaryQueryDataset {
MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t dim;
const uint8_t* query_data;
};
} // namespace dataset
} // namespace milvus::query

View File

@ -22,6 +22,7 @@
#include "utils/EasyAssert.h"
#include "utils/tools.h"
#include <boost/container/vector.hpp>
#include "common/Types.h"
namespace milvus::segcore {
@ -213,10 +214,15 @@ class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
class VectorTrait {};
class FloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
};
class BinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr auto metric_type = DataType::VECTOR_BINARY;
};
template <>

View File

@ -45,7 +45,7 @@ VecIndexingEntry::get_build_conf() const {
return knowhere::Config{{knowhere::meta::DIM, field_meta_.get_dim()},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{knowhere::Metric::TYPE, MetricTypeToName(field_meta_.get_metric_type())},
{knowhere::meta::DEVICEID, 0}};
}
@ -55,7 +55,7 @@ VecIndexingEntry::get_search_conf(int top_K) const {
{knowhere::meta::TOPK, top_K},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{knowhere::Metric::TYPE, MetricTypeToName(field_meta_.get_metric_type())},
{knowhere::meta::DEVICEID, 0}};
}

View File

@ -246,17 +246,16 @@ TEST(Indexing, BinaryBruteForce) {
schema->AddField("age", DataType::INT64);
auto dataset = DataGen(schema, N, 10);
auto bin_vec = dataset.get_col<uint8_t>(0);
auto line_sizeof = schema->operator[](0).get_sizeof();
auto query_data = 1024 * line_sizeof + bin_vec.data();
auto query_data = 1024 * dim / 8 + bin_vec.data();
query::dataset::BinaryQueryDataset query_dataset{
faiss::MetricType::METRIC_Jaccard, //
num_queries, //
topk, //
line_sizeof, //
dim, //
query_data //
};
auto sub_result = query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N);
auto sub_result = query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N, nullptr);
QueryResult qr;
qr.num_queries_ = num_queries;

View File

@ -312,6 +312,51 @@ TEST(Query, ExecTerm) {
// for(auto x: )
}
TEST(Query, ExecEmpty) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
schema->AddField("age", DataType::FLOAT);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
int64_t N = 1000 * 1000;
auto segment = CreateSegment(schema);
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
QueryResult qr;
Timestamp time = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
std::cout << QueryResultToJson(qr);
for (auto i : qr.internal_seg_offsets_) {
ASSERT_EQ(i, -1);
}
for (auto v : qr.result_distances_) {
ASSERT_EQ(v, std::numeric_limits<float>::max());
}
}
TEST(Query, ExecWithoutPredicate) {
using namespace milvus::query;
using namespace milvus::segcore;
@ -336,13 +381,13 @@ TEST(Query, ExecWithoutPredicate) {
]
}
})";
auto plan = CreatePlan(*schema, dsl);
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
@ -397,6 +442,47 @@ TEST(Query, ExecWithoutPredicate) {
ASSERT_EQ(json.dump(2), ref.dump(2));
}
TEST(Indexing, InnerProduct) {
int64_t N = 100000;
constexpr auto dim = 16;
constexpr auto topk = 10;
auto num_queries = 5;
auto schema = std::make_shared<Schema>();
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"normalized": {
"metric_type": "IP",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
schema->AddField("normalized", DataType::VECTOR_FLOAT, dim, MetricType::METRIC_INNER_PRODUCT);
auto dataset = DataGen(schema, N);
auto segment = CreateSegment(schema);
auto plan = CreatePlan(*schema, dsl);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto col = dataset.get_col<float>(0);
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, col.data());
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
std::vector<Timestamp> ts{(Timestamp)N * 2};
const auto* ptr = ph_group.get();
QueryResult qr;
segment->Search(plan.get(), &ptr, ts.data(), 1, qr);
std::cout << QueryResultToJson(qr).dump(2);
}
TEST(Query, FillSegment) {
namespace pb = milvus::proto;
pb::schema::CollectionSchema proto;

View File

@ -16,6 +16,9 @@
#include <cstring>
#include "segcore/SegmentBase.h"
#include "Constants.h"
#include <boost/algorithm/string/predicate.hpp>
using boost::algorithm::starts_with;
namespace milvus::segcore {
struct GeneratedData {
@ -92,11 +95,25 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
switch (field.get_data_type()) {
case engine::DataType::VECTOR_FLOAT: {
auto dim = field.get_dim();
vector<float> data(dim * N);
for (auto& x : data) {
x = distr(er) + offset;
vector<float> final;
bool is_ip = starts_with(field.get_name(), "normalized");
for (int n = 0; n < N; ++n) {
vector<float> data(dim);
float sum = 0;
for (auto& x : data) {
x = distr(er) + offset;
sum += x * x;
}
if (is_ip) {
sum = sqrt(sum);
for (auto& x : data) {
x /= sum;
}
}
final.insert(final.end(), data.begin(), data.end());
}
insert_cols(data);
insert_cols(final);
break;
}
case engine::DataType::VECTOR_BINARY: {
@ -111,9 +128,9 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
}
case engine::DataType::INT64: {
vector<int64_t> data(N);
int64_t index = 0;
// begin with counter
if (field.get_name().rfind("counter", 0) == 0) {
if (starts_with(field.get_name(), "counter")) {
int64_t index = 0;
for (auto& x : data) {
x = index++;
}