From 83379977b2c0563710bb8bc11bf203720f4ae404 Mon Sep 17 00:00:00 2001 From: FluorineDog Date: Tue, 24 Nov 2020 19:09:57 +0800 Subject: [PATCH] Enable structured index Signed-off-by: FluorineDog --- internal/core/src/cache/DataObj.h | 3 + .../index/structured_index/StructuredIndex.h | 2 +- .../structured_index_simple/StructuredIndex.h | 89 ++++++++ .../StructuredIndexFlat-inl.h | 153 ++++++++++++++ .../StructuredIndexFlat.h | 80 +++++++ .../StructuredIndexSort-inl.h | 199 ++++++++++++++++++ .../StructuredIndexSort.h | 80 +++++++ internal/core/src/query/Search.cpp | 33 ++- .../src/query/generated/ExecExprVisitor.h | 4 +- .../src/query/visitors/ExecExprVisitor.cpp | 62 ++++-- .../query/visitors/ExecPlanNodeVisitor.cpp | 3 +- .../src/query/visitors/ShowExprVisitor.cpp | 4 +- internal/core/src/segcore/IndexingEntry.cpp | 52 ++++- internal/core/src/segcore/IndexingEntry.h | 86 ++++++-- internal/core/src/segcore/InsertRecord.cpp | 36 +++- .../core/src/segcore/SegmentSmallIndex.cpp | 4 +- internal/core/unittest/CMakeLists.txt | 2 +- internal/core/unittest/test_bitmap.cpp | 26 +++ internal/core/unittest/test_c_api.cpp | 2 +- internal/core/unittest/test_indexing.cpp | 12 +- internal/core/unittest/test_query.cpp | 173 +++++++++++++-- internal/core/unittest/test_utils/DataGen.h | 22 +- 22 files changed, 1040 insertions(+), 87 deletions(-) create mode 100644 internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndex.h create mode 100644 internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h create mode 100644 internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat.h create mode 100644 internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h create mode 100644 internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort.h create mode 100644 internal/core/unittest/test_bitmap.cpp diff --git a/internal/core/src/cache/DataObj.h b/internal/core/src/cache/DataObj.h index 3aea7dea24..24b8638aab 100644 --- a/internal/core/src/cache/DataObj.h +++ b/internal/core/src/cache/DataObj.h @@ -20,6 +20,9 @@ class DataObj { public: virtual int64_t Size() = 0; + + public: + virtual ~DataObj() = default; }; using DataObjPtr = std::shared_ptr; diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h b/internal/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h index 6ad310ac43..c53a0decf6 100644 --- a/internal/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h @@ -20,7 +20,7 @@ namespace milvus { namespace knowhere { -enum OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 }; +enum class OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 }; static std::map s_map_operator_type = { {"LT", OperatorType::LT}, diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndex.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndex.h new file mode 100644 index 0000000000..1ff77c633d --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndex.h @@ -0,0 +1,89 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include "faiss/utils/ConcurrentBitset.h" +#include "knowhere/index/Index.h" +#include + +namespace milvus { +namespace knowhere::scalar { + +enum class OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 }; + +static std::map s_map_operator_type = { + {"LT", OperatorType::LT}, + {"LE", OperatorType::LE}, + {"GT", OperatorType::GT}, + {"GE", OperatorType::GE}, +}; + +template +struct IndexStructure { + IndexStructure() : a_(0), idx_(0) { + } + explicit IndexStructure(const T a) : a_(a), idx_(0) { + } + IndexStructure(const T a, const size_t idx) : a_(a), idx_(idx) { + } + bool + operator<(const IndexStructure& b) const { + return a_ < b.a_; + } + bool + operator<=(const IndexStructure& b) const { + return a_ <= b.a_; + } + bool + operator>(const IndexStructure& b) const { + return a_ > b.a_; + } + bool + operator>=(const IndexStructure& b) const { + return a_ >= b.a_; + } + bool + operator==(const IndexStructure& b) const { + return a_ == b.a_; + } + T a_; + size_t idx_; +}; +using TargetBitmap = boost::dynamic_bitset<>; +using TargetBitmapPtr = std::unique_ptr; + +template +class StructuredIndex : public Index { + public: + virtual void + Build(const size_t n, const T* values) = 0; + + virtual const TargetBitmapPtr + In(const size_t n, const T* values) = 0; + + virtual const TargetBitmapPtr + NotIn(const size_t n, const T* values) = 0; + + virtual const TargetBitmapPtr + Range(const T value, const OperatorType op) = 0; + + virtual const TargetBitmapPtr + Range(const T lower_bound_value, bool lb_inclusive, const T upper_bound_value, bool ub_inclusive) = 0; +}; + +template +using StructuredIndexPtr = std::shared_ptr>; +} // namespace knowhere::scalar +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h new file mode 100644 index 0000000000..bee0bbdf98 --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h @@ -0,0 +1,153 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include "knowhere/common/Log.h" +#include "knowhere/index/structured_index_simple/StructuredIndexFlat.h" + +namespace milvus { +namespace knowhere::scalar { + +template +StructuredIndexFlat::StructuredIndexFlat() : is_built_(false), data_() { +} + +template +StructuredIndexFlat::StructuredIndexFlat(const size_t n, const T* values) : is_built_(false) { + Build(n, values); +} + +template +StructuredIndexFlat::~StructuredIndexFlat() { +} + +template +void +StructuredIndexFlat::Build(const size_t n, const T* values) { + data_.reserve(n); + T* p = const_cast(values); + for (size_t i = 0; i < n; ++i) { + data_.emplace_back(IndexStructure(*p++, i)); + } + is_built_ = true; +} + +template +const TargetBitmapPtr +StructuredIndexFlat::In(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size()); + for (size_t i = 0; i < n; ++i) { + for (const auto& index : data_) { + if (index->a_ == *(values + i)) { + bitset->set(index->idx_); + } + } + } + return bitset; +} + +template +const TargetBitmapPtr +StructuredIndexFlat::NotIn(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size(), true); + for (size_t i = 0; i < n; ++i) { + for (const auto& index : data_) { + if (index->a_ == *(values + i)) { + bitset->reset(index->idx_); + } + } + } + return bitset; +} + +template +const TargetBitmapPtr +StructuredIndexFlat::Range(const T value, const OperatorType op) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size()); + auto lb = data_.begin(); + auto ub = data_.end(); + for (; lb <= ub; lb++) { + switch (op) { + case OperatorType::LT: + if (lb < IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + case OperatorType::LE: + if (lb <= IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + case OperatorType::GT: + if (lb > IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + case OperatorType::GE: + if (lb >= IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + default: + KNOWHERE_THROW_MSG("Invalid OperatorType:" + std::to_string((int)op) + "!"); + } + } + return bitset; +} + +template +const TargetBitmapPtr +StructuredIndexFlat::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size()); + if (lower_bound_value > upper_bound_value) { + std::swap(lower_bound_value, upper_bound_value); + std::swap(lb_inclusive, ub_inclusive); + } + auto lb = data_.begin(); + auto ub = data_.end(); + for (; lb <= ub; ++lb) { + if (lb_inclusive && ub_inclusive) { + if (lb >= IndexStructure(lower_bound_value) && lb <= IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } else if (lb_inclusive && !ub_inclusive) { + if (lb >= IndexStructure(lower_bound_value) && lb < IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } else if (!lb_inclusive && ub_inclusive) { + if (lb > IndexStructure(lower_bound_value) && lb <= IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } else { + if (lb > IndexStructure(lower_bound_value) && lb < IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } + } + return bitset; +} + +} // namespace knowhere::scalar +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat.h new file mode 100644 index 0000000000..37097b8faf --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat.h @@ -0,0 +1,80 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "knowhere/index/structured_index_simple/StructuredIndex.h" + +namespace milvus { +namespace knowhere::scalar { + +template +class StructuredIndexFlat : public StructuredIndex { + public: + StructuredIndexFlat(); + StructuredIndexFlat(const size_t n, const T* values); + ~StructuredIndexFlat(); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Build(const size_t n, const T* values) override; + + void + build(); + + const TargetBitmapPtr + In(const size_t n, const T* values) override; + + const TargetBitmapPtr + NotIn(const size_t n, const T* values) override; + + const TargetBitmapPtr + Range(const T value, const OperatorType op) override; + + const TargetBitmapPtr + Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override; + + const std::vector>& + GetData() { + return data_; + } + + int64_t + Size() override { + return (int64_t)data_.size(); + } + + bool + IsBuilt() const { + return is_built_; + } + + private: + bool is_built_; + std::vector> data_; +}; + +template +using StructuredIndexFlatPtr = std::shared_ptr>; +} // namespace knowhere::scalar +} // namespace milvus + +#include "knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h" diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h new file mode 100644 index 0000000000..9f7647a40b --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h @@ -0,0 +1,199 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include "knowhere/common/Log.h" +#include "knowhere/index/structured_index_simple/StructuredIndexSort.h" + +namespace milvus { +namespace knowhere::scalar { + +template +StructuredIndexSort::StructuredIndexSort() : is_built_(false), data_() { +} + +template +StructuredIndexSort::StructuredIndexSort(const size_t n, const T* values) : is_built_(false) { + StructuredIndexSort::Build(n, values); +} + +template +StructuredIndexSort::~StructuredIndexSort() { +} + +template +void +StructuredIndexSort::Build(const size_t n, const T* values) { + data_.reserve(n); + T* p = const_cast(values); + for (size_t i = 0; i < n; ++i) { + data_.emplace_back(IndexStructure(*p++, i)); + } + build(); +} + +template +void +StructuredIndexSort::build() { + if (is_built_) + return; + if (data_.size() == 0) { + // todo: throw an exception + KNOWHERE_THROW_MSG("StructuredIndexSort cannot build null values!"); + } + std::sort(data_.begin(), data_.end()); + is_built_ = true; +} + +template +BinarySet +StructuredIndexSort::Serialize(const milvus::knowhere::Config& config) { + if (!is_built_) { + build(); + } + + auto index_data_size = data_.size() * sizeof(IndexStructure); + std::shared_ptr index_data(new uint8_t[index_data_size]); + memcpy(index_data.get(), data_.data(), index_data_size); + + std::shared_ptr index_length(new uint8_t[sizeof(size_t)]); + auto index_size = data_.size(); + memcpy(index_length.get(), &index_size, sizeof(size_t)); + + BinarySet res_set; + res_set.Append("index_data", index_data, index_data_size); + res_set.Append("index_length", index_length, sizeof(size_t)); + return res_set; +} + +template +void +StructuredIndexSort::Load(const milvus::knowhere::BinarySet& index_binary) { + try { + size_t index_size; + auto index_length = index_binary.GetByName("index_length"); + memcpy(&index_size, index_length->data.get(), (size_t)index_length->size); + + auto index_data = index_binary.GetByName("index_data"); + data_.resize(index_size); + memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size); + is_built_ = true; + } catch (...) { + KNOHWERE_ERROR_MSG("StructuredIndexSort Load failed!"); + } +} + +template +const TargetBitmapPtr +StructuredIndexSort::In(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size()); + for (size_t i = 0; i < n; ++i) { + auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + for (; lb < ub; ++lb) { + if (lb->a_ != *(values + i)) { + LOG_KNOWHERE_ERROR_ << "error happens in StructuredIndexSort::In, experted value is: " + << *(values + i) << ", but real value is: " << lb->a_; + } + bitset->set(lb->idx_); + } + } + return bitset; +} + +template +const TargetBitmapPtr +StructuredIndexSort::NotIn(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size(), true); + for (size_t i = 0; i < n; ++i) { + auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + for (; lb < ub; ++lb) { + if (lb->a_ != *(values + i)) { + LOG_KNOWHERE_ERROR_ << "error happens in StructuredIndexSort::NotIn, experted value is: " + << *(values + i) << ", but real value is: " << lb->a_; + } + bitset->reset(lb->idx_); + } + } + return bitset; +} + +template +const TargetBitmapPtr +StructuredIndexSort::Range(const T value, const OperatorType op) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size()); + auto lb = data_.begin(); + auto ub = data_.end(); + switch (op) { + case OperatorType::LT: + ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + case OperatorType::LE: + ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + case OperatorType::GT: + lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + case OperatorType::GE: + lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + default: + KNOWHERE_THROW_MSG("Invalid OperatorType:" + std::to_string((int)op) + "!"); + } + for (; lb < ub; ++lb) { + bitset->set(lb->idx_); + } + return bitset; +} + +template +const TargetBitmapPtr +StructuredIndexSort::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) { + if (!is_built_) { + build(); + } + TargetBitmapPtr bitset = std::make_unique(data_.size()); + if (lower_bound_value > upper_bound_value) { + std::swap(lower_bound_value, upper_bound_value); + std::swap(lb_inclusive, ub_inclusive); + } + auto lb = data_.begin(); + auto ub = data_.end(); + if (lb_inclusive) { + lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(lower_bound_value)); + } else { + lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure(lower_bound_value)); + } + if (ub_inclusive) { + ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(upper_bound_value)); + } else { + ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure(upper_bound_value)); + } + for (; lb < ub; ++lb) { + bitset->set(lb->idx_); + } + return bitset; +} + +} // namespace knowhere::scalar +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort.h new file mode 100644 index 0000000000..ee0c05a47a --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort.h @@ -0,0 +1,80 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "knowhere/index/structured_index_simple/StructuredIndex.h" + +namespace milvus { +namespace knowhere::scalar { + +template +class StructuredIndexSort : public StructuredIndex { + public: + StructuredIndexSort(); + StructuredIndexSort(const size_t n, const T* values); + ~StructuredIndexSort(); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Build(const size_t n, const T* values) override; + + void + build(); + + const TargetBitmapPtr + In(size_t n, const T* values) override; + + const TargetBitmapPtr + NotIn(size_t n, const T* values) override; + + const TargetBitmapPtr + Range(T value, OperatorType op) override; + + const TargetBitmapPtr + Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override; + + const std::vector>& + GetData() { + return data_; + } + + int64_t + Size() override { + return (int64_t)data_.size(); + } + + bool + IsBuilt() const { + return is_built_; + } + + private: + bool is_built_; + std::vector> data_; +}; + +template +using StructuredIndexSortPtr = std::shared_ptr>; +} // namespace knowhere::scalar +} // namespace milvus + +#include "knowhere/index/structured_index_simple/StructuredIndexSort-inl.h" diff --git a/internal/core/src/query/Search.cpp b/internal/core/src/query/Search.cpp index 1a38d1dea3..148a927a46 100644 --- a/internal/core/src/query/Search.cpp +++ b/internal/core/src/query/Search.cpp @@ -15,7 +15,9 @@ create_bitmap_view(std::optional bitmaps_opt, int64_t chunk auto& bitmaps = *bitmaps_opt.value(); auto& src_vec = bitmaps.at(chunk_id); auto dst = std::make_shared(src_vec.size()); - boost::to_block_range(src_vec, dst->mutable_data()); + auto iter = reinterpret_cast(dst->mutable_data()); + + boost::to_block_range(src_vec, iter); return dst; } @@ -28,9 +30,9 @@ QueryBruteForceImpl(const SegmentSmallIndex& segment, Timestamp timestamp, std::optional bitmaps_opt, QueryResult& results) { - auto& record = segment.get_insert_record(); auto& schema = segment.get_schema(); auto& indexing_record = segment.get_indexing_record(); + auto& record = segment.get_insert_record(); // step 1: binary search to find the barrier of the snapshot auto ins_barrier = get_barrier(record, timestamp); auto max_chunk = upper_div(ins_barrier, DefaultElementPerChunk); @@ -48,7 +50,6 @@ QueryBruteForceImpl(const SegmentSmallIndex& segment, Assert(vecfield_offset_opt.has_value()); auto vecfield_offset = vecfield_offset_opt.value(); auto& field = schema[vecfield_offset]; - auto vec_ptr = record.get_vec_entity(vecfield_offset); Assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); @@ -61,31 +62,45 @@ QueryBruteForceImpl(const SegmentSmallIndex& segment, std::vector final_dis(total_count, std::numeric_limits::max()); auto max_indexed_id = indexing_record.get_finished_ack(); - const auto& indexing_entry = indexing_record.get_indexing(vecfield_offset); + const auto& indexing_entry = indexing_record.get_vec_entry(vecfield_offset); auto search_conf = indexing_entry.get_search_conf(topK); for (int chunk_id = 0; chunk_id < max_indexed_id; ++chunk_id) { - auto indexing = indexing_entry.get_indexing(chunk_id); - auto src_data = vec_ptr->get_chunk(chunk_id).data(); - auto dataset = knowhere::GenDataset(num_queries, dim, src_data); + auto indexing = indexing_entry.get_vec_indexing(chunk_id); + auto dataset = knowhere::GenDataset(num_queries, dim, query_data); auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id); auto ans = indexing->Query(dataset, search_conf, bitmap_view); auto dis = ans->Get(milvus::knowhere::meta::DISTANCE); auto uids = ans->Get(milvus::knowhere::meta::IDS); + // convert chunk uid to segment uid + for (int64_t i = 0; i < total_count; ++i) { + auto& x = uids[i]; + if (x != -1) { + x += chunk_id * DefaultElementPerChunk; + } + } merge_into(num_queries, topK, final_dis.data(), final_uids.data(), dis, uids); } + auto vec_ptr = record.get_vec_entity(vecfield_offset); // step 4: brute force search where small indexing is unavailable for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) { std::vector buf_uids(total_count, -1); std::vector buf_dis(total_count, std::numeric_limits::max()); faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()}; - auto src_data = vec_ptr->get_chunk(chunk_id).data(); + auto& chunk = vec_ptr->get_chunk(chunk_id); auto nsize = chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk; auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id); - faiss::knn_L2sqr(query_data, src_data, dim, num_queries, nsize, &buf, bitmap_view); + faiss::knn_L2sqr(query_data, chunk.data(), dim, num_queries, nsize, &buf, bitmap_view); + Assert(buf_uids.size() == total_count); + // convert chunk uid to segment uid + for (auto& x : buf_uids) { + if (x != -1) { + x += chunk_id * DefaultElementPerChunk; + } + } merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data()); } diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h index e5327f7e71..86cd6b7ee0 100644 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ b/internal/core/src/query/generated/ExecExprVisitor.h @@ -37,9 +37,9 @@ class ExecExprVisitor : ExprVisitor { } public: - template + template auto - ExecRangeVisitorImpl(RangeExprImpl& expr_scp, Func func) -> RetType; + ExecRangeVisitorImpl(RangeExprImpl& expr, IndexFunc func, ElementFunc element_func) -> RetType; template auto diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 889485393e..e82d0723af 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -55,9 +55,10 @@ ExecExprVisitor::visit(TermExpr& expr) { PanicInfo("unimplemented"); } -template +template auto -ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl& expr, Func func) -> RetType { +ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl& expr, IndexFunc index_func, ElementFunc element_func) + -> RetType { auto& records = segment_.get_insert_record(); auto data_type = expr.data_type_; auto& schema = segment_.get_schema(); @@ -67,15 +68,28 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl& expr, Func func) -> RetT auto& field_meta = schema[field_offset]; auto vec_ptr = records.get_scalar_entity(field_offset); auto& vec = *vec_ptr; + auto& indexing_record = segment_.get_indexing_record(); + const segcore::ScalarIndexingEntry& entry = indexing_record.get_scalar_entry(field_offset); + RetType results(vec.chunk_size()); - for (auto chunk_id = 0; chunk_id < vec.chunk_size(); ++chunk_id) { + auto indexing_barrier = indexing_record.get_finished_ack(); + for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) { + auto& result = results[chunk_id]; + auto indexing = entry.get_indexing(chunk_id); + auto data = index_func(indexing); + result = ~std::move(*data); + Assert(result.size() == segcore::DefaultElementPerChunk); + } + + for (auto chunk_id = indexing_barrier; chunk_id < vec.chunk_size(); ++chunk_id) { auto& result = results[chunk_id]; result.resize(segcore::DefaultElementPerChunk); auto chunk = vec.get_chunk(chunk_id); const T* data = chunk.data(); for (int index = 0; index < segcore::DefaultElementPerChunk; ++index) { - result[index] = func(data[index]); + result[index] = element_func(data[index]); } + Assert(result.size() == segcore::DefaultElementPerChunk); } return results; } @@ -89,6 +103,8 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType { auto conditions = expr.conditions_; std::sort(conditions.begin(), conditions.end()); using OpType = RangeExpr::OpType; + using Index = knowhere::scalar::StructuredIndex; + using Operator = knowhere::scalar::OperatorType; if (conditions.size() == 1) { auto cond = conditions[0]; // auto [op, val] = cond; // strange bug on capture @@ -96,27 +112,39 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType { auto val = std::get<1>(cond); switch (op) { case OpType::Equal: { - return ExecRangeVisitorImpl(expr, [val](T x) { return !(x == val); }); + auto index_func = [val](Index* index) { return index->In(1, &val); }; + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x == val); }); } case OpType::NotEqual: { - return ExecRangeVisitorImpl(expr, [val](T x) { return !(x != val); }); + auto index_func = [val](Index* index) { + // Note: index->NotIn() is buggy, investigating + // this is a workaround + auto res = index->In(1, &val); + *res = ~std::move(*res); + return res; + }; + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); }); } case OpType::GreaterEqual: { - return ExecRangeVisitorImpl(expr, [val](T x) { return !(x >= val); }); + auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); }; + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x >= val); }); } case OpType::GreaterThan: { - return ExecRangeVisitorImpl(expr, [val](T x) { return !(x > val); }); + auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); }; + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x > val); }); } case OpType::LessEqual: { - return ExecRangeVisitorImpl(expr, [val](T x) { return !(x <= val); }); + auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); }; + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x <= val); }); } case OpType::LessThan: { - return ExecRangeVisitorImpl(expr, [val](T x) { return !(x < val); }); + auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); }; + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x < val); }); } default: { PanicInfo("unsupported range node"); @@ -131,13 +159,17 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType { auto ops = std::make_tuple(op1, op2); if (false) { } else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) { - return ExecRangeVisitorImpl(expr, [val1, val2](T x) { return !(val1 < x && x < val2); }); + auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); }; + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x < val2); }); } else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) { - return ExecRangeVisitorImpl(expr, [val1, val2](T x) { return !(val1 < x && x <= val2); }); + auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); }; + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x <= val2); }); } else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) { - return ExecRangeVisitorImpl(expr, [val1, val2](T x) { return !(val1 <= x && x < val2); }); + auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); }; + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x < val2); }); } else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) { - return ExecRangeVisitorImpl(expr, [val1, val2](T x) { return !(val1 <= x && x <= val2); }); + auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); }; + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x <= val2); }); } else { PanicInfo("unsupported range node"); } @@ -157,7 +189,7 @@ ExecExprVisitor::visit(RangeExpr& expr) { // ret = ExecRangeVisitorDispatcher(expr); // break; //} - case DataType::BOOL: + // case DataType::BOOL: case DataType::INT8: { ret = ExecRangeVisitorDispatcher(expr); break; diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index f35f7e9367..e973e8532e 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -55,8 +55,9 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) { auto bitmap = ExecExprVisitor(*segment).call_child(*node.predicate_.value()); auto ptr = &bitmap; QueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, ptr, ret); + } else { + QueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret); } - QueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret); ret_ = ret; } diff --git a/internal/core/src/query/visitors/ShowExprVisitor.cpp b/internal/core/src/query/visitors/ShowExprVisitor.cpp index 0c837d5933..0a5de4eb57 100644 --- a/internal/core/src/query/visitors/ShowExprVisitor.cpp +++ b/internal/core/src/query/visitors/ShowExprVisitor.cpp @@ -115,8 +115,8 @@ ShowExprVisitor::visit(TermExpr& expr) { return TermExtract(expr); case DataType::FLOAT: return TermExtract(expr); - case DataType::BOOL: - return TermExtract(expr); + // case DataType::BOOL: + // return TermExtract(expr); default: PanicInfo("unsupported type"); } diff --git a/internal/core/src/segcore/IndexingEntry.cpp b/internal/core/src/segcore/IndexingEntry.cpp index 44bc2a43d2..dc7b14cc16 100644 --- a/internal/core/src/segcore/IndexingEntry.cpp +++ b/internal/core/src/segcore/IndexingEntry.cpp @@ -5,7 +5,7 @@ namespace milvus::segcore { void -IndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) { +VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) { // TODO assert(field_meta_.get_data_type() == DataType::VECTOR_FLOAT); @@ -30,7 +30,7 @@ IndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBas } knowhere::Config -IndexingEntry::get_build_conf() const { +VecIndexingEntry::get_build_conf() const { return knowhere::Config{{knowhere::meta::DIM, field_meta_.get_dim()}, {knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4}, @@ -39,7 +39,7 @@ IndexingEntry::get_build_conf() const { } knowhere::Config -IndexingEntry::get_search_conf(int top_K) const { +VecIndexingEntry::get_search_conf(int top_K) const { return knowhere::Config{{knowhere::meta::DIM, field_meta_.get_dim()}, {knowhere::meta::TOPK, top_K}, {knowhere::IndexParams::nlist, 100}, @@ -65,10 +65,54 @@ IndexingRecord::UpdateResourceAck(int64_t chunk_ack, const InsertRecord& record) // std::thread([this, old_ack, chunk_ack, &record] { for (auto& [field_offset, entry] : entries_) { auto vec_base = record.entity_vec_[field_offset].get(); - entry.BuildIndexRange(old_ack, chunk_ack, vec_base); + entry->BuildIndexRange(old_ack, chunk_ack, vec_base); } finished_ack_.AddSegment(old_ack, chunk_ack); // }).detach(); } +template +void +ScalarIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) { + auto dim = field_meta_.get_dim(); + + auto source = dynamic_cast*>(vec_base); + Assert(source); + auto chunk_size = source->chunk_size(); + assert(ack_end <= chunk_size); + data_.grow_to_at_least(ack_end); + for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) { + const auto& chunk = source->get_chunk(chunk_id); + // build index for chunk + // TODO + Assert(chunk.size() == DefaultElementPerChunk); + auto indexing = std::make_unique>(); + indexing->Build(DefaultElementPerChunk, chunk.data()); + data_[chunk_id] = std::move(indexing); + } +} + +std::unique_ptr +CreateIndex(const FieldMeta& field_meta) { + if (field_meta.is_vector()) { + return std::make_unique(field_meta); + } + switch (field_meta.get_data_type()) { + case DataType::INT8: + return std::make_unique>(field_meta); + case DataType::INT16: + return std::make_unique>(field_meta); + case DataType::INT32: + return std::make_unique>(field_meta); + case DataType::INT64: + return std::make_unique>(field_meta); + case DataType::FLOAT: + return std::make_unique>(field_meta); + case DataType::DOUBLE: + return std::make_unique>(field_meta); + default: + PanicInfo("unsupported"); + } +} + } // namespace milvus::segcore diff --git a/internal/core/src/segcore/IndexingEntry.h b/internal/core/src/segcore/IndexingEntry.h index 6cc61df36d..89d5458e96 100644 --- a/internal/core/src/segcore/IndexingEntry.h +++ b/internal/core/src/segcore/IndexingEntry.h @@ -5,6 +5,7 @@ #include #include "InsertRecord.h" #include +#include namespace milvus::segcore { @@ -14,35 +15,68 @@ class IndexingEntry { public: explicit IndexingEntry(const FieldMeta& field_meta) : field_meta_(field_meta) { } - - // concurrent - knowhere::VecIndex* - get_indexing(int64_t chunk_id) const { - return data_.at(chunk_id).get(); - } + IndexingEntry(const IndexingEntry&) = delete; + IndexingEntry& + operator=(const IndexingEntry&) = delete; // Do this in parallel - void - BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base); + virtual void + BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) = 0; const FieldMeta& get_field_meta() { return field_meta_; } + protected: + // additional info + const FieldMeta& field_meta_; +}; +template +class ScalarIndexingEntry : public IndexingEntry { + public: + using IndexingEntry::IndexingEntry; + + void + BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override; + + // concurrent + knowhere::scalar::StructuredIndex* + get_indexing(int64_t chunk_id) const { + Assert(!field_meta_.is_vector()); + return data_.at(chunk_id).get(); + } + + private: + tbb::concurrent_vector>> data_; +}; + +class VecIndexingEntry : public IndexingEntry { + public: + using IndexingEntry::IndexingEntry; + + void + BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override; + + // concurrent + knowhere::VecIndex* + get_vec_indexing(int64_t chunk_id) const { + Assert(field_meta_.is_vector()); + return data_.at(chunk_id).get(); + } + knowhere::Config get_build_conf() const; knowhere::Config get_search_conf(int top_k) const; - private: - // additional info - const FieldMeta& field_meta_; - private: tbb::concurrent_vector> data_; }; +std::unique_ptr +CreateIndex(const FieldMeta& field_meta); + class IndexingRecord { public: explicit IndexingRecord(const Schema& schema) : schema_(schema) { @@ -53,9 +87,7 @@ class IndexingRecord { Initialize() { int offset = 0; for (auto& field : schema_) { - if (field.is_vector()) { - entries_.try_emplace(offset, field); - } + entries_.try_emplace(offset, CreateIndex(field)); ++offset; } assert(offset == schema_.size()); @@ -72,9 +104,25 @@ class IndexingRecord { } const IndexingEntry& - get_indexing(int i) const { - assert(entries_.count(i)); - return entries_.at(i); + get_entry(int field_offset) const { + assert(entries_.count(field_offset)); + return *entries_.at(field_offset); + } + + const VecIndexingEntry& + get_vec_entry(int field_offset) const { + auto& entry = get_entry(field_offset); + auto ptr = dynamic_cast(&entry); + AssertInfo(ptr, "invalid indexing"); + return *ptr; + } + template + auto + get_scalar_entry(int field_offset) const -> const ScalarIndexingEntry& { + auto& entry = get_entry(field_offset); + auto ptr = dynamic_cast*>(&entry); + AssertInfo(ptr, "invalid indexing"); + return *ptr; } private: @@ -89,7 +137,7 @@ class IndexingRecord { private: // field_offset => indexing - std::map entries_; + std::map> entries_; }; } // namespace milvus::segcore \ No newline at end of file diff --git a/internal/core/src/segcore/InsertRecord.cpp b/internal/core/src/segcore/InsertRecord.cpp index ecbe7cedd3..fb65b61624 100644 --- a/internal/core/src/segcore/InsertRecord.cpp +++ b/internal/core/src/segcore/InsertRecord.cpp @@ -7,9 +7,39 @@ InsertRecord::InsertRecord(const Schema& schema) : uids_(1), timestamps_(1) { if (field.is_vector()) { Assert(field.get_data_type() == DataType::VECTOR_FLOAT); entity_vec_.emplace_back(std::make_shared>(field.get_dim())); - } else { - Assert(field.get_data_type() == DataType::INT32); - entity_vec_.emplace_back(std::make_shared>()); + continue; + } + switch (field.get_data_type()) { + case DataType::INT8: { + entity_vec_.emplace_back(std::make_shared>()); + break; + } + case DataType::INT16: { + entity_vec_.emplace_back(std::make_shared>()); + break; + } + case DataType::INT32: { + entity_vec_.emplace_back(std::make_shared>()); + break; + } + + case DataType::INT64: { + entity_vec_.emplace_back(std::make_shared>()); + break; + } + + case DataType::FLOAT: { + entity_vec_.emplace_back(std::make_shared>()); + break; + } + + case DataType::DOUBLE: { + entity_vec_.emplace_back(std::make_shared>()); + break; + } + default: { + PanicInfo("unsupported"); + } } } } diff --git a/internal/core/src/segcore/SegmentSmallIndex.cpp b/internal/core/src/segcore/SegmentSmallIndex.cpp index 3af53741d6..951795987a 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.cpp +++ b/internal/core/src/segcore/SegmentSmallIndex.cpp @@ -168,7 +168,7 @@ SegmentSmallIndex::Insert(int64_t reserved_begin, } record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size); - // indexing_record_.UpdateResourceAck(record_.ack_responder_.GetAck() / DefaultElementPerChunk); + indexing_record_.UpdateResourceAck(record_.ack_responder_.GetAck() / DefaultElementPerChunk, record_); return Status::OK(); } @@ -280,7 +280,7 @@ SegmentSmallIndex::BuildIndex(IndexMetaPtr remote_index_meta) { if (record_.ack_responder_.GetAck() < 1024 * 4) { return Status(SERVER_BUILD_INDEX_ERROR, "too few elements"); } - // AssertInfo(false, "unimplemented"); + AssertInfo(false, "unimplemented"); return Status::OK(); #if 0 index_meta_ = remote_index_meta; diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 8ae28c95c5..4210b43f54 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -10,7 +10,7 @@ set(MILVUS_TEST_FILES test_indexing.cpp test_query.cpp test_expr.cpp - ) + test_bitmap.cpp) add_executable(all_tests ${MILVUS_TEST_FILES} ) diff --git a/internal/core/unittest/test_bitmap.cpp b/internal/core/unittest/test_bitmap.cpp new file mode 100644 index 0000000000..1cb67e2b28 --- /dev/null +++ b/internal/core/unittest/test_bitmap.cpp @@ -0,0 +1,26 @@ +#include +#include "test_utils/DataGen.h" +#include "knowhere/index/structured_index_simple/StructuredIndexSort.h" + +TEST(Bitmap, Naive) { + using namespace milvus; + using namespace milvus::segcore; + using namespace milvus::query; + auto schema = std::make_shared(); + schema->AddField("height", DataType::FLOAT); + int N = 10000; + auto raw_data = DataGen(schema, N); + auto vec = raw_data.get_col(0); + auto sort_index = std::make_shared>(); + sort_index->Build(N, vec.data()); + { + auto res = sort_index->Range(0, knowhere::scalar::OperatorType::LT); + double count = res->count(); + ASSERT_NEAR(count / N, 0.5, 0.01); + } + { + auto res = sort_index->Range(-1, false, 1, true); + double count = res->count(); + ASSERT_NEAR(count / N, 0.682, 0.01); + } +} \ No newline at end of file diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 4b28f4fa03..dd0f3abd2a 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -192,7 +192,7 @@ TEST(CApiTest, BuildIndexTest) { // TODO: add index ptr Close(segment); - BuildIndex(collection, segment); + // BuildIndex(collection, segment); const char* dsl_string = R"( { diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index ed01cc5a8c..a55dcc09b3 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -76,15 +76,13 @@ TEST(Indexing, SmartBruteForce) { auto query_data = raw; - vector final_uids(total_count); + vector final_uids(total_count, -1); vector final_dis(total_count, std::numeric_limits::max()); for (int beg = 0; beg < N; beg += DefaultElementPerChunk) { vector buf_uids(total_count, -1); vector buf_dis(total_count, std::numeric_limits::max()); - faiss::float_maxheap_array_t buf = {queries, TOPK, buf_uids.data(), buf_dis.data()}; - auto end = beg + DefaultElementPerChunk; if (end > N) { end = N; @@ -93,12 +91,10 @@ TEST(Indexing, SmartBruteForce) { auto src_data = raw + beg * DIM; faiss::knn_L2sqr(query_data, src_data, DIM, queries, nsize, &buf, nullptr); - if (beg == 0) { - final_uids = buf_uids; - final_dis = buf_dis; - } else { - merge_into(queries, TOPK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data()); + for (auto& x : buf_uids) { + x = uids[x + beg]; } + merge_into(queries, TOPK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data()); } for (int qn = 0; qn < queries; ++qn) { diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index d3d97def8c..ff7a3ab317 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -8,6 +8,7 @@ #include "query/generated/ShowPlanNodeVisitor.h" #include "query/generated/ExecPlanNodeVisitor.h" #include "query/PlanImpl.h" +#include "segcore/SegmentSmallIndex.h" using namespace milvus; using namespace milvus::query; @@ -148,29 +149,165 @@ TEST(Query, ParsePlaceholderGroup) { auto schema = std::make_shared(); schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); auto plan = CreatePlan(*schema, dsl_string); - int num_queries = 10; + int64_t num_queries = 100000; int dim = 16; - std::default_random_engine e; - std::normal_distribution dis(0, 1); - ser::PlaceholderGroup raw_group; - auto value = raw_group.add_placeholders(); - value->set_tag("$0"); - value->set_type(ser::PlaceholderType::VECTOR_FLOAT); - for (int i = 0; i < num_queries; ++i) { - std::vector vec; - for (int d = 0; d < dim; ++d) { - vec.push_back(dis(e)); - } - // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); - value->add_values(vec.data(), vec.size() * sizeof(float)); - } + auto raw_group = CreatePlaceholderGroup(num_queries, dim); auto blob = raw_group.SerializeAsString(); - // ser::PlaceholderGroup new_group; - // new_group.ParseFromString() auto placeholder = ParsePlaceholderGroup(plan.get(), blob); } -TEST(Query, Exec) { +TEST(Query, ExecWithPredicate) { using namespace milvus::query; using namespace milvus::segcore; + auto schema = std::make_shared(); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); + schema->AddField("age", DataType::FLOAT); + std::string dsl = R"({ + "bool": { + "must": [ + { + "range": { + "age": { + "GE": -1, + "LT": 1 + } + } + }, + { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 5 + } + } + } + ] + } + })"; + int64_t N = 1000 * 1000; + auto dataset = DataGen(schema, N); + auto segment = std::make_unique(schema); + segment->PreInsert(N); + segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); + + auto plan = CreatePlan(*schema, dsl); + auto num_queries = 5; + auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + QueryResult qr; + Timestamp time = 1000000; + std::vector ph_group_arr = {ph_group.get()}; + segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); + std::vector> results; + int topk = 5; + for (int q = 0; q < num_queries; ++q) { + std::vector result; + for (int k = 0; k < topk; ++k) { + int index = q * topk + k; + result.emplace_back(std::to_string(qr.result_ids_[index]) + "->" + + std::to_string(qr.result_distances_[index])); + } + results.emplace_back(std::move(result)); + } + + auto ref = Json::parse(R"([ + [ + [ + "980486->3.149221", + "318367->3.661235", + "302798->4.553688", + "321424->4.757450", + "565529->5.083780" + ], + [ + "233390->7.931535", + "238958->8.109344", + "230645->8.439169", + "901939->8.658772", + "380328->8.731251" + ], + [ + "897246->3.749835", + "750683->3.897577", + "857598->4.230977", + "299009->4.379639", + "440010->4.454046" + ], + [ + "840855->4.782170", + "709627->5.063170", + "72322->5.166143", + "107142->5.180207", + "948403->5.247065" + ], + [ + "810401->3.926393", + "46575->4.054171", + "201740->4.274491", + "669040->4.399628", + "231500->4.831223" + ] + ] +])"); + + Json json{results}; + ASSERT_EQ(json, ref); +} + +TEST(Query, ExecWihtoutPredicate) { + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); + schema->AddField("age", DataType::FLOAT); + std::string dsl = R"({ + "bool": { + "must": [ + { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 5 + } + } + } + ] + } + })"; + int64_t N = 1000 * 1000; + auto dataset = DataGen(schema, N); + auto segment = std::make_unique(schema); + segment->PreInsert(N); + segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); + + auto plan = CreatePlan(*schema, dsl); + auto num_queries = 5; + auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + QueryResult qr; + Timestamp time = 1000000; + std::vector ph_group_arr = {ph_group.get()}; + segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); + std::vector> results; + int topk = 5; + for (int q = 0; q < num_queries; ++q) { + std::vector result; + for (int k = 0; k < topk; ++k) { + int index = q * topk + k; + result.emplace_back(std::to_string(qr.result_ids_[index]) + "->" + + std::to_string(qr.result_distances_[index])); + } + results.emplace_back(std::move(result)); + } + + Json json{results}; + std::cout << json.dump(2); } diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index d5d7a3f6c8..787d4d38e2 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -111,4 +111,24 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) { return std::move(res); } -} // namespace milvus::segcore \ No newline at end of file +inline auto +CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) { + namespace ser = milvus::proto::service; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::VECTOR_FLOAT); + std::normal_distribution dis(0, 1); + std::default_random_engine e(seed); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(dis(e)); + } + // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); + value->add_values(vec.data(), vec.size() * sizeof(float)); + } + return raw_group; +} + +} // namespace milvus::segcore