mirror of https://github.com/milvus-io/milvus.git
enhance: [2.5][cp] speed up search iterator stage 1 (#38678)
pr: https://github.com/milvus-io/milvus/pull/37947 issue: https://github.com/milvus-io/milvus/issues/37548 Signed-off-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com> (cherry picked from commit 9016c4adcd765c0766b01e7e5d465c915e176a6f)pull/38837/head
parent
d48a33d76b
commit
ef400227ad
2
go.mod
2
go.mod
|
@ -64,6 +64,7 @@ require (
|
|||
github.com/cenkalti/backoff/v4 v4.2.1
|
||||
github.com/cockroachdb/redact v1.1.3
|
||||
github.com/goccy/go-json v0.10.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/greatroar/blobloom v0.0.0-00010101000000-000000000000
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/jolestar/go-commons-pool/v2 v2.1.2
|
||||
|
@ -144,7 +145,6 @@ require (
|
|||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/flatbuffers v2.0.8+incompatible // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
|
|
|
@ -24,6 +24,12 @@
|
|||
|
||||
namespace milvus {
|
||||
|
||||
struct SearchIteratorV2Info {
|
||||
std::string token = "";
|
||||
uint32_t batch_size = 0;
|
||||
std::optional<float> last_bound = std::nullopt;
|
||||
};
|
||||
|
||||
struct SearchInfo {
|
||||
int64_t topk_{0};
|
||||
int64_t group_size_{1};
|
||||
|
@ -36,6 +42,7 @@ struct SearchInfo {
|
|||
tracer::TraceContext trace_ctx_;
|
||||
bool materialized_view_involved = false;
|
||||
bool iterative_filter_execution = false;
|
||||
std::optional<SearchIteratorV2Info> iterator_v2_info_ = std::nullopt;
|
||||
};
|
||||
|
||||
using SearchInfoPtr = std::shared_ptr<SearchInfo>;
|
||||
|
|
|
@ -362,4 +362,38 @@ ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size) {
|
|||
}
|
||||
}
|
||||
|
||||
bool
|
||||
CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info,
|
||||
const int64_t topk,
|
||||
const MetricType& metric_type,
|
||||
knowhere::Json& search_config) {
|
||||
const auto radius =
|
||||
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS);
|
||||
if (!radius.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
search_config[RADIUS] = radius.value();
|
||||
// `range_search_k` is only used as one of the conditions for iterator early termination.
|
||||
// not gurantee to return exactly `range_search_k` results, which may be more or less.
|
||||
// set it to -1 will return all results in the range.
|
||||
search_config[knowhere::meta::RANGE_SEARCH_K] = topk;
|
||||
|
||||
const auto range_filter =
|
||||
GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
|
||||
if (range_filter.has_value()) {
|
||||
search_config[RANGE_FILTER] = range_filter.value();
|
||||
CheckRangeSearchParam(
|
||||
search_config[RADIUS], search_config[RANGE_FILTER], metric_type);
|
||||
}
|
||||
|
||||
const auto page_retain_order =
|
||||
GetValueFromConfig<bool>(search_info.search_params_, PAGE_RETAIN_ORDER);
|
||||
if (page_retain_order.has_value()) {
|
||||
search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] =
|
||||
page_retain_order.value();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -30,6 +30,8 @@
|
|||
|
||||
#include "common/Types.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
#include "index/IndexInfo.h"
|
||||
#include "storage/Types.h"
|
||||
|
||||
|
@ -147,4 +149,10 @@ AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
|
|||
void
|
||||
ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size = 0x7ffff000);
|
||||
|
||||
bool
|
||||
CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info,
|
||||
const int64_t topk,
|
||||
const MetricType& metric_type,
|
||||
knowhere::Json& search_config);
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -266,32 +266,9 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
|||
search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;
|
||||
|
||||
auto final = [&] {
|
||||
auto radius =
|
||||
GetValueFromConfig<float>(search_info.search_params_, RADIUS);
|
||||
if (radius.has_value()) {
|
||||
search_config[RADIUS] = radius.value();
|
||||
// `range_search_k` is only used as one of the conditions for iterator early termination.
|
||||
// not gurantee to return exactly `range_search_k` results, which may be more or less.
|
||||
// set it to -1 will return all results in the range.
|
||||
search_config[knowhere::meta::RANGE_SEARCH_K] = topk;
|
||||
auto range_filter = GetValueFromConfig<float>(
|
||||
search_info.search_params_, RANGE_FILTER);
|
||||
if (range_filter.has_value()) {
|
||||
search_config[RANGE_FILTER] = range_filter.value();
|
||||
CheckRangeSearchParam(search_config[RADIUS],
|
||||
search_config[RANGE_FILTER],
|
||||
GetMetricType());
|
||||
}
|
||||
|
||||
auto page_retain_order = GetValueFromConfig<bool>(
|
||||
search_info.search_params_, PAGE_RETAIN_ORDER);
|
||||
if (page_retain_order.has_value()) {
|
||||
search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] =
|
||||
page_retain_order.value();
|
||||
}
|
||||
|
||||
if (CheckAndUpdateKnowhereRangeSearchParam(
|
||||
search_info, topk, GetMetricType(), search_config)) {
|
||||
auto res = index_.RangeSearch(dataset, search_config, bitset);
|
||||
|
||||
if (!res.has_value()) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
fmt::format("failed to range search: {}: {}",
|
||||
|
|
|
@ -380,16 +380,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
|||
// TODO :: check dim of search data
|
||||
auto final = [&] {
|
||||
auto index_type = GetIndexType();
|
||||
if (CheckKeyInConfig(search_conf, RADIUS)) {
|
||||
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
|
||||
CheckRangeSearchParam(search_conf[RADIUS],
|
||||
search_conf[RANGE_FILTER],
|
||||
GetMetricType());
|
||||
}
|
||||
// `range_search_k` is only used as one of the conditions for iterator early termination.
|
||||
// not gurantee to return exactly `range_search_k` results, which may be more or less.
|
||||
// set it to -1 will return all results in the range.
|
||||
search_conf[knowhere::meta::RANGE_SEARCH_K] = topk;
|
||||
if (CheckAndUpdateKnowhereRangeSearchParam(
|
||||
search_info, topk, GetMetricType(), search_conf)) {
|
||||
milvus::tracer::AddEvent("start_knowhere_index_range_search");
|
||||
auto res = index_.RangeSearch(dataset, search_conf, bitset);
|
||||
milvus::tracer::AddEvent("finish_knowhere_index_range_search");
|
||||
|
|
|
@ -0,0 +1,362 @@
|
|||
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
CachedSearchIterator::CachedSearchIterator(
|
||||
const milvus::index::VectorIndex& index,
|
||||
const knowhere::DataSetPtr& query_ds,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) {
|
||||
if (query_ds == nullptr) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Query dataset is nullptr, cannot initialize iterator");
|
||||
}
|
||||
nq_ = query_ds->GetRows();
|
||||
Init(search_info);
|
||||
|
||||
auto search_json = index.PrepareSearchParams(search_info);
|
||||
index::CheckAndUpdateKnowhereRangeSearchParam(
|
||||
search_info, batch_size_, index.GetMetricType(), search_json);
|
||||
|
||||
auto expected_iterators =
|
||||
index.VectorIterators(query_ds, search_json, bitset);
|
||||
if (expected_iterators.has_value()) {
|
||||
iterators_ = std::move(expected_iterators.value());
|
||||
} else {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Failed to create iterators from index");
|
||||
}
|
||||
}
|
||||
|
||||
CachedSearchIterator::CachedSearchIterator(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type) {
|
||||
nq_ = query_ds.num_queries;
|
||||
Init(search_info);
|
||||
|
||||
auto expected_iterators = GetBruteForceSearchIterators(
|
||||
query_ds, raw_ds, search_info, index_info, bitset, data_type);
|
||||
if (expected_iterators.has_value()) {
|
||||
iterators_ = std::move(expected_iterators.value());
|
||||
} else {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Failed to create iterators from index");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CachedSearchIterator::InitializeChunkedIterators(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type,
|
||||
const GetChunkDataFunc& get_chunk_data) {
|
||||
int64_t offset = 0;
|
||||
chunked_heaps_.resize(nq_);
|
||||
for (int64_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
|
||||
auto [chunk_data, chunk_size] = get_chunk_data(chunk_id);
|
||||
auto sub_data = query::dataset::RawDataset{
|
||||
offset, query_ds.dim, chunk_size, chunk_data};
|
||||
|
||||
auto expected_iterators = GetBruteForceSearchIterators(
|
||||
query_ds, sub_data, search_info, index_info, bitset, data_type);
|
||||
if (expected_iterators.has_value()) {
|
||||
auto& chunk_iterators = expected_iterators.value();
|
||||
iterators_.insert(iterators_.end(),
|
||||
std::make_move_iterator(chunk_iterators.begin()),
|
||||
std::make_move_iterator(chunk_iterators.end()));
|
||||
} else {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Failed to create iterators from index");
|
||||
}
|
||||
offset += chunk_size;
|
||||
}
|
||||
}
|
||||
|
||||
CachedSearchIterator::CachedSearchIterator(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const segcore::VectorBase* vec_data,
|
||||
const int64_t row_count,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type) {
|
||||
if (vec_data == nullptr) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Vector data is nullptr, cannot initialize iterator");
|
||||
}
|
||||
|
||||
if (row_count <= 0) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Number of rows is 0, cannot initialize iterator");
|
||||
}
|
||||
|
||||
const int64_t vec_size_per_chunk = vec_data->get_size_per_chunk();
|
||||
num_chunks_ = upper_div(row_count, vec_size_per_chunk);
|
||||
nq_ = query_ds.num_queries;
|
||||
Init(search_info);
|
||||
|
||||
iterators_.reserve(nq_ * num_chunks_);
|
||||
InitializeChunkedIterators(
|
||||
query_ds,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type,
|
||||
[&vec_data, vec_size_per_chunk, row_count](
|
||||
int64_t chunk_id) -> std::pair<const void*, int64_t> {
|
||||
const auto chunk_data = vec_data->get_chunk_data(chunk_id);
|
||||
int64_t chunk_size = std::min(
|
||||
vec_size_per_chunk, row_count - chunk_id * vec_size_per_chunk);
|
||||
return {chunk_data, chunk_size};
|
||||
});
|
||||
}
|
||||
|
||||
CachedSearchIterator::CachedSearchIterator(
|
||||
const std::shared_ptr<ChunkedColumnBase>& column,
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type) {
|
||||
if (column == nullptr) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Column is nullptr, cannot initialize iterator");
|
||||
}
|
||||
|
||||
num_chunks_ = column->num_chunks();
|
||||
nq_ = query_ds.num_queries;
|
||||
Init(search_info);
|
||||
|
||||
iterators_.reserve(nq_ * num_chunks_);
|
||||
InitializeChunkedIterators(
|
||||
query_ds,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type,
|
||||
[&column](int64_t chunk_id) {
|
||||
const char* chunk_data = column->Data(chunk_id);
|
||||
int64_t chunk_size = column->chunk_row_nums(chunk_id);
|
||||
return std::make_pair(static_cast<const void*>(chunk_data),
|
||||
chunk_size);
|
||||
});
|
||||
}
|
||||
|
||||
void
|
||||
CachedSearchIterator::NextBatch(const SearchInfo& search_info,
|
||||
SearchResult& search_result) {
|
||||
if (iterators_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (iterators_.size() != nq_ * num_chunks_) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Iterator size mismatch, expect %d, but got %d",
|
||||
nq_ * num_chunks_,
|
||||
iterators_.size());
|
||||
}
|
||||
|
||||
ValidateSearchInfo(search_info);
|
||||
|
||||
search_result.total_nq_ = nq_;
|
||||
search_result.unity_topK_ = batch_size_;
|
||||
search_result.seg_offsets_.resize(nq_ * batch_size_);
|
||||
search_result.distances_.resize(nq_ * batch_size_);
|
||||
|
||||
for (size_t query_idx = 0; query_idx < nq_; ++query_idx) {
|
||||
auto rst = GetBatchedNextResults(query_idx, search_info);
|
||||
WriteSingleQuerySearchResult(
|
||||
search_result, query_idx, rst, search_info.round_decimal_);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CachedSearchIterator::ValidateSearchInfo(const SearchInfo& search_info) {
|
||||
if (!search_info.iterator_v2_info_.has_value()) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Iterator v2 SearchInfo is not set");
|
||||
}
|
||||
|
||||
auto iterator_v2_info = search_info.iterator_v2_info_.value();
|
||||
if (iterator_v2_info.batch_size != batch_size_) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Batch size mismatch, expect %d, but got %d",
|
||||
batch_size_,
|
||||
iterator_v2_info.batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<CachedSearchIterator::DisIdPair>
|
||||
CachedSearchIterator::GetNextValidResult(
|
||||
const size_t iterator_idx,
|
||||
const std::optional<float>& last_bound,
|
||||
const std::optional<float>& radius,
|
||||
const std::optional<float>& range_filter) {
|
||||
auto& iterator = iterators_[iterator_idx];
|
||||
while (iterator->HasNext()) {
|
||||
auto result = ConvertIteratorResult(iterator->Next());
|
||||
if (IsValid(result, last_bound, radius, range_filter)) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// TODO: Optimize this method
|
||||
void
|
||||
CachedSearchIterator::MergeChunksResults(
|
||||
size_t query_idx,
|
||||
const std::optional<float>& last_bound,
|
||||
const std::optional<float>& radius,
|
||||
const std::optional<float>& range_filter,
|
||||
std::vector<DisIdPair>& rst) {
|
||||
auto& heap = chunked_heaps_[query_idx];
|
||||
|
||||
if (heap.empty()) {
|
||||
for (size_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
|
||||
const size_t iterator_idx = query_idx + chunk_id * nq_;
|
||||
if (auto next_result = GetNextValidResult(
|
||||
iterator_idx, last_bound, radius, range_filter);
|
||||
next_result.has_value()) {
|
||||
heap.emplace(iterator_idx, next_result.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (!heap.empty() && rst.size() < batch_size_) {
|
||||
const auto [iterator_idx, cur_rst] = heap.top();
|
||||
heap.pop();
|
||||
|
||||
// last_bound may change between NextBatch calls, discard any invalid results
|
||||
if (!IsValid(cur_rst, last_bound, radius, range_filter)) {
|
||||
continue;
|
||||
}
|
||||
rst.emplace_back(cur_rst);
|
||||
|
||||
if (auto next_result = GetNextValidResult(
|
||||
iterator_idx, last_bound, radius, range_filter);
|
||||
next_result.has_value()) {
|
||||
heap.emplace(iterator_idx, next_result.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<CachedSearchIterator::DisIdPair>
|
||||
CachedSearchIterator::GetBatchedNextResults(size_t query_idx,
|
||||
const SearchInfo& search_info) {
|
||||
auto last_bound = ConvertIncomingDistance(
|
||||
search_info.iterator_v2_info_.value().last_bound);
|
||||
auto radius = ConvertIncomingDistance(
|
||||
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS));
|
||||
auto range_filter =
|
||||
ConvertIncomingDistance(index::GetValueFromConfig<float>(
|
||||
search_info.search_params_, RANGE_FILTER));
|
||||
|
||||
std::vector<DisIdPair> rst;
|
||||
rst.reserve(batch_size_);
|
||||
|
||||
if (num_chunks_ == 1) {
|
||||
auto& iterator = iterators_[query_idx];
|
||||
while (iterator->HasNext() && rst.size() < batch_size_) {
|
||||
auto result = ConvertIteratorResult(iterator->Next());
|
||||
if (IsValid(result, last_bound, radius, range_filter)) {
|
||||
rst.emplace_back(result);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MergeChunksResults(query_idx, last_bound, radius, range_filter, rst);
|
||||
}
|
||||
std::sort(rst.begin(), rst.end());
|
||||
if (sign_ == -1) {
|
||||
std::for_each(rst.begin(), rst.end(), [this](DisIdPair& x) {
|
||||
x.first = x.first * sign_;
|
||||
});
|
||||
}
|
||||
while (rst.size() < batch_size_) {
|
||||
rst.emplace_back(1.0f / 0.0f, -1);
|
||||
}
|
||||
return rst;
|
||||
}
|
||||
|
||||
void
|
||||
CachedSearchIterator::WriteSingleQuerySearchResult(
|
||||
SearchResult& search_result,
|
||||
const size_t idx,
|
||||
std::vector<DisIdPair>& rst,
|
||||
const int64_t round_decimal) {
|
||||
const float multiplier = pow(10.0, round_decimal);
|
||||
|
||||
std::transform(rst.begin(),
|
||||
rst.end(),
|
||||
search_result.distances_.begin() + idx * batch_size_,
|
||||
[multiplier, round_decimal](DisIdPair& x) {
|
||||
if (round_decimal != -1) {
|
||||
x.first =
|
||||
std::round(x.first * multiplier) / multiplier;
|
||||
}
|
||||
return x.first;
|
||||
});
|
||||
|
||||
std::transform(rst.begin(),
|
||||
rst.end(),
|
||||
search_result.seg_offsets_.begin() + idx * batch_size_,
|
||||
[](const DisIdPair& x) { return x.second; });
|
||||
}
|
||||
|
||||
void
|
||||
CachedSearchIterator::Init(const SearchInfo& search_info) {
|
||||
if (!search_info.iterator_v2_info_.has_value()) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Iterator v2 info is not set, cannot initialize iterator");
|
||||
}
|
||||
|
||||
auto iterator_v2_info = search_info.iterator_v2_info_.value();
|
||||
if (iterator_v2_info.batch_size == 0) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Batch size is 0, cannot initialize iterator");
|
||||
}
|
||||
batch_size_ = iterator_v2_info.batch_size;
|
||||
|
||||
if (search_info.metric_type_.empty()) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Metric type is empty, cannot initialize iterator");
|
||||
}
|
||||
if (PositivelyRelated(search_info.metric_type_)) {
|
||||
sign_ = -1;
|
||||
} else {
|
||||
sign_ = 1;
|
||||
}
|
||||
|
||||
if (nq_ == 0) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"Number of queries is 0, cannot initialize iterator");
|
||||
}
|
||||
|
||||
// disable multi-query for now
|
||||
if (nq_ > 1) {
|
||||
PanicInfo(
|
||||
ErrorCode::UnexpectedError,
|
||||
"Number of queries is greater than 1, cannot initialize iterator");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
|
@ -0,0 +1,182 @@
|
|||
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "query/helper.h"
|
||||
#include "segcore/ConcurrentVector.h"
|
||||
#include "index/VectorIndex.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
// This class is used to cache the search results from Knowhere
|
||||
// search iterators and filter the results based on the last_bound,
|
||||
// radius and range_filter.
|
||||
// It provides a number of constructors to support different scenarios,
|
||||
// including growing/sealed, chunked/non-chunked.
|
||||
//
|
||||
// It does not care about TopK in search_info
|
||||
// The topk in SearchResult will be set to the batch_size for compatibility
|
||||
//
|
||||
// TODO: introduce the pool of results in the near future
|
||||
// TODO: replace VectorIterator class
|
||||
class CachedSearchIterator {
|
||||
public:
|
||||
// For sealed segment with vector index
|
||||
CachedSearchIterator(const milvus::index::VectorIndex& index,
|
||||
const knowhere::DataSetPtr& dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset);
|
||||
|
||||
// For sealed segment, BF
|
||||
CachedSearchIterator(const dataset::SearchDataset& dataset,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type);
|
||||
|
||||
// For growing segment with chunked data, BF
|
||||
CachedSearchIterator(const dataset::SearchDataset& dataset,
|
||||
const segcore::VectorBase* vec_data,
|
||||
const int64_t row_count,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type);
|
||||
|
||||
// For sealed segment with chunked data, BF
|
||||
CachedSearchIterator(const std::shared_ptr<ChunkedColumnBase>& column,
|
||||
const dataset::SearchDataset& dataset,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type);
|
||||
|
||||
// This method fetches the next batch of search results based on the provided search information
|
||||
// and updates the search_result object with the new batch of results.
|
||||
void
|
||||
NextBatch(const SearchInfo& search_info, SearchResult& search_result);
|
||||
|
||||
// Disable copy and move
|
||||
CachedSearchIterator(const CachedSearchIterator&) = delete;
|
||||
CachedSearchIterator&
|
||||
operator=(const CachedSearchIterator&) = delete;
|
||||
CachedSearchIterator(CachedSearchIterator&&) = delete;
|
||||
CachedSearchIterator&
|
||||
operator=(CachedSearchIterator&&) = delete;
|
||||
|
||||
private:
|
||||
using DisIdPair = std::pair<float, int64_t>;
|
||||
using IterIdx = size_t;
|
||||
using IterIdDisIdPair = std::pair<IterIdx, DisIdPair>;
|
||||
using GetChunkDataFunc =
|
||||
std::function<std::pair<const void*, int64_t>(int64_t)>;
|
||||
|
||||
int64_t batch_size_ = 0;
|
||||
std::vector<knowhere::IndexNode::IteratorPtr> iterators_;
|
||||
int8_t sign_ = 1;
|
||||
size_t num_chunks_ = 1;
|
||||
size_t nq_ = 0;
|
||||
|
||||
struct IterIdDisIdPairComparator {
|
||||
bool
|
||||
operator()(const IterIdDisIdPair& lhs, const IterIdDisIdPair& rhs) {
|
||||
if (lhs.second.first == rhs.second.first) {
|
||||
return lhs.second.second > rhs.second.second;
|
||||
}
|
||||
return lhs.second.first > rhs.second.first;
|
||||
}
|
||||
};
|
||||
std::vector<std::priority_queue<IterIdDisIdPair,
|
||||
std::vector<IterIdDisIdPair>,
|
||||
IterIdDisIdPairComparator>>
|
||||
chunked_heaps_;
|
||||
|
||||
inline bool
|
||||
IsValid(const DisIdPair& result,
|
||||
const std::optional<float>& last_bound,
|
||||
const std::optional<float>& radius,
|
||||
const std::optional<float>& range_filter) {
|
||||
const float dist = result.first;
|
||||
const bool is_valid =
|
||||
!last_bound.has_value() || dist > last_bound.value();
|
||||
|
||||
if (!radius.has_value()) {
|
||||
return is_valid;
|
||||
}
|
||||
|
||||
if (!range_filter.has_value()) {
|
||||
return is_valid && dist < radius.value();
|
||||
}
|
||||
|
||||
return is_valid && dist < radius.value() &&
|
||||
dist >= range_filter.value();
|
||||
}
|
||||
|
||||
inline DisIdPair
|
||||
ConvertIteratorResult(const std::pair<int64_t, float>& iter_rst) {
|
||||
DisIdPair rst;
|
||||
rst.first = iter_rst.second * sign_;
|
||||
rst.second = iter_rst.first;
|
||||
return rst;
|
||||
}
|
||||
|
||||
inline std::optional<float>
|
||||
ConvertIncomingDistance(std::optional<float> dist) {
|
||||
if (dist.has_value()) {
|
||||
dist = dist.value() * sign_;
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
std::optional<DisIdPair>
|
||||
GetNextValidResult(size_t iterator_idx,
|
||||
const std::optional<float>& last_bound,
|
||||
const std::optional<float>& radius,
|
||||
const std::optional<float>& range_filter);
|
||||
|
||||
void
|
||||
MergeChunksResults(size_t query_idx,
|
||||
const std::optional<float>& last_bound,
|
||||
const std::optional<float>& radius,
|
||||
const std::optional<float>& range_filter,
|
||||
std::vector<DisIdPair>& rst);
|
||||
|
||||
void
|
||||
ValidateSearchInfo(const SearchInfo& search_info);
|
||||
|
||||
std::vector<DisIdPair>
|
||||
GetBatchedNextResults(size_t query_idx, const SearchInfo& search_info);
|
||||
|
||||
void
|
||||
WriteSingleQuerySearchResult(SearchResult& search_result,
|
||||
const size_t idx,
|
||||
std::vector<DisIdPair>& rst,
|
||||
const int64_t round_decimal);
|
||||
|
||||
void
|
||||
Init(const SearchInfo& search_info);
|
||||
|
||||
void
|
||||
InitializeChunkedIterators(
|
||||
const dataset::SearchDataset& dataset,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type,
|
||||
const GetChunkDataFunc& get_chunk_data);
|
||||
};
|
||||
} // namespace milvus::query
|
|
@ -93,6 +93,20 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
search_info.strict_group_size_ =
|
||||
query_info_proto.strict_group_size();
|
||||
}
|
||||
|
||||
if (query_info_proto.has_search_iterator_v2_info()) {
|
||||
auto& iterator_v2_info_proto =
|
||||
query_info_proto.search_iterator_v2_info();
|
||||
search_info.iterator_v2_info_ = SearchIteratorV2Info{
|
||||
.token = iterator_v2_info_proto.token(),
|
||||
.batch_size = iterator_v2_info_proto.batch_size(),
|
||||
};
|
||||
if (iterator_v2_info_proto.has_last_bound()) {
|
||||
search_info.iterator_v2_info_->last_bound =
|
||||
iterator_v2_info_proto.last_bound();
|
||||
}
|
||||
}
|
||||
|
||||
return search_info;
|
||||
};
|
||||
|
||||
|
|
|
@ -226,45 +226,66 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
|||
return sub_result;
|
||||
}
|
||||
|
||||
SubSearchResult
|
||||
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type) {
|
||||
auto nq = query_ds.num_queries;
|
||||
auto [query_dataset, base_dataset] =
|
||||
PrepareBFDataSet(query_ds, raw_ds, data_type);
|
||||
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
|
||||
|
||||
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
|
||||
iterators_val;
|
||||
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
|
||||
DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset,
|
||||
const knowhere::DataSetPtr& query_dataset,
|
||||
const knowhere::Json& config,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type) {
|
||||
switch (data_type) {
|
||||
case DataType::VECTOR_FLOAT:
|
||||
iterators_val = knowhere::BruteForce::AnnIterator<float>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
return knowhere::BruteForce::AnnIterator<float>(
|
||||
base_dataset, query_dataset, config, bitset);
|
||||
break;
|
||||
case DataType::VECTOR_FLOAT16:
|
||||
//todo: if knowhere support real fp16/bf16 bf, change it
|
||||
iterators_val = knowhere::BruteForce::AnnIterator<float>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
return knowhere::BruteForce::AnnIterator<float>(
|
||||
base_dataset, query_dataset, config, bitset);
|
||||
break;
|
||||
case DataType::VECTOR_BFLOAT16:
|
||||
//todo: if knowhere support real fp16/bf16 bf, change it
|
||||
iterators_val = knowhere::BruteForce::AnnIterator<float>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
return knowhere::BruteForce::AnnIterator<float>(
|
||||
base_dataset, query_dataset, config, bitset);
|
||||
break;
|
||||
case DataType::VECTOR_SPARSE_FLOAT:
|
||||
iterators_val = knowhere::BruteForce::AnnIterator<
|
||||
return knowhere::BruteForce::AnnIterator<
|
||||
knowhere::sparse::SparseRow<float>>(
|
||||
base_dataset, query_dataset, search_cfg, bitset);
|
||||
base_dataset, query_dataset, config, bitset);
|
||||
break;
|
||||
default:
|
||||
PanicInfo(ErrorCode::Unsupported,
|
||||
"Unsupported dataType for chunk brute force iterator:{}",
|
||||
data_type);
|
||||
}
|
||||
}
|
||||
|
||||
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
|
||||
GetBruteForceSearchIterators(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type) {
|
||||
auto nq = query_ds.num_queries;
|
||||
auto [query_dataset, base_dataset] =
|
||||
PrepareBFDataSet(query_ds, raw_ds, data_type);
|
||||
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
|
||||
return DispatchBruteForceIteratorByDataType(
|
||||
base_dataset, query_dataset, search_cfg, bitset, data_type);
|
||||
}
|
||||
|
||||
SubSearchResult
|
||||
PackBruteForceSearchIteratorsIntoSubResult(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type) {
|
||||
auto nq = query_ds.num_queries;
|
||||
auto iterators_val = GetBruteForceSearchIterators(
|
||||
query_ds, raw_ds, search_info, index_info, bitset, data_type);
|
||||
if (iterators_val.has_value()) {
|
||||
AssertInfo(
|
||||
iterators_val.value().size() == nq,
|
||||
|
|
|
@ -31,12 +31,22 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
|||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
|
||||
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
|
||||
GetBruteForceSearchIterators(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
|
||||
SubSearchResult
|
||||
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
PackBruteForceSearchIteratorsIntoSubResult(
|
||||
const dataset::SearchDataset& query_ds,
|
||||
const dataset::RawDataset& raw_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "knowhere/comp/index_param.h"
|
||||
#include "knowhere/config.h"
|
||||
#include "log/Log.h"
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnIndex.h"
|
||||
#include "exec/operator/Utils.h"
|
||||
|
@ -125,6 +126,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||
|
||||
// step 3: brute force search where small indexing is unavailable
|
||||
auto vec_ptr = record.get_data_base(vecfield_id);
|
||||
|
||||
if (info.iterator_v2_info_.has_value()) {
|
||||
CachedSearchIterator cached_iter(search_dataset,
|
||||
vec_ptr,
|
||||
active_count,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
cached_iter.NextBatch(info, search_result);
|
||||
return;
|
||||
}
|
||||
|
||||
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
|
||||
auto max_chunk = upper_div(active_count, vec_size_per_chunk);
|
||||
|
||||
|
@ -140,12 +154,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||
auto sub_data = query::dataset::RawDataset{
|
||||
element_begin, dim, size_per_chunk, chunk_data};
|
||||
if (milvus::exec::UseVectorIterator(info)) {
|
||||
auto sub_qr = BruteForceSearchIterators(search_dataset,
|
||||
sub_data,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
auto sub_qr =
|
||||
PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
|
||||
sub_data,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
final_qr.merge(sub_qr);
|
||||
} else {
|
||||
auto sub_qr = BruteForceSearch(search_dataset,
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "SearchOnIndex.h"
|
||||
#include "exec/operator/Utils.h"
|
||||
#include "CachedSearchIterator.h"
|
||||
|
||||
namespace milvus::query {
|
||||
void
|
||||
|
@ -26,14 +27,23 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
|
|||
auto dataset =
|
||||
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
|
||||
dataset->SetIsSparse(is_sparse);
|
||||
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
|
||||
num_queries,
|
||||
dataset,
|
||||
search_result,
|
||||
bitset,
|
||||
indexing)) {
|
||||
indexing.Query(dataset, search_conf, bitset, search_result);
|
||||
if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
|
||||
num_queries,
|
||||
dataset,
|
||||
search_result,
|
||||
bitset,
|
||||
indexing)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (search_conf.iterator_v2_info_.has_value()) {
|
||||
auto iter =
|
||||
CachedSearchIterator(indexing, dataset, search_conf, bitset);
|
||||
iter.NextBatch(search_conf, search_result);
|
||||
return;
|
||||
}
|
||||
|
||||
indexing.Query(dataset, search_conf, bitset, search_result);
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "common/QueryInfo.h"
|
||||
#include "common/Types.h"
|
||||
#include "mmap/Column.h"
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnSealed.h"
|
||||
#include "query/helper.h"
|
||||
|
@ -55,13 +56,20 @@ SearchOnSealedIndex(const Schema& schema,
|
|||
dataset->SetIsSparse(is_sparse);
|
||||
auto vec_index =
|
||||
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||
|
||||
if (search_info.iterator_v2_info_.has_value()) {
|
||||
CachedSearchIterator cached_iter(
|
||||
*vec_index, dataset, search_info, bitset);
|
||||
cached_iter.NextBatch(search_info, search_result);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
|
||||
num_queries,
|
||||
dataset,
|
||||
search_result,
|
||||
bitset,
|
||||
*vec_index)) {
|
||||
auto index_type = vec_index->GetIndexType();
|
||||
vec_index->Query(dataset, search_info, bitset, search_result);
|
||||
float* distances = search_result.distances_.data();
|
||||
auto total_num = num_queries * topK;
|
||||
|
@ -104,6 +112,14 @@ SearchOnSealed(const Schema& schema,
|
|||
|
||||
auto data_type = field.get_data_type();
|
||||
CheckBruteForceSearchParam(field, search_info);
|
||||
|
||||
if (search_info.iterator_v2_info_.has_value()) {
|
||||
CachedSearchIterator cached_iter(
|
||||
column, query_dataset, search_info, index_info, bitview, data_type);
|
||||
cached_iter.NextBatch(search_info, result);
|
||||
return;
|
||||
}
|
||||
|
||||
auto num_chunk = column->num_chunks();
|
||||
|
||||
SubSearchResult final_qr(num_queries,
|
||||
|
@ -115,17 +131,16 @@ SearchOnSealed(const Schema& schema,
|
|||
for (int i = 0; i < num_chunk; ++i) {
|
||||
auto vec_data = column->Data(i);
|
||||
auto chunk_size = column->chunk_row_nums(i);
|
||||
const uint8_t* bitset_ptr = nullptr;
|
||||
auto data_id = offset;
|
||||
auto raw_dataset =
|
||||
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
|
||||
if (milvus::exec::UseVectorIterator(search_info)) {
|
||||
auto sub_qr = BruteForceSearchIterators(query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitview,
|
||||
data_type);
|
||||
auto sub_qr =
|
||||
PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitview,
|
||||
data_type);
|
||||
final_qr.merge(sub_qr);
|
||||
} else {
|
||||
auto sub_qr = BruteForceSearch(query_dataset,
|
||||
|
@ -136,7 +151,6 @@ SearchOnSealed(const Schema& schema,
|
|||
data_type);
|
||||
final_qr.merge(sub_qr);
|
||||
}
|
||||
|
||||
offset += chunk_size;
|
||||
}
|
||||
if (milvus::exec::UseVectorIterator(search_info)) {
|
||||
|
@ -181,14 +195,23 @@ SearchOnSealed(const Schema& schema,
|
|||
CheckBruteForceSearchParam(field, search_info);
|
||||
auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data};
|
||||
if (milvus::exec::UseVectorIterator(search_info)) {
|
||||
auto sub_qr = BruteForceSearchIterators(query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
auto sub_qr = PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
result.AssembleChunkVectorIterators(
|
||||
num_queries, 1, {0}, sub_qr.chunk_iterators());
|
||||
} else if (search_info.iterator_v2_info_.has_value()) {
|
||||
CachedSearchIterator cached_iter(query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
cached_iter.NextBatch(search_info, result);
|
||||
return;
|
||||
} else {
|
||||
auto sub_qr = BruteForceSearch(query_dataset,
|
||||
raw_dataset,
|
||||
|
|
|
@ -89,6 +89,7 @@ set(MILVUS_TEST_FILES
|
|||
test_chunked_segment.cpp
|
||||
test_chunked_column.cpp
|
||||
test_rust_result.cpp
|
||||
test_cached_search_iterator.cpp
|
||||
)
|
||||
|
||||
if ( INDEX_ENGINE STREQUAL "cardinal" )
|
||||
|
|
|
@ -143,12 +143,13 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
|||
AssertMatch(ref, ans);
|
||||
}
|
||||
|
||||
auto result3 = BruteForceSearchIterators(query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
auto result3 = PackBruteForceSearchIteratorsIntoSubResult(
|
||||
query_dataset,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
auto iterators = result3.chunk_iterators();
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto it = iterators[i];
|
||||
|
|
|
@ -0,0 +1,797 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "common/Utils.h"
|
||||
#include "index/Index.h"
|
||||
#include "knowhere/comp/index_param.h"
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "index/VectorIndex.h"
|
||||
#include "index/IndexFactory.h"
|
||||
#include "knowhere/dataset.h"
|
||||
#include "query/helper.h"
|
||||
#include "segcore/ConcurrentVector.h"
|
||||
#include "segcore/InsertRecord.h"
|
||||
#include "mmap/ChunkedColumn.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
using namespace milvus::index;
|
||||
|
||||
namespace {
|
||||
constexpr int64_t kDim = 16;
|
||||
constexpr int64_t kNumVectors = 1000;
|
||||
constexpr int64_t kNumQueries = 1;
|
||||
constexpr int64_t kBatchSize = 100;
|
||||
constexpr size_t kSizePerChunk = 128;
|
||||
constexpr size_t kHnswM = 24;
|
||||
constexpr size_t kHnswEfConstruction = 360;
|
||||
constexpr size_t kHnswEf = 128;
|
||||
|
||||
const MetricType kMetricType = knowhere::metric::L2;
|
||||
} // namespace
|
||||
|
||||
enum class ConstructorType {
|
||||
VectorIndex = 0,
|
||||
RawData,
|
||||
VectorBase,
|
||||
ChunkedColumn
|
||||
};
|
||||
|
||||
static const std::vector<ConstructorType> kConstructorTypes = {
|
||||
ConstructorType::VectorIndex,
|
||||
ConstructorType::RawData,
|
||||
ConstructorType::VectorBase,
|
||||
ConstructorType::ChunkedColumn,
|
||||
};
|
||||
|
||||
static const std::vector<MetricType> kMetricTypes = {
|
||||
knowhere::metric::L2,
|
||||
knowhere::metric::IP,
|
||||
knowhere::metric::COSINE,
|
||||
};
|
||||
|
||||
// this class does not support test concurrently
|
||||
class CachedSearchIteratorTest
|
||||
: public ::testing::TestWithParam<std::tuple<ConstructorType, MetricType>> {
|
||||
private:
|
||||
protected:
|
||||
SearchInfo
|
||||
GetDefaultNormalSearchInfo() {
|
||||
return SearchInfo{
|
||||
.topk_ = kBatchSize,
|
||||
.round_decimal_ = -1,
|
||||
.metric_type_ = std::get<1>(GetParam()),
|
||||
.search_params_ =
|
||||
{
|
||||
{knowhere::indexparam::EF, std::to_string(kHnswEf)},
|
||||
},
|
||||
.iterator_v2_info_ =
|
||||
SearchIteratorV2Info{
|
||||
.batch_size = kBatchSize,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
static DataType data_type_;
|
||||
static int64_t dim_;
|
||||
static int64_t nb_;
|
||||
static int64_t nq_;
|
||||
static FixedVector<float> base_dataset_;
|
||||
static FixedVector<float> query_dataset_;
|
||||
static IndexBasePtr index_hnsw_l2_;
|
||||
static IndexBasePtr index_hnsw_ip_;
|
||||
static IndexBasePtr index_hnsw_cos_;
|
||||
static knowhere::DataSetPtr knowhere_query_dataset_;
|
||||
static dataset::SearchDataset search_dataset_;
|
||||
static std::unique_ptr<ConcurrentVector<milvus::FloatVector>> vector_base_;
|
||||
static std::shared_ptr<ChunkedColumn> column_;
|
||||
static std::vector<std::vector<char>> column_data_;
|
||||
|
||||
IndexBase* index_hnsw_ = nullptr;
|
||||
MetricType metric_type_ = kMetricType;
|
||||
|
||||
std::unique_ptr<CachedSearchIterator>
|
||||
DispatchIterator(const ConstructorType& constructor_type,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) {
|
||||
switch (constructor_type) {
|
||||
case ConstructorType::VectorIndex:
|
||||
return std::make_unique<CachedSearchIterator>(
|
||||
dynamic_cast<const VectorIndex&>(*index_hnsw_),
|
||||
knowhere_query_dataset_,
|
||||
search_info,
|
||||
bitset);
|
||||
|
||||
case ConstructorType::RawData:
|
||||
return std::make_unique<CachedSearchIterator>(
|
||||
search_dataset_,
|
||||
dataset::RawDataset{0, dim_, nb_, base_dataset_.data()},
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
bitset,
|
||||
data_type_);
|
||||
|
||||
case ConstructorType::VectorBase:
|
||||
return std::make_unique<CachedSearchIterator>(
|
||||
search_dataset_,
|
||||
vector_base_.get(),
|
||||
nb_,
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
bitset,
|
||||
data_type_);
|
||||
|
||||
case ConstructorType::ChunkedColumn:
|
||||
return std::make_unique<CachedSearchIterator>(
|
||||
column_,
|
||||
search_dataset_,
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
bitset,
|
||||
data_type_);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// use last distance of the first batch as range_filter
|
||||
// use first distance of the last batch as radius
|
||||
std::pair<float, float>
|
||||
GetRadiusAndRangeFilter() {
|
||||
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
|
||||
SearchResult search_result;
|
||||
float radius, range_filter;
|
||||
bool get_radius_success = false;
|
||||
bool get_range_filter_sucess = false;
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
if (rnd == 0) {
|
||||
for (size_t i = kBatchSize - 1; i >= 0; --i) {
|
||||
if (search_result.seg_offsets_[i] != -1) {
|
||||
range_filter = search_result.distances_[i];
|
||||
get_range_filter_sucess = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
if (search_result.seg_offsets_[i] != -1) {
|
||||
radius = search_result.distances_[i];
|
||||
get_radius_success = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!get_radius_success || !get_range_filter_sucess) {
|
||||
throw std::runtime_error("Failed to get radius and range filter");
|
||||
}
|
||||
return {radius, range_filter};
|
||||
}
|
||||
|
||||
static void
|
||||
BuildIndex() {
|
||||
auto dataset = knowhere::GenDataSet(nb_, dim_, base_dataset_.data());
|
||||
|
||||
for (const auto& metric_type : kMetricTypes) {
|
||||
milvus::index::CreateIndexInfo create_index_info;
|
||||
create_index_info.field_type = data_type_;
|
||||
create_index_info.metric_type = metric_type;
|
||||
create_index_info.index_engine_version =
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber();
|
||||
auto build_conf = knowhere::Json{
|
||||
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
|
||||
{knowhere::meta::DIM, std::to_string(dim_)},
|
||||
{knowhere::indexparam::M, std::to_string(kHnswM)},
|
||||
{knowhere::indexparam::EFCONSTRUCTION,
|
||||
std::to_string(kHnswEfConstruction)}};
|
||||
create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
|
||||
if (metric_type == knowhere::metric::L2) {
|
||||
index_hnsw_l2_ =
|
||||
milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info,
|
||||
milvus::storage::FileManagerContext());
|
||||
index_hnsw_l2_->BuildWithDataset(dataset, build_conf);
|
||||
ASSERT_EQ(index_hnsw_l2_->Count(), nb_);
|
||||
} else if (metric_type == knowhere::metric::IP) {
|
||||
index_hnsw_ip_ =
|
||||
milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info,
|
||||
milvus::storage::FileManagerContext());
|
||||
index_hnsw_ip_->BuildWithDataset(dataset, build_conf);
|
||||
ASSERT_EQ(index_hnsw_ip_->Count(), nb_);
|
||||
} else if (metric_type == knowhere::metric::COSINE) {
|
||||
index_hnsw_cos_ =
|
||||
milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info,
|
||||
milvus::storage::FileManagerContext());
|
||||
index_hnsw_cos_->BuildWithDataset(dataset, build_conf);
|
||||
ASSERT_EQ(index_hnsw_cos_->Count(), nb_);
|
||||
} else {
|
||||
FAIL() << "Unsupported metric type: " << metric_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
SetUpVectorBase() {
|
||||
vector_base_ = std::make_unique<ConcurrentVector<milvus::FloatVector>>(
|
||||
dim_, kSizePerChunk);
|
||||
vector_base_->set_data_raw(0, base_dataset_.data(), nb_);
|
||||
|
||||
ASSERT_EQ(vector_base_->num_chunk(),
|
||||
(nb_ + kSizePerChunk - 1) / kSizePerChunk);
|
||||
}
|
||||
|
||||
static void
|
||||
SetUpChunkedColumn() {
|
||||
column_ = std::make_unique<ChunkedColumn>();
|
||||
const size_t num_chunks_ = (nb_ + kSizePerChunk - 1) / kSizePerChunk;
|
||||
column_data_.resize(num_chunks_);
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < num_chunks_; ++i) {
|
||||
const size_t rows = std::min(nb_ - offset, kSizePerChunk);
|
||||
const size_t chunk_bitset_size = (rows + 7) / 8;
|
||||
const size_t buf_size =
|
||||
chunk_bitset_size + rows * dim_ * sizeof(float);
|
||||
auto& chunk_data = column_data_[i];
|
||||
chunk_data.resize(buf_size);
|
||||
memcpy(chunk_data.data() + chunk_bitset_size,
|
||||
base_dataset_.cbegin() + offset * dim_,
|
||||
rows * dim_ * sizeof(float));
|
||||
column_->AddChunk(std::make_shared<FixedWidthChunk>(
|
||||
rows, dim_, chunk_data.data(), buf_size, sizeof(float), false));
|
||||
offset += rows;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
SetUpTestSuite() {
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto fakevec_id = schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, dim_, kMetricType);
|
||||
|
||||
// generate base dataset
|
||||
base_dataset_ =
|
||||
segcore::DataGen(schema, nb_).get_col<float>(fakevec_id);
|
||||
|
||||
// generate query dataset
|
||||
query_dataset_ = {base_dataset_.cbegin(),
|
||||
base_dataset_.cbegin() + nq_ * dim_};
|
||||
knowhere_query_dataset_ =
|
||||
knowhere::GenDataSet(nq_, dim_, query_dataset_.data());
|
||||
search_dataset_ = dataset::SearchDataset{
|
||||
.metric_type = kMetricType,
|
||||
.num_queries = nq_,
|
||||
.topk = kBatchSize,
|
||||
.round_decimal = -1,
|
||||
.dim = dim_,
|
||||
.query_data = query_dataset_.data(),
|
||||
};
|
||||
|
||||
BuildIndex();
|
||||
SetUpVectorBase();
|
||||
SetUpChunkedColumn();
|
||||
}
|
||||
|
||||
static void
|
||||
TearDownTestSuite() {
|
||||
base_dataset_.clear();
|
||||
query_dataset_.clear();
|
||||
index_hnsw_l2_.reset();
|
||||
index_hnsw_ip_.reset();
|
||||
index_hnsw_cos_.reset();
|
||||
knowhere_query_dataset_.reset();
|
||||
vector_base_.reset();
|
||||
column_.reset();
|
||||
}
|
||||
|
||||
void
|
||||
SetUp() override {
|
||||
auto metric_type = std::get<1>(GetParam());
|
||||
if (metric_type == knowhere::metric::L2) {
|
||||
metric_type_ = knowhere::metric::L2;
|
||||
search_dataset_.metric_type = knowhere::metric::L2;
|
||||
index_hnsw_ = index_hnsw_l2_.get();
|
||||
} else if (metric_type == knowhere::metric::IP) {
|
||||
metric_type_ = knowhere::metric::IP;
|
||||
search_dataset_.metric_type = knowhere::metric::IP;
|
||||
index_hnsw_ = index_hnsw_ip_.get();
|
||||
} else if (metric_type == knowhere::metric::COSINE) {
|
||||
metric_type_ = knowhere::metric::COSINE;
|
||||
search_dataset_.metric_type = knowhere::metric::COSINE;
|
||||
index_hnsw_ = index_hnsw_cos_.get();
|
||||
} else {
|
||||
FAIL() << "Unsupported metric type: " << metric_type;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
}
|
||||
};
|
||||
|
||||
// initialize static variables
|
||||
DataType CachedSearchIteratorTest::data_type_ = DataType::VECTOR_FLOAT;
|
||||
int64_t CachedSearchIteratorTest::dim_ = kDim;
|
||||
int64_t CachedSearchIteratorTest::nb_ = kNumVectors;
|
||||
int64_t CachedSearchIteratorTest::nq_ = kNumQueries;
|
||||
IndexBasePtr CachedSearchIteratorTest::index_hnsw_l2_ = nullptr;
|
||||
IndexBasePtr CachedSearchIteratorTest::index_hnsw_ip_ = nullptr;
|
||||
IndexBasePtr CachedSearchIteratorTest::index_hnsw_cos_ = nullptr;
|
||||
knowhere::DataSetPtr CachedSearchIteratorTest::knowhere_query_dataset_ =
|
||||
nullptr;
|
||||
dataset::SearchDataset CachedSearchIteratorTest::search_dataset_;
|
||||
FixedVector<float> CachedSearchIteratorTest::base_dataset_;
|
||||
FixedVector<float> CachedSearchIteratorTest::query_dataset_;
|
||||
std::unique_ptr<ConcurrentVector<milvus::FloatVector>>
|
||||
CachedSearchIteratorTest::vector_base_ = nullptr;
|
||||
std::shared_ptr<ChunkedColumn> CachedSearchIteratorTest::column_ = nullptr;
|
||||
std::vector<std::vector<char>> CachedSearchIteratorTest::column_data_;
|
||||
|
||||
/********* Testcases Start **********/
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchNormal) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
const std::vector<size_t> kBatchSizes = {
|
||||
1, 7, 43, 99, 100, 101, 1000, 1005};
|
||||
|
||||
for (size_t batch_size : kBatchSizes) {
|
||||
std::cout << "batch_size: " << batch_size << std::endl;
|
||||
search_info.iterator_v2_info_->batch_size = batch_size;
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
std::unordered_set<int64_t> seg_offsets;
|
||||
size_t cnt = 0;
|
||||
for (size_t j = 0; j < batch_size; ++j) {
|
||||
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
|
||||
break;
|
||||
}
|
||||
++cnt;
|
||||
seg_offsets.insert(
|
||||
search_result.seg_offsets_[i * batch_size + j]);
|
||||
}
|
||||
EXPECT_EQ(seg_offsets.size(), cnt);
|
||||
if (metric_type_ == knowhere::metric::L2) {
|
||||
EXPECT_EQ(search_result.distances_[i * batch_size], 0);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(search_result.unity_topK_, batch_size);
|
||||
EXPECT_EQ(search_result.total_nq_, nq_);
|
||||
EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size);
|
||||
EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchDistBound) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
const size_t batch_size = kBatchSize;
|
||||
const float dist_bound_factor = PositivelyRelated(metric_type_) ? 0.5 : 1.5;
|
||||
float dist_bound = 0;
|
||||
|
||||
{
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
|
||||
bool found_dist_bound = false;
|
||||
// use the last distance of the first query * factor as the dist bound
|
||||
for (size_t j = batch_size - 1; j >= 0; --j) {
|
||||
if (search_result.seg_offsets_[j] != -1) {
|
||||
dist_bound = search_result.distances_[j] * dist_bound_factor;
|
||||
found_dist_bound = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(found_dist_bound);
|
||||
|
||||
search_info.iterator_v2_info_->last_bound = dist_bound;
|
||||
for (size_t rnd = 1; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
for (size_t j = 0; j < batch_size; ++j) {
|
||||
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
|
||||
break;
|
||||
}
|
||||
if (PositivelyRelated(metric_type_)) {
|
||||
EXPECT_LT(search_result.distances_[i * batch_size + j],
|
||||
dist_bound);
|
||||
} else {
|
||||
EXPECT_GT(search_result.distances_[i * batch_size + j],
|
||||
dist_bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchDistBoundEmptyResults) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
const size_t batch_size = kBatchSize;
|
||||
const float dist_bound = PositivelyRelated(metric_type_)
|
||||
? -std::numeric_limits<float>::max()
|
||||
: std::numeric_limits<float>::max();
|
||||
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
|
||||
search_info.iterator_v2_info_->last_bound = dist_bound;
|
||||
size_t total_cnt = 0;
|
||||
for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
for (size_t j = 0; j < batch_size; ++j) {
|
||||
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
|
||||
break;
|
||||
}
|
||||
++total_cnt;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(total_cnt, 0);
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadius) {
|
||||
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
|
||||
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
|
||||
SearchResult search_result;
|
||||
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
search_info.search_params_[knowhere::meta::RADIUS] = radius;
|
||||
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
for (size_t j = 0; j < kBatchSize; ++j) {
|
||||
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
|
||||
break;
|
||||
}
|
||||
float dist = search_result.distances_[i * kBatchSize + j];
|
||||
if (PositivelyRelated(metric_type_)) {
|
||||
ASSERT_GT(dist, radius);
|
||||
} else {
|
||||
ASSERT_LT(dist, radius);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadiusAndRangeFilter) {
|
||||
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
|
||||
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
|
||||
SearchResult search_result;
|
||||
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
search_info.search_params_[knowhere::meta::RADIUS] = radius;
|
||||
search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter;
|
||||
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
for (size_t j = 0; j < kBatchSize; ++j) {
|
||||
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
|
||||
break;
|
||||
}
|
||||
float dist = search_result.distances_[i * kBatchSize + j];
|
||||
if (PositivelyRelated(metric_type_)) {
|
||||
ASSERT_GT(dist, radius);
|
||||
ASSERT_LE(dist, range_filter);
|
||||
} else {
|
||||
ASSERT_LT(dist, radius);
|
||||
ASSERT_GE(dist, range_filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest,
|
||||
NextBatchRangeSearchLastBoundRadiusRangeFilter) {
|
||||
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
|
||||
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
|
||||
SearchResult search_result;
|
||||
const float diff = (radius + range_filter) / 2;
|
||||
const std::vector<float> last_bounds = {radius - diff,
|
||||
radius,
|
||||
radius + diff,
|
||||
range_filter,
|
||||
range_filter + diff};
|
||||
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
search_info.search_params_[knowhere::meta::RADIUS] = radius;
|
||||
search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter;
|
||||
for (float last_bound : last_bounds) {
|
||||
search_info.iterator_v2_info_->last_bound = last_bound;
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
for (size_t j = 0; j < kBatchSize; ++j) {
|
||||
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
|
||||
break;
|
||||
}
|
||||
float dist = search_result.distances_[i * kBatchSize + j];
|
||||
if (PositivelyRelated(metric_type_)) {
|
||||
ASSERT_LE(dist, last_bound);
|
||||
ASSERT_GT(dist, radius);
|
||||
ASSERT_LE(dist, range_filter);
|
||||
} else {
|
||||
ASSERT_GT(dist, last_bound);
|
||||
ASSERT_LT(dist, radius);
|
||||
ASSERT_GE(dist, range_filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchZeroBatchSize) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
|
||||
search_info.iterator_v2_info_->batch_size = 0;
|
||||
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchDiffBatchSizeComparedToInit) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
|
||||
search_info.iterator_v2_info_->batch_size = kBatchSize + 1;
|
||||
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchEmptySearchInfo) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
|
||||
SearchInfo empty_search_info;
|
||||
EXPECT_THROW(iterator->NextBatch(empty_search_info, search_result),
|
||||
SegcoreError);
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchEmptyIteratorV2Info) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
SearchResult search_result;
|
||||
|
||||
search_info.iterator_v2_info_ = std::nullopt;
|
||||
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, NextBatchtAllBatchesNormal) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
const std::vector<size_t> kBatchSizes = {
|
||||
1, 7, 43, 99, 100, 101, 1000, 1005};
|
||||
// const std::vector<size_t> kBatchSizes = {1005};
|
||||
|
||||
for (size_t batch_size : kBatchSizes) {
|
||||
search_info.iterator_v2_info_->batch_size = batch_size;
|
||||
auto iterator =
|
||||
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
|
||||
size_t total_cnt = 0;
|
||||
|
||||
for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
|
||||
SearchResult search_result;
|
||||
iterator->NextBatch(search_info, search_result);
|
||||
for (size_t i = 0; i < nq_; ++i) {
|
||||
std::unordered_set<int64_t> seg_offsets;
|
||||
size_t cnt = 0;
|
||||
for (size_t j = 0; j < batch_size; ++j) {
|
||||
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
|
||||
break;
|
||||
}
|
||||
++cnt;
|
||||
seg_offsets.insert(
|
||||
search_result.seg_offsets_[i * batch_size + j]);
|
||||
}
|
||||
total_cnt += cnt;
|
||||
// check no duplicate
|
||||
EXPECT_EQ(seg_offsets.size(), cnt);
|
||||
|
||||
// only check if the first distance of the first batch is 0
|
||||
if (rnd == 0 && metric_type_ == knowhere::metric::L2) {
|
||||
EXPECT_EQ(search_result.distances_[i * batch_size], 0);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(search_result.unity_topK_, batch_size);
|
||||
EXPECT_EQ(search_result.total_nq_, nq_);
|
||||
EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size);
|
||||
EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size);
|
||||
}
|
||||
if (std::get<0>(GetParam()) == ConstructorType::VectorIndex) {
|
||||
EXPECT_GE(total_cnt, nb_ * nq_ * 0.9);
|
||||
} else {
|
||||
EXPECT_EQ(total_cnt, nb_ * nq_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidSearchInfo) {
|
||||
EXPECT_THROW(
|
||||
DispatchIterator(std::get<0>(GetParam()), SearchInfo{}, nullptr),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(
|
||||
DispatchIterator(
|
||||
std::get<0>(GetParam()), SearchInfo{.metric_type_ = ""}, nullptr),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(DispatchIterator(std::get<0>(GetParam()),
|
||||
SearchInfo{.metric_type_ = metric_type_},
|
||||
nullptr),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(DispatchIterator(std::get<0>(GetParam()),
|
||||
SearchInfo{.metric_type_ = metric_type_,
|
||||
.iterator_v2_info_ = {}},
|
||||
nullptr),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(
|
||||
DispatchIterator(std::get<0>(GetParam()),
|
||||
SearchInfo{.metric_type_ = metric_type_,
|
||||
.iterator_v2_info_ =
|
||||
SearchIteratorV2Info{.batch_size = 0}},
|
||||
nullptr),
|
||||
SegcoreError);
|
||||
}
|
||||
|
||||
TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidParams) {
|
||||
SearchInfo search_info = GetDefaultNormalSearchInfo();
|
||||
if (std::get<0>(GetParam()) == ConstructorType::VectorIndex) {
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
dynamic_cast<const VectorIndex&>(*index_hnsw_),
|
||||
nullptr,
|
||||
search_info,
|
||||
nullptr),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
dynamic_cast<const VectorIndex&>(*index_hnsw_),
|
||||
std::make_shared<knowhere::DataSet>(),
|
||||
search_info,
|
||||
nullptr),
|
||||
SegcoreError);
|
||||
} else if (std::get<0>(GetParam()) == ConstructorType::RawData) {
|
||||
EXPECT_THROW(
|
||||
auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
dataset::SearchDataset{},
|
||||
dataset::RawDataset{0, dim_, nb_, base_dataset_.data()},
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
nullptr,
|
||||
data_type_),
|
||||
SegcoreError);
|
||||
} else if (std::get<0>(GetParam()) == ConstructorType::VectorBase) {
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
dataset::SearchDataset{},
|
||||
vector_base_.get(),
|
||||
nb_,
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
nullptr,
|
||||
data_type_),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
search_dataset_,
|
||||
nullptr,
|
||||
nb_,
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
nullptr,
|
||||
data_type_),
|
||||
SegcoreError);
|
||||
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
search_dataset_,
|
||||
vector_base_.get(),
|
||||
0,
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
nullptr,
|
||||
data_type_),
|
||||
SegcoreError);
|
||||
} else if (std::get<0>(GetParam()) == ConstructorType::ChunkedColumn) {
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
nullptr,
|
||||
search_dataset_,
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
nullptr,
|
||||
data_type_),
|
||||
SegcoreError);
|
||||
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
|
||||
column_,
|
||||
dataset::SearchDataset{},
|
||||
search_info,
|
||||
std::map<std::string, std::string>{},
|
||||
nullptr,
|
||||
data_type_),
|
||||
SegcoreError);
|
||||
}
|
||||
}
|
||||
|
||||
/********* Testcases End **********/
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
CachedSearchIteratorTests,
|
||||
CachedSearchIteratorTest,
|
||||
::testing::Combine(::testing::ValuesIn(kConstructorTypes),
|
||||
::testing::ValuesIn(kMetricTypes)),
|
||||
[](const testing::TestParamInfo<std::tuple<ConstructorType, MetricType>>&
|
||||
info) {
|
||||
std::string constructor_type_str;
|
||||
ConstructorType constructor_type = std::get<0>(info.param);
|
||||
MetricType metric_type = std::get<1>(info.param);
|
||||
switch (constructor_type) {
|
||||
case ConstructorType::VectorIndex:
|
||||
constructor_type_str = "VectorIndex";
|
||||
break;
|
||||
case ConstructorType::RawData:
|
||||
constructor_type_str = "RawData";
|
||||
break;
|
||||
case ConstructorType::VectorBase:
|
||||
constructor_type_str = "VectorBase";
|
||||
break;
|
||||
case ConstructorType::ChunkedColumn:
|
||||
constructor_type_str = "ChunkedColumn";
|
||||
break;
|
||||
default:
|
||||
constructor_type_str = "Unknown constructor type";
|
||||
};
|
||||
if (metric_type == knowhere::metric::L2) {
|
||||
constructor_type_str += "_L2";
|
||||
} else if (metric_type == knowhere::metric::IP) {
|
||||
constructor_type_str += "_IP";
|
||||
} else if (metric_type == knowhere::metric::COSINE) {
|
||||
constructor_type_str += "_COSINE";
|
||||
} else {
|
||||
constructor_type_str += "_Unknown";
|
||||
}
|
||||
return constructor_type_str;
|
||||
});
|
|
@ -55,6 +55,12 @@ message Array {
|
|||
schema.DataType element_type = 3;
|
||||
}
|
||||
|
||||
message SearchIteratorV2Info {
|
||||
string token = 1;
|
||||
uint32 batch_size = 2;
|
||||
optional float last_bound = 3;
|
||||
}
|
||||
|
||||
message QueryInfo {
|
||||
int64 topk = 1;
|
||||
string metric_type = 3;
|
||||
|
@ -67,6 +73,7 @@ message QueryInfo {
|
|||
double bm25_avgdl = 10;
|
||||
int64 query_field_id =11;
|
||||
string hints = 12;
|
||||
optional SearchIteratorV2Info search_iterator_v2_info = 13;
|
||||
}
|
||||
|
||||
message ColumnInfo {
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/atomic"
|
||||
|
@ -304,6 +305,13 @@ func (node *Proxy) Init() error {
|
|||
|
||||
node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool()
|
||||
|
||||
// Enable internal rand pool for UUIDv4 generation
|
||||
// This is NOT thread-safe and should only be called before the service starts and
|
||||
// there is no possibility that New or any other UUID V4 generation function will be called concurrently
|
||||
// Only proxy generates UUID for now, and one Milvus process only has one proxy
|
||||
uuid.EnableRandPool()
|
||||
log.Debug("enable rand pool for UUIDv4 generation")
|
||||
|
||||
log.Info("init proxy done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address))
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
@ -82,6 +83,81 @@ type SearchInfo struct {
|
|||
isIterator bool
|
||||
}
|
||||
|
||||
func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) {
|
||||
isIteratorV2Str, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterV2Key, searchParamsPair)
|
||||
isIteratorV2, _ := strconv.ParseBool(isIteratorV2Str)
|
||||
if !isIteratorV2 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// iteratorV1 and iteratorV2 should be set together for compatibility
|
||||
if !isIterator {
|
||||
return nil, fmt.Errorf("both %s and %s must be set in the SDK", IteratorField, SearchIterV2Key)
|
||||
}
|
||||
|
||||
// disable groupBy when doing iteratorV2
|
||||
// same behavior with V1
|
||||
if isIteratorV2 && groupByFieldId > 0 {
|
||||
return nil, merr.WrapErrParameterInvalid("", "",
|
||||
"GroupBy is not permitted when using a search iterator")
|
||||
}
|
||||
|
||||
// disable offset when doing iteratorV2
|
||||
if isIteratorV2 && offset > 0 {
|
||||
return nil, merr.WrapErrParameterInvalid("", "",
|
||||
"Setting an offset is not permitted when using a search iterator v2")
|
||||
}
|
||||
|
||||
// parse token, generate if not exist
|
||||
token, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterIdKey, searchParamsPair)
|
||||
if token == "" {
|
||||
generatedToken, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token = generatedToken.String()
|
||||
} else {
|
||||
// Validate existing token is a valid UUID
|
||||
if _, err := uuid.Parse(token); err != nil {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
}
|
||||
|
||||
// parse batch size, required non-zero value
|
||||
batchSizeStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterBatchSizeKey, searchParamsPair)
|
||||
if batchSizeStr == "" {
|
||||
return nil, fmt.Errorf("batch size is required")
|
||||
}
|
||||
batchSize, err := strconv.ParseInt(batchSizeStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch size is invalid, %w", err)
|
||||
}
|
||||
// use the same validation logic as topk
|
||||
if err := validateLimit(batchSize); err != nil {
|
||||
return nil, fmt.Errorf("batch size is invalid, %w", err)
|
||||
}
|
||||
*queryTopK = batchSize // for compatibility
|
||||
|
||||
// prepare plan iterator v2 info proto
|
||||
planIteratorV2Info := &planpb.SearchIteratorV2Info{
|
||||
Token: token,
|
||||
BatchSize: uint32(batchSize),
|
||||
}
|
||||
|
||||
// append optional last bound if applicable
|
||||
lastBoundStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterLastBoundKey, searchParamsPair)
|
||||
if lastBoundStr != "" {
|
||||
lastBound, err := strconv.ParseFloat(lastBoundStr, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse input last bound, %w", err)
|
||||
}
|
||||
lastBoundFloat32 := float32(lastBound)
|
||||
planIteratorV2Info.LastBound = &lastBoundFloat32 // escape pointer
|
||||
}
|
||||
|
||||
return planIteratorV2Info, nil
|
||||
}
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
|
||||
var topK int64
|
||||
|
@ -196,16 +272,22 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
"Not allowed to do range-search when doing search-group-by")}
|
||||
}
|
||||
|
||||
planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, offset, &queryTopK)
|
||||
if err != nil {
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("parse iterator v2 info failed: %w", err)}
|
||||
}
|
||||
|
||||
return &SearchInfo{
|
||||
planInfo: &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
StrictGroupSize: strictGroupSize,
|
||||
Hints: hints,
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
StrictGroupSize: strictGroupSize,
|
||||
Hints: hints,
|
||||
SearchIteratorV2Info: planSearchIteratorV2Info,
|
||||
},
|
||||
offset: offset,
|
||||
isIterator: isIterator,
|
||||
|
|
|
@ -69,6 +69,11 @@ const (
|
|||
OffsetKey = "offset"
|
||||
LimitKey = "limit"
|
||||
|
||||
SearchIterV2Key = "search_iter_v2"
|
||||
SearchIterBatchSizeKey = "search_iter_batch_size"
|
||||
SearchIterLastBoundKey = "search_iter_last_bound"
|
||||
SearchIterIdKey = "search_iter_id"
|
||||
|
||||
InsertTaskName = "InsertTask"
|
||||
CreateCollectionTaskName = "CreateCollectionTask"
|
||||
DropCollectionTaskName = "DropCollectionTask"
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
|
@ -590,12 +591,15 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
|
||||
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
|
||||
metricType := ""
|
||||
if len(toReduceResults) >= 1 {
|
||||
metricType = toReduceResults[0].GetMetricType()
|
||||
}
|
||||
return metricType
|
||||
}
|
||||
|
||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
|
||||
defer sp.End()
|
||||
|
||||
|
@ -631,6 +635,24 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// find the last bound based on reduced results and metric type
|
||||
// only support nq == 1, for search iterator v2
|
||||
func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 {
|
||||
len := len(result.Results.Scores)
|
||||
if len > 0 && result.GetResults().GetNumQueries() == 1 {
|
||||
return result.Results.Scores[len-1]
|
||||
}
|
||||
// if no results found and incoming last bound is not nil, return it
|
||||
if incomingLastBound != nil {
|
||||
return *incomingLastBound
|
||||
}
|
||||
// if no results found and it is the first call, return the closest bound
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
return math.MaxFloat32
|
||||
}
|
||||
return -math.MaxFloat32
|
||||
}
|
||||
|
||||
func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute")
|
||||
defer sp.End()
|
||||
|
@ -670,6 +692,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
metricType := getMetricType(toReduceResults)
|
||||
// reduce
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
|
@ -696,16 +719,12 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||
for index, internalResults := range multipleInternalResults {
|
||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||
|
||||
metricType := ""
|
||||
if len(internalResults) >= 1 {
|
||||
metricType = internalResults[0].GetMetricType()
|
||||
}
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true)
|
||||
subMetricType := getMetricType(internalResults)
|
||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.reScorers[index].setMetricType(metricType)
|
||||
t.reScorers[index].setMetricType(subMetricType)
|
||||
t.reScorers[index].reScore(result)
|
||||
multipleMilvusResults[index] = result
|
||||
}
|
||||
|
@ -721,7 +740,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
} else {
|
||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false)
|
||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -751,6 +770,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
if t.isIterator && len(t.queryInfos) == 1 && t.queryInfos[0] != nil {
|
||||
if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil {
|
||||
t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{
|
||||
Token: iterInfo.GetToken(),
|
||||
LastBound: getLastBound(t.result, iterInfo.LastBound, metricType),
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
|
||||
// first page for iteration, need to set up sessionTs for iterator
|
||||
t.result.SessionTs = getMaxMvccTsFromChannels(t.queryChannelsTs, t.BeginTs())
|
||||
|
|
|
@ -18,12 +18,14 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -103,9 +105,124 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
|||
assert.Equal(t, qt.resultSizeInsufficient, true)
|
||||
assert.Equal(t, qt.isTopkReduce, false)
|
||||
})
|
||||
|
||||
t.Run("test search iterator v2", func(t *testing.T) {
|
||||
const (
|
||||
kRows = 10
|
||||
kToken = "test-token"
|
||||
)
|
||||
|
||||
collName := "test_collection_search_iterator_v2" + funcutil.GenRandomStr()
|
||||
collSchema := createColl(t, collName, rc)
|
||||
|
||||
createIteratorSearchTask := func(t *testing.T, metricType string, rows int) *searchTask {
|
||||
ids := make([]int64, rows)
|
||||
for i := range ids {
|
||||
ids[i] = int64(i)
|
||||
}
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
scores := make([]float32, rows)
|
||||
// proxy needs to reverse the score for negatively related metrics
|
||||
for i := range scores {
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
scores[i] = float32(len(scores) - i)
|
||||
} else {
|
||||
scores[i] = -float32(i + 1)
|
||||
}
|
||||
}
|
||||
resultData := &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
Scores: scores,
|
||||
NumQueries: 1,
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
Nq: 1,
|
||||
},
|
||||
schema: newSchemaInfo(collSchema),
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collName,
|
||||
},
|
||||
queryInfos: []*planpb.QueryInfo{{
|
||||
SearchIteratorV2Info: &planpb.SearchIteratorV2Info{
|
||||
Token: kToken,
|
||||
BatchSize: 1,
|
||||
},
|
||||
}},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: resultData,
|
||||
},
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
isIterator: true,
|
||||
}
|
||||
bytes, err := proto.Marshal(resultData)
|
||||
assert.NoError(t, err)
|
||||
qt.resultBuf.Insert(&internalpb.SearchResults{
|
||||
MetricType: metricType,
|
||||
SlicedBlob: bytes,
|
||||
})
|
||||
return qt
|
||||
}
|
||||
|
||||
t.Run("test search iterator v2", func(t *testing.T) {
|
||||
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
|
||||
for _, metricType := range metrics {
|
||||
qt := createIteratorSearchTask(t, metricType, kRows)
|
||||
err = qt.PostExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
assert.Equal(t, float32(1), qt.result.Results.SearchIteratorV2Results.LastBound)
|
||||
} else {
|
||||
assert.Equal(t, float32(kRows), qt.result.Results.SearchIteratorV2Results.LastBound)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test search iterator v2 with empty result", func(t *testing.T) {
|
||||
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
|
||||
for _, metricType := range metrics {
|
||||
qt := createIteratorSearchTask(t, metricType, 0)
|
||||
err = qt.PostExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
assert.Equal(t, float32(math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound)
|
||||
} else {
|
||||
assert.Equal(t, float32(-math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test search iterator v2 with empty result and incoming last bound", func(t *testing.T) {
|
||||
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
|
||||
kLastBound := float32(10)
|
||||
for _, metricType := range metrics {
|
||||
qt := createIteratorSearchTask(t, metricType, 0)
|
||||
qt.queryInfos[0].SearchIteratorV2Info.LastBound = &kLastBound
|
||||
err = qt.PostExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
|
||||
assert.Equal(t, kLastBound, qt.result.Results.SearchIteratorV2Results.LastBound)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func createColl(t *testing.T, name string, rc types.RootCoordClient) {
|
||||
func createColl(t *testing.T, name string, rc types.RootCoordClient) *schemapb.CollectionSchema {
|
||||
schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
require.NoError(t, err)
|
||||
|
@ -126,6 +243,8 @@ func createColl(t *testing.T, name string, rc types.RootCoordClient) {
|
|||
require.NoError(t, createColT.PreExecute(ctx))
|
||||
require.NoError(t, createColT.Execute(ctx))
|
||||
require.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
func getBaseSearchParams() []*commonpb.KeyValuePair {
|
||||
|
@ -2599,6 +2718,157 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check search iterator v2", func(t *testing.T) {
|
||||
kBatchSize := uint32(10)
|
||||
generateValidParamsForSearchIteratorV2 := func() []*commonpb.KeyValuePair {
|
||||
param := getValidSearchParams()
|
||||
return append(param,
|
||||
&commonpb.KeyValuePair{
|
||||
Key: SearchIterV2Key,
|
||||
Value: "True",
|
||||
},
|
||||
&commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
},
|
||||
&commonpb.KeyValuePair{
|
||||
Key: SearchIterBatchSizeKey,
|
||||
Value: fmt.Sprintf("%d", kBatchSize),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
t.Run("iteratorV2 normal", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
|
||||
assert.Equal(t, kBatchSize, searchInfo.planInfo.SearchIteratorV2Info.BatchSize)
|
||||
assert.Len(t, searchInfo.planInfo.SearchIteratorV2Info.Token, 36)
|
||||
assert.Equal(t, int64(kBatchSize), searchInfo.planInfo.GetTopk()) // compatibility
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 without isIterator", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
resetSearchParamsValue(param, IteratorField, "False")
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "both")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 with groupBy", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
param = append(param, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "string_field",
|
||||
})
|
||||
fields := make([]*schemapb.FieldSchema, 0)
|
||||
fields = append(fields, &schemapb.FieldSchema{
|
||||
FieldID: int64(101),
|
||||
Name: "string_field",
|
||||
})
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
searchInfo := parseSearchInfo(param, schema, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "roupBy")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 with offset", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
param = append(param, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Value: "10",
|
||||
})
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "offset")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 invalid token", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
param = append(param, &commonpb.KeyValuePair{
|
||||
Key: SearchIterIdKey,
|
||||
Value: "invalid_token",
|
||||
})
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "invalid token format")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 passed token must be same", func(t *testing.T) {
|
||||
token, err := uuid.NewRandom()
|
||||
assert.NoError(t, err)
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
param = append(param, &commonpb.KeyValuePair{
|
||||
Key: SearchIterIdKey,
|
||||
Value: token.String(),
|
||||
})
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
|
||||
assert.Equal(t, token.String(), searchInfo.planInfo.SearchIteratorV2Info.Token)
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 batch size", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, "1.123")
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 batch size", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, "")
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is required")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 batch size negative", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, "-1")
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 batch size too large", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, fmt.Sprintf("%d", Params.QuotaConfig.TopKLimit.GetAsInt64()+1))
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 last bound", func(t *testing.T) {
|
||||
kLastBound := float32(1.123)
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
param = append(param, &commonpb.KeyValuePair{
|
||||
Key: SearchIterLastBoundKey,
|
||||
Value: fmt.Sprintf("%f", kLastBound),
|
||||
})
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.Equal(t, kLastBound, *searchInfo.planInfo.SearchIteratorV2Info.LastBound)
|
||||
})
|
||||
|
||||
t.Run("iteratorV2 invalid last bound", func(t *testing.T) {
|
||||
param := generateValidParamsForSearchIteratorV2()
|
||||
param = append(param, &commonpb.KeyValuePair{
|
||||
Key: SearchIterLastBoundKey,
|
||||
Value: "xxx",
|
||||
})
|
||||
searchInfo := parseSearchInfo(param, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.ErrorContains(t, searchInfo.parseError, "failed to parse input last bound")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||
|
|
Loading…
Reference in New Issue