enhance: [2.5][cp] speed up search iterator stage 1 (#38678)

pr: https://github.com/milvus-io/milvus/pull/37947
issue: https://github.com/milvus-io/milvus/issues/37548

Signed-off-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com>
(cherry picked from commit 9016c4adcd765c0766b01e7e5d465c915e176a6f)
pull/38837/head
Patrick Weizhi Xu 2024-12-27 18:48:52 +08:00 committed by GitHub
parent d48a33d76b
commit ef400227ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1970 additions and 117 deletions

2
go.mod
View File

@ -64,6 +64,7 @@ require (
github.com/cenkalti/backoff/v4 v4.2.1
github.com/cockroachdb/redact v1.1.3
github.com/goccy/go-json v0.10.3
github.com/google/uuid v1.6.0
github.com/greatroar/blobloom v0.0.0-00010101000000-000000000000
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jolestar/go-commons-pool/v2 v2.1.2
@ -144,7 +145,6 @@ require (
github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v2.0.8+incompatible // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
github.com/gorilla/websocket v1.4.2 // indirect

View File

@ -24,6 +24,12 @@
namespace milvus {
struct SearchIteratorV2Info {
std::string token = "";
uint32_t batch_size = 0;
std::optional<float> last_bound = std::nullopt;
};
struct SearchInfo {
int64_t topk_{0};
int64_t group_size_{1};
@ -36,6 +42,7 @@ struct SearchInfo {
tracer::TraceContext trace_ctx_;
bool materialized_view_involved = false;
bool iterative_filter_execution = false;
std::optional<SearchIteratorV2Info> iterator_v2_info_ = std::nullopt;
};
using SearchInfoPtr = std::shared_ptr<SearchInfo>;

View File

@ -362,4 +362,38 @@ ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size) {
}
}
bool
CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info,
const int64_t topk,
const MetricType& metric_type,
knowhere::Json& search_config) {
const auto radius =
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS);
if (!radius.has_value()) {
return false;
}
search_config[RADIUS] = radius.value();
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_config[knowhere::meta::RANGE_SEARCH_K] = topk;
const auto range_filter =
GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
if (range_filter.has_value()) {
search_config[RANGE_FILTER] = range_filter.value();
CheckRangeSearchParam(
search_config[RADIUS], search_config[RANGE_FILTER], metric_type);
}
const auto page_retain_order =
GetValueFromConfig<bool>(search_info.search_params_, PAGE_RETAIN_ORDER);
if (page_retain_order.has_value()) {
search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] =
page_retain_order.value();
}
return true;
}
} // namespace milvus::index

View File

@ -30,6 +30,8 @@
#include "common/Types.h"
#include "common/FieldData.h"
#include "common/QueryInfo.h"
#include "common/RangeSearchHelper.h"
#include "index/IndexInfo.h"
#include "storage/Types.h"
@ -147,4 +149,10 @@ AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
void
ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size = 0x7ffff000);
bool
CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info,
const int64_t topk,
const MetricType& metric_type,
knowhere::Json& search_config);
} // namespace milvus::index

View File

@ -266,32 +266,9 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;
auto final = [&] {
auto radius =
GetValueFromConfig<float>(search_info.search_params_, RADIUS);
if (radius.has_value()) {
search_config[RADIUS] = radius.value();
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_config[knowhere::meta::RANGE_SEARCH_K] = topk;
auto range_filter = GetValueFromConfig<float>(
search_info.search_params_, RANGE_FILTER);
if (range_filter.has_value()) {
search_config[RANGE_FILTER] = range_filter.value();
CheckRangeSearchParam(search_config[RADIUS],
search_config[RANGE_FILTER],
GetMetricType());
}
auto page_retain_order = GetValueFromConfig<bool>(
search_info.search_params_, PAGE_RETAIN_ORDER);
if (page_retain_order.has_value()) {
search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] =
page_retain_order.value();
}
if (CheckAndUpdateKnowhereRangeSearchParam(
search_info, topk, GetMetricType(), search_config)) {
auto res = index_.RangeSearch(dataset, search_config, bitset);
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
fmt::format("failed to range search: {}: {}",

View File

@ -380,16 +380,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
// TODO :: check dim of search data
auto final = [&] {
auto index_type = GetIndexType();
if (CheckKeyInConfig(search_conf, RADIUS)) {
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
CheckRangeSearchParam(search_conf[RADIUS],
search_conf[RANGE_FILTER],
GetMetricType());
}
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_conf[knowhere::meta::RANGE_SEARCH_K] = topk;
if (CheckAndUpdateKnowhereRangeSearchParam(
search_info, topk, GetMetricType(), search_conf)) {
milvus::tracer::AddEvent("start_knowhere_index_range_search");
auto res = index_.RangeSearch(dataset, search_conf, bitset);
milvus::tracer::AddEvent("finish_knowhere_index_range_search");

View File

@ -0,0 +1,362 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include <algorithm>
namespace milvus::query {
CachedSearchIterator::CachedSearchIterator(
const milvus::index::VectorIndex& index,
const knowhere::DataSetPtr& query_ds,
const SearchInfo& search_info,
const BitsetView& bitset) {
if (query_ds == nullptr) {
PanicInfo(ErrorCode::UnexpectedError,
"Query dataset is nullptr, cannot initialize iterator");
}
nq_ = query_ds->GetRows();
Init(search_info);
auto search_json = index.PrepareSearchParams(search_info);
index::CheckAndUpdateKnowhereRangeSearchParam(
search_info, batch_size_, index.GetMetricType(), search_json);
auto expected_iterators =
index.VectorIterators(query_ds, search_json, bitset);
if (expected_iterators.has_value()) {
iterators_ = std::move(expected_iterators.value());
} else {
PanicInfo(ErrorCode::UnexpectedError,
"Failed to create iterators from index");
}
}
CachedSearchIterator::CachedSearchIterator(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type) {
nq_ = query_ds.num_queries;
Init(search_info);
auto expected_iterators = GetBruteForceSearchIterators(
query_ds, raw_ds, search_info, index_info, bitset, data_type);
if (expected_iterators.has_value()) {
iterators_ = std::move(expected_iterators.value());
} else {
PanicInfo(ErrorCode::UnexpectedError,
"Failed to create iterators from index");
}
}
void
CachedSearchIterator::InitializeChunkedIterators(
const dataset::SearchDataset& query_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type,
const GetChunkDataFunc& get_chunk_data) {
int64_t offset = 0;
chunked_heaps_.resize(nq_);
for (int64_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
auto [chunk_data, chunk_size] = get_chunk_data(chunk_id);
auto sub_data = query::dataset::RawDataset{
offset, query_ds.dim, chunk_size, chunk_data};
auto expected_iterators = GetBruteForceSearchIterators(
query_ds, sub_data, search_info, index_info, bitset, data_type);
if (expected_iterators.has_value()) {
auto& chunk_iterators = expected_iterators.value();
iterators_.insert(iterators_.end(),
std::make_move_iterator(chunk_iterators.begin()),
std::make_move_iterator(chunk_iterators.end()));
} else {
PanicInfo(ErrorCode::UnexpectedError,
"Failed to create iterators from index");
}
offset += chunk_size;
}
}
CachedSearchIterator::CachedSearchIterator(
const dataset::SearchDataset& query_ds,
const segcore::VectorBase* vec_data,
const int64_t row_count,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type) {
if (vec_data == nullptr) {
PanicInfo(ErrorCode::UnexpectedError,
"Vector data is nullptr, cannot initialize iterator");
}
if (row_count <= 0) {
PanicInfo(ErrorCode::UnexpectedError,
"Number of rows is 0, cannot initialize iterator");
}
const int64_t vec_size_per_chunk = vec_data->get_size_per_chunk();
num_chunks_ = upper_div(row_count, vec_size_per_chunk);
nq_ = query_ds.num_queries;
Init(search_info);
iterators_.reserve(nq_ * num_chunks_);
InitializeChunkedIterators(
query_ds,
search_info,
index_info,
bitset,
data_type,
[&vec_data, vec_size_per_chunk, row_count](
int64_t chunk_id) -> std::pair<const void*, int64_t> {
const auto chunk_data = vec_data->get_chunk_data(chunk_id);
int64_t chunk_size = std::min(
vec_size_per_chunk, row_count - chunk_id * vec_size_per_chunk);
return {chunk_data, chunk_size};
});
}
CachedSearchIterator::CachedSearchIterator(
const std::shared_ptr<ChunkedColumnBase>& column,
const dataset::SearchDataset& query_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type) {
if (column == nullptr) {
PanicInfo(ErrorCode::UnexpectedError,
"Column is nullptr, cannot initialize iterator");
}
num_chunks_ = column->num_chunks();
nq_ = query_ds.num_queries;
Init(search_info);
iterators_.reserve(nq_ * num_chunks_);
InitializeChunkedIterators(
query_ds,
search_info,
index_info,
bitset,
data_type,
[&column](int64_t chunk_id) {
const char* chunk_data = column->Data(chunk_id);
int64_t chunk_size = column->chunk_row_nums(chunk_id);
return std::make_pair(static_cast<const void*>(chunk_data),
chunk_size);
});
}
void
CachedSearchIterator::NextBatch(const SearchInfo& search_info,
SearchResult& search_result) {
if (iterators_.empty()) {
return;
}
if (iterators_.size() != nq_ * num_chunks_) {
PanicInfo(ErrorCode::UnexpectedError,
"Iterator size mismatch, expect %d, but got %d",
nq_ * num_chunks_,
iterators_.size());
}
ValidateSearchInfo(search_info);
search_result.total_nq_ = nq_;
search_result.unity_topK_ = batch_size_;
search_result.seg_offsets_.resize(nq_ * batch_size_);
search_result.distances_.resize(nq_ * batch_size_);
for (size_t query_idx = 0; query_idx < nq_; ++query_idx) {
auto rst = GetBatchedNextResults(query_idx, search_info);
WriteSingleQuerySearchResult(
search_result, query_idx, rst, search_info.round_decimal_);
}
}
void
CachedSearchIterator::ValidateSearchInfo(const SearchInfo& search_info) {
if (!search_info.iterator_v2_info_.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"Iterator v2 SearchInfo is not set");
}
auto iterator_v2_info = search_info.iterator_v2_info_.value();
if (iterator_v2_info.batch_size != batch_size_) {
PanicInfo(ErrorCode::UnexpectedError,
"Batch size mismatch, expect %d, but got %d",
batch_size_,
iterator_v2_info.batch_size);
}
}
std::optional<CachedSearchIterator::DisIdPair>
CachedSearchIterator::GetNextValidResult(
const size_t iterator_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter) {
auto& iterator = iterators_[iterator_idx];
while (iterator->HasNext()) {
auto result = ConvertIteratorResult(iterator->Next());
if (IsValid(result, last_bound, radius, range_filter)) {
return result;
}
}
return std::nullopt;
}
// TODO: Optimize this method
void
CachedSearchIterator::MergeChunksResults(
size_t query_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter,
std::vector<DisIdPair>& rst) {
auto& heap = chunked_heaps_[query_idx];
if (heap.empty()) {
for (size_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
const size_t iterator_idx = query_idx + chunk_id * nq_;
if (auto next_result = GetNextValidResult(
iterator_idx, last_bound, radius, range_filter);
next_result.has_value()) {
heap.emplace(iterator_idx, next_result.value());
}
}
}
while (!heap.empty() && rst.size() < batch_size_) {
const auto [iterator_idx, cur_rst] = heap.top();
heap.pop();
// last_bound may change between NextBatch calls, discard any invalid results
if (!IsValid(cur_rst, last_bound, radius, range_filter)) {
continue;
}
rst.emplace_back(cur_rst);
if (auto next_result = GetNextValidResult(
iterator_idx, last_bound, radius, range_filter);
next_result.has_value()) {
heap.emplace(iterator_idx, next_result.value());
}
}
}
std::vector<CachedSearchIterator::DisIdPair>
CachedSearchIterator::GetBatchedNextResults(size_t query_idx,
const SearchInfo& search_info) {
auto last_bound = ConvertIncomingDistance(
search_info.iterator_v2_info_.value().last_bound);
auto radius = ConvertIncomingDistance(
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS));
auto range_filter =
ConvertIncomingDistance(index::GetValueFromConfig<float>(
search_info.search_params_, RANGE_FILTER));
std::vector<DisIdPair> rst;
rst.reserve(batch_size_);
if (num_chunks_ == 1) {
auto& iterator = iterators_[query_idx];
while (iterator->HasNext() && rst.size() < batch_size_) {
auto result = ConvertIteratorResult(iterator->Next());
if (IsValid(result, last_bound, radius, range_filter)) {
rst.emplace_back(result);
}
}
} else {
MergeChunksResults(query_idx, last_bound, radius, range_filter, rst);
}
std::sort(rst.begin(), rst.end());
if (sign_ == -1) {
std::for_each(rst.begin(), rst.end(), [this](DisIdPair& x) {
x.first = x.first * sign_;
});
}
while (rst.size() < batch_size_) {
rst.emplace_back(1.0f / 0.0f, -1);
}
return rst;
}
void
CachedSearchIterator::WriteSingleQuerySearchResult(
SearchResult& search_result,
const size_t idx,
std::vector<DisIdPair>& rst,
const int64_t round_decimal) {
const float multiplier = pow(10.0, round_decimal);
std::transform(rst.begin(),
rst.end(),
search_result.distances_.begin() + idx * batch_size_,
[multiplier, round_decimal](DisIdPair& x) {
if (round_decimal != -1) {
x.first =
std::round(x.first * multiplier) / multiplier;
}
return x.first;
});
std::transform(rst.begin(),
rst.end(),
search_result.seg_offsets_.begin() + idx * batch_size_,
[](const DisIdPair& x) { return x.second; });
}
void
CachedSearchIterator::Init(const SearchInfo& search_info) {
if (!search_info.iterator_v2_info_.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"Iterator v2 info is not set, cannot initialize iterator");
}
auto iterator_v2_info = search_info.iterator_v2_info_.value();
if (iterator_v2_info.batch_size == 0) {
PanicInfo(ErrorCode::UnexpectedError,
"Batch size is 0, cannot initialize iterator");
}
batch_size_ = iterator_v2_info.batch_size;
if (search_info.metric_type_.empty()) {
PanicInfo(ErrorCode::UnexpectedError,
"Metric type is empty, cannot initialize iterator");
}
if (PositivelyRelated(search_info.metric_type_)) {
sign_ = -1;
} else {
sign_ = 1;
}
if (nq_ == 0) {
PanicInfo(ErrorCode::UnexpectedError,
"Number of queries is 0, cannot initialize iterator");
}
// disable multi-query for now
if (nq_ > 1) {
PanicInfo(
ErrorCode::UnexpectedError,
"Number of queries is greater than 1, cannot initialize iterator");
}
}
} // namespace milvus::query

View File

@ -0,0 +1,182 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <utility>
#include "common/BitsetView.h"
#include "common/QueryInfo.h"
#include "common/QueryResult.h"
#include "query/helper.h"
#include "segcore/ConcurrentVector.h"
#include "index/VectorIndex.h"
namespace milvus::query {
// This class is used to cache the search results from Knowhere
// search iterators and filter the results based on the last_bound,
// radius and range_filter.
// It provides a number of constructors to support different scenarios,
// including growing/sealed, chunked/non-chunked.
//
// It does not care about TopK in search_info
// The topk in SearchResult will be set to the batch_size for compatibility
//
// TODO: introduce the pool of results in the near future
// TODO: replace VectorIterator class
class CachedSearchIterator {
public:
// For sealed segment with vector index
CachedSearchIterator(const milvus::index::VectorIndex& index,
const knowhere::DataSetPtr& dataset,
const SearchInfo& search_info,
const BitsetView& bitset);
// For sealed segment, BF
CachedSearchIterator(const dataset::SearchDataset& dataset,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type);
// For growing segment with chunked data, BF
CachedSearchIterator(const dataset::SearchDataset& dataset,
const segcore::VectorBase* vec_data,
const int64_t row_count,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type);
// For sealed segment with chunked data, BF
CachedSearchIterator(const std::shared_ptr<ChunkedColumnBase>& column,
const dataset::SearchDataset& dataset,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type);
// This method fetches the next batch of search results based on the provided search information
// and updates the search_result object with the new batch of results.
void
NextBatch(const SearchInfo& search_info, SearchResult& search_result);
// Disable copy and move
CachedSearchIterator(const CachedSearchIterator&) = delete;
CachedSearchIterator&
operator=(const CachedSearchIterator&) = delete;
CachedSearchIterator(CachedSearchIterator&&) = delete;
CachedSearchIterator&
operator=(CachedSearchIterator&&) = delete;
private:
using DisIdPair = std::pair<float, int64_t>;
using IterIdx = size_t;
using IterIdDisIdPair = std::pair<IterIdx, DisIdPair>;
using GetChunkDataFunc =
std::function<std::pair<const void*, int64_t>(int64_t)>;
int64_t batch_size_ = 0;
std::vector<knowhere::IndexNode::IteratorPtr> iterators_;
int8_t sign_ = 1;
size_t num_chunks_ = 1;
size_t nq_ = 0;
struct IterIdDisIdPairComparator {
bool
operator()(const IterIdDisIdPair& lhs, const IterIdDisIdPair& rhs) {
if (lhs.second.first == rhs.second.first) {
return lhs.second.second > rhs.second.second;
}
return lhs.second.first > rhs.second.first;
}
};
std::vector<std::priority_queue<IterIdDisIdPair,
std::vector<IterIdDisIdPair>,
IterIdDisIdPairComparator>>
chunked_heaps_;
inline bool
IsValid(const DisIdPair& result,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter) {
const float dist = result.first;
const bool is_valid =
!last_bound.has_value() || dist > last_bound.value();
if (!radius.has_value()) {
return is_valid;
}
if (!range_filter.has_value()) {
return is_valid && dist < radius.value();
}
return is_valid && dist < radius.value() &&
dist >= range_filter.value();
}
inline DisIdPair
ConvertIteratorResult(const std::pair<int64_t, float>& iter_rst) {
DisIdPair rst;
rst.first = iter_rst.second * sign_;
rst.second = iter_rst.first;
return rst;
}
inline std::optional<float>
ConvertIncomingDistance(std::optional<float> dist) {
if (dist.has_value()) {
dist = dist.value() * sign_;
}
return dist;
}
std::optional<DisIdPair>
GetNextValidResult(size_t iterator_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter);
void
MergeChunksResults(size_t query_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter,
std::vector<DisIdPair>& rst);
void
ValidateSearchInfo(const SearchInfo& search_info);
std::vector<DisIdPair>
GetBatchedNextResults(size_t query_idx, const SearchInfo& search_info);
void
WriteSingleQuerySearchResult(SearchResult& search_result,
const size_t idx,
std::vector<DisIdPair>& rst,
const int64_t round_decimal);
void
Init(const SearchInfo& search_info);
void
InitializeChunkedIterators(
const dataset::SearchDataset& dataset,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type,
const GetChunkDataFunc& get_chunk_data);
};
} // namespace milvus::query

View File

@ -93,6 +93,20 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.strict_group_size_ =
query_info_proto.strict_group_size();
}
if (query_info_proto.has_search_iterator_v2_info()) {
auto& iterator_v2_info_proto =
query_info_proto.search_iterator_v2_info();
search_info.iterator_v2_info_ = SearchIteratorV2Info{
.token = iterator_v2_info_proto.token(),
.batch_size = iterator_v2_info_proto.batch_size(),
};
if (iterator_v2_info_proto.has_last_bound()) {
search_info.iterator_v2_info_->last_bound =
iterator_v2_info_proto.last_bound();
}
}
return search_info;
};

View File

@ -226,45 +226,66 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
return sub_result;
}
SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
iterators_val;
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset,
const knowhere::DataSetPtr& query_dataset,
const knowhere::Json& config,
const BitsetView& bitset,
const milvus::DataType& data_type) {
switch (data_type) {
case DataType::VECTOR_FLOAT:
iterators_val = knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset);
return knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, config, bitset);
break;
case DataType::VECTOR_FLOAT16:
//todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset);
return knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, config, bitset);
break;
case DataType::VECTOR_BFLOAT16:
//todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset);
return knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, config, bitset);
break;
case DataType::VECTOR_SPARSE_FLOAT:
iterators_val = knowhere::BruteForce::AnnIterator<
return knowhere::BruteForce::AnnIterator<
knowhere::sparse::SparseRow<float>>(
base_dataset, query_dataset, search_cfg, bitset);
base_dataset, query_dataset, config, bitset);
break;
default:
PanicInfo(ErrorCode::Unsupported,
"Unsupported dataType for chunk brute force iterator:{}",
data_type);
}
}
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
GetBruteForceSearchIterators(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
return DispatchBruteForceIteratorByDataType(
base_dataset, query_dataset, search_cfg, bitset, data_type);
}
SubSearchResult
PackBruteForceSearchIteratorsIntoSubResult(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = query_ds.num_queries;
auto iterators_val = GetBruteForceSearchIterators(
query_ds, raw_ds, search_info, index_info, bitset, data_type);
if (iterators_val.has_value()) {
AssertInfo(
iterators_val.value().size() == nq,

View File

@ -31,12 +31,22 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
const BitsetView& bitset,
DataType data_type);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
GetBruteForceSearchIterators(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);
SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);
PackBruteForceSearchIteratorsIntoSubResult(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);
} // namespace milvus::query

View File

@ -18,6 +18,7 @@
#include "knowhere/comp/index_param.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h"
#include "exec/operator/Utils.h"
@ -125,6 +126,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
// step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_data_base(vecfield_id);
if (info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(search_dataset,
vec_ptr,
active_count,
info,
index_info,
bitset,
data_type);
cached_iter.NextBatch(info, search_result);
return;
}
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
auto max_chunk = upper_div(active_count, vec_size_per_chunk);
@ -140,12 +154,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data};
if (milvus::exec::UseVectorIterator(info)) {
auto sub_qr = BruteForceSearchIterators(search_dataset,
sub_data,
info,
index_info,
bitset,
data_type);
auto sub_qr =
PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
sub_data,
info,
index_info,
bitset,
data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(search_dataset,

View File

@ -11,6 +11,7 @@
#include "SearchOnIndex.h"
#include "exec/operator/Utils.h"
#include "CachedSearchIterator.h"
namespace milvus::query {
void
@ -26,14 +27,23 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
auto dataset =
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
dataset->SetIsSparse(is_sparse);
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
num_queries,
dataset,
search_result,
bitset,
indexing)) {
indexing.Query(dataset, search_conf, bitset, search_result);
if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
num_queries,
dataset,
search_result,
bitset,
indexing)) {
return;
}
if (search_conf.iterator_v2_info_.has_value()) {
auto iter =
CachedSearchIterator(indexing, dataset, search_conf, bitset);
iter.NextBatch(search_conf, search_result);
return;
}
indexing.Query(dataset, search_conf, bitset, search_result);
}
} // namespace milvus::query

View File

@ -18,6 +18,7 @@
#include "common/QueryInfo.h"
#include "common/Types.h"
#include "mmap/Column.h"
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h"
#include "query/helper.h"
@ -55,13 +56,20 @@ SearchOnSealedIndex(const Schema& schema,
dataset->SetIsSparse(is_sparse);
auto vec_index =
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(
*vec_index, dataset, search_info, bitset);
cached_iter.NextBatch(search_info, search_result);
return;
}
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
num_queries,
dataset,
search_result,
bitset,
*vec_index)) {
auto index_type = vec_index->GetIndexType();
vec_index->Query(dataset, search_info, bitset, search_result);
float* distances = search_result.distances_.data();
auto total_num = num_queries * topK;
@ -104,6 +112,14 @@ SearchOnSealed(const Schema& schema,
auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(
column, query_dataset, search_info, index_info, bitview, data_type);
cached_iter.NextBatch(search_info, result);
return;
}
auto num_chunk = column->num_chunks();
SubSearchResult final_qr(num_queries,
@ -115,17 +131,16 @@ SearchOnSealed(const Schema& schema,
for (int i = 0; i < num_chunk; ++i) {
auto vec_data = column->Data(i);
auto chunk_size = column->chunk_row_nums(i);
const uint8_t* bitset_ptr = nullptr;
auto data_id = offset;
auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
if (milvus::exec::UseVectorIterator(search_info)) {
auto sub_qr = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitview,
data_type);
auto sub_qr =
PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
raw_dataset,
search_info,
index_info,
bitview,
data_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(query_dataset,
@ -136,7 +151,6 @@ SearchOnSealed(const Schema& schema,
data_type);
final_qr.merge(sub_qr);
}
offset += chunk_size;
}
if (milvus::exec::UseVectorIterator(search_info)) {
@ -181,14 +195,23 @@ SearchOnSealed(const Schema& schema,
CheckBruteForceSearchParam(field, search_info);
auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data};
if (milvus::exec::UseVectorIterator(search_info)) {
auto sub_qr = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
data_type);
auto sub_qr = PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
data_type);
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
data_type);
cached_iter.NextBatch(search_info, result);
return;
} else {
auto sub_qr = BruteForceSearch(query_dataset,
raw_dataset,

View File

@ -89,6 +89,7 @@ set(MILVUS_TEST_FILES
test_chunked_segment.cpp
test_chunked_column.cpp
test_rust_result.cpp
test_cached_search_iterator.cpp
)
if ( INDEX_ENGINE STREQUAL "cardinal" )

View File

@ -143,12 +143,13 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
AssertMatch(ref, ans);
}
auto result3 = BruteForceSearchIterators(query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto result3 = PackBruteForceSearchIteratorsIntoSubResult(
query_dataset,
raw_dataset,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto iterators = result3.chunk_iterators();
for (int i = 0; i < nq; i++) {
auto it = iterators[i];

View File

@ -0,0 +1,797 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <unordered_set>
#include "common/BitsetView.h"
#include "common/QueryInfo.h"
#include "common/QueryResult.h"
#include "common/Utils.h"
#include "index/Index.h"
#include "knowhere/comp/index_param.h"
#include "query/CachedSearchIterator.h"
#include "index/VectorIndex.h"
#include "index/IndexFactory.h"
#include "knowhere/dataset.h"
#include "query/helper.h"
#include "segcore/ConcurrentVector.h"
#include "segcore/InsertRecord.h"
#include "mmap/ChunkedColumn.h"
#include "test_utils/DataGen.h"
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
using namespace milvus::index;
namespace {
constexpr int64_t kDim = 16;
constexpr int64_t kNumVectors = 1000;
constexpr int64_t kNumQueries = 1;
constexpr int64_t kBatchSize = 100;
constexpr size_t kSizePerChunk = 128;
constexpr size_t kHnswM = 24;
constexpr size_t kHnswEfConstruction = 360;
constexpr size_t kHnswEf = 128;
const MetricType kMetricType = knowhere::metric::L2;
} // namespace
enum class ConstructorType {
VectorIndex = 0,
RawData,
VectorBase,
ChunkedColumn
};
static const std::vector<ConstructorType> kConstructorTypes = {
ConstructorType::VectorIndex,
ConstructorType::RawData,
ConstructorType::VectorBase,
ConstructorType::ChunkedColumn,
};
static const std::vector<MetricType> kMetricTypes = {
knowhere::metric::L2,
knowhere::metric::IP,
knowhere::metric::COSINE,
};
// this class does not support test concurrently
class CachedSearchIteratorTest
: public ::testing::TestWithParam<std::tuple<ConstructorType, MetricType>> {
private:
protected:
SearchInfo
GetDefaultNormalSearchInfo() {
return SearchInfo{
.topk_ = kBatchSize,
.round_decimal_ = -1,
.metric_type_ = std::get<1>(GetParam()),
.search_params_ =
{
{knowhere::indexparam::EF, std::to_string(kHnswEf)},
},
.iterator_v2_info_ =
SearchIteratorV2Info{
.batch_size = kBatchSize,
},
};
}
static DataType data_type_;
static int64_t dim_;
static int64_t nb_;
static int64_t nq_;
static FixedVector<float> base_dataset_;
static FixedVector<float> query_dataset_;
static IndexBasePtr index_hnsw_l2_;
static IndexBasePtr index_hnsw_ip_;
static IndexBasePtr index_hnsw_cos_;
static knowhere::DataSetPtr knowhere_query_dataset_;
static dataset::SearchDataset search_dataset_;
static std::unique_ptr<ConcurrentVector<milvus::FloatVector>> vector_base_;
static std::shared_ptr<ChunkedColumn> column_;
static std::vector<std::vector<char>> column_data_;
IndexBase* index_hnsw_ = nullptr;
MetricType metric_type_ = kMetricType;
std::unique_ptr<CachedSearchIterator>
DispatchIterator(const ConstructorType& constructor_type,
const SearchInfo& search_info,
const BitsetView& bitset) {
switch (constructor_type) {
case ConstructorType::VectorIndex:
return std::make_unique<CachedSearchIterator>(
dynamic_cast<const VectorIndex&>(*index_hnsw_),
knowhere_query_dataset_,
search_info,
bitset);
case ConstructorType::RawData:
return std::make_unique<CachedSearchIterator>(
search_dataset_,
dataset::RawDataset{0, dim_, nb_, base_dataset_.data()},
search_info,
std::map<std::string, std::string>{},
bitset,
data_type_);
case ConstructorType::VectorBase:
return std::make_unique<CachedSearchIterator>(
search_dataset_,
vector_base_.get(),
nb_,
search_info,
std::map<std::string, std::string>{},
bitset,
data_type_);
case ConstructorType::ChunkedColumn:
return std::make_unique<CachedSearchIterator>(
column_,
search_dataset_,
search_info,
std::map<std::string, std::string>{},
bitset,
data_type_);
default:
return nullptr;
}
}
// use last distance of the first batch as range_filter
// use first distance of the last batch as radius
std::pair<float, float>
GetRadiusAndRangeFilter() {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
SearchResult search_result;
float radius, range_filter;
bool get_radius_success = false;
bool get_range_filter_sucess = false;
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
if (rnd == 0) {
for (size_t i = kBatchSize - 1; i >= 0; --i) {
if (search_result.seg_offsets_[i] != -1) {
range_filter = search_result.distances_[i];
get_range_filter_sucess = true;
break;
}
}
} else {
for (size_t i = 0; i < kBatchSize; ++i) {
if (search_result.seg_offsets_[i] != -1) {
radius = search_result.distances_[i];
get_radius_success = true;
break;
}
}
}
}
if (!get_radius_success || !get_range_filter_sucess) {
throw std::runtime_error("Failed to get radius and range filter");
}
return {radius, range_filter};
}
static void
BuildIndex() {
auto dataset = knowhere::GenDataSet(nb_, dim_, base_dataset_.data());
for (const auto& metric_type : kMetricTypes) {
milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = data_type_;
create_index_info.metric_type = metric_type;
create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber();
auto build_conf = knowhere::Json{
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, std::to_string(dim_)},
{knowhere::indexparam::M, std::to_string(kHnswM)},
{knowhere::indexparam::EFCONSTRUCTION,
std::to_string(kHnswEfConstruction)}};
create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
if (metric_type == knowhere::metric::L2) {
index_hnsw_l2_ =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
milvus::storage::FileManagerContext());
index_hnsw_l2_->BuildWithDataset(dataset, build_conf);
ASSERT_EQ(index_hnsw_l2_->Count(), nb_);
} else if (metric_type == knowhere::metric::IP) {
index_hnsw_ip_ =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
milvus::storage::FileManagerContext());
index_hnsw_ip_->BuildWithDataset(dataset, build_conf);
ASSERT_EQ(index_hnsw_ip_->Count(), nb_);
} else if (metric_type == knowhere::metric::COSINE) {
index_hnsw_cos_ =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
milvus::storage::FileManagerContext());
index_hnsw_cos_->BuildWithDataset(dataset, build_conf);
ASSERT_EQ(index_hnsw_cos_->Count(), nb_);
} else {
FAIL() << "Unsupported metric type: " << metric_type;
}
}
}
static void
SetUpVectorBase() {
vector_base_ = std::make_unique<ConcurrentVector<milvus::FloatVector>>(
dim_, kSizePerChunk);
vector_base_->set_data_raw(0, base_dataset_.data(), nb_);
ASSERT_EQ(vector_base_->num_chunk(),
(nb_ + kSizePerChunk - 1) / kSizePerChunk);
}
static void
SetUpChunkedColumn() {
column_ = std::make_unique<ChunkedColumn>();
const size_t num_chunks_ = (nb_ + kSizePerChunk - 1) / kSizePerChunk;
column_data_.resize(num_chunks_);
size_t offset = 0;
for (size_t i = 0; i < num_chunks_; ++i) {
const size_t rows = std::min(nb_ - offset, kSizePerChunk);
const size_t chunk_bitset_size = (rows + 7) / 8;
const size_t buf_size =
chunk_bitset_size + rows * dim_ * sizeof(float);
auto& chunk_data = column_data_[i];
chunk_data.resize(buf_size);
memcpy(chunk_data.data() + chunk_bitset_size,
base_dataset_.cbegin() + offset * dim_,
rows * dim_ * sizeof(float));
column_->AddChunk(std::make_shared<FixedWidthChunk>(
rows, dim_, chunk_data.data(), buf_size, sizeof(float), false));
offset += rows;
}
}
static void
SetUpTestSuite() {
auto schema = std::make_shared<Schema>();
auto fakevec_id = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim_, kMetricType);
// generate base dataset
base_dataset_ =
segcore::DataGen(schema, nb_).get_col<float>(fakevec_id);
// generate query dataset
query_dataset_ = {base_dataset_.cbegin(),
base_dataset_.cbegin() + nq_ * dim_};
knowhere_query_dataset_ =
knowhere::GenDataSet(nq_, dim_, query_dataset_.data());
search_dataset_ = dataset::SearchDataset{
.metric_type = kMetricType,
.num_queries = nq_,
.topk = kBatchSize,
.round_decimal = -1,
.dim = dim_,
.query_data = query_dataset_.data(),
};
BuildIndex();
SetUpVectorBase();
SetUpChunkedColumn();
}
static void
TearDownTestSuite() {
base_dataset_.clear();
query_dataset_.clear();
index_hnsw_l2_.reset();
index_hnsw_ip_.reset();
index_hnsw_cos_.reset();
knowhere_query_dataset_.reset();
vector_base_.reset();
column_.reset();
}
void
SetUp() override {
auto metric_type = std::get<1>(GetParam());
if (metric_type == knowhere::metric::L2) {
metric_type_ = knowhere::metric::L2;
search_dataset_.metric_type = knowhere::metric::L2;
index_hnsw_ = index_hnsw_l2_.get();
} else if (metric_type == knowhere::metric::IP) {
metric_type_ = knowhere::metric::IP;
search_dataset_.metric_type = knowhere::metric::IP;
index_hnsw_ = index_hnsw_ip_.get();
} else if (metric_type == knowhere::metric::COSINE) {
metric_type_ = knowhere::metric::COSINE;
search_dataset_.metric_type = knowhere::metric::COSINE;
index_hnsw_ = index_hnsw_cos_.get();
} else {
FAIL() << "Unsupported metric type: " << metric_type;
}
}
void
TearDown() override {
}
};
// initialize static variables
DataType CachedSearchIteratorTest::data_type_ = DataType::VECTOR_FLOAT;
int64_t CachedSearchIteratorTest::dim_ = kDim;
int64_t CachedSearchIteratorTest::nb_ = kNumVectors;
int64_t CachedSearchIteratorTest::nq_ = kNumQueries;
IndexBasePtr CachedSearchIteratorTest::index_hnsw_l2_ = nullptr;
IndexBasePtr CachedSearchIteratorTest::index_hnsw_ip_ = nullptr;
IndexBasePtr CachedSearchIteratorTest::index_hnsw_cos_ = nullptr;
knowhere::DataSetPtr CachedSearchIteratorTest::knowhere_query_dataset_ =
nullptr;
dataset::SearchDataset CachedSearchIteratorTest::search_dataset_;
FixedVector<float> CachedSearchIteratorTest::base_dataset_;
FixedVector<float> CachedSearchIteratorTest::query_dataset_;
std::unique_ptr<ConcurrentVector<milvus::FloatVector>>
CachedSearchIteratorTest::vector_base_ = nullptr;
std::shared_ptr<ChunkedColumn> CachedSearchIteratorTest::column_ = nullptr;
std::vector<std::vector<char>> CachedSearchIteratorTest::column_data_;
/********* Testcases Start **********/
TEST_P(CachedSearchIteratorTest, NextBatchNormal) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const std::vector<size_t> kBatchSizes = {
1, 7, 43, 99, 100, 101, 1000, 1005};
for (size_t batch_size : kBatchSizes) {
std::cout << "batch_size: " << batch_size << std::endl;
search_info.iterator_v2_info_->batch_size = batch_size;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
std::unordered_set<int64_t> seg_offsets;
size_t cnt = 0;
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
++cnt;
seg_offsets.insert(
search_result.seg_offsets_[i * batch_size + j]);
}
EXPECT_EQ(seg_offsets.size(), cnt);
if (metric_type_ == knowhere::metric::L2) {
EXPECT_EQ(search_result.distances_[i * batch_size], 0);
}
}
EXPECT_EQ(search_result.unity_topK_, batch_size);
EXPECT_EQ(search_result.total_nq_, nq_);
EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size);
EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size);
}
}
TEST_P(CachedSearchIteratorTest, NextBatchDistBound) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const size_t batch_size = kBatchSize;
const float dist_bound_factor = PositivelyRelated(metric_type_) ? 0.5 : 1.5;
float dist_bound = 0;
{
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
iterator->NextBatch(search_info, search_result);
bool found_dist_bound = false;
// use the last distance of the first query * factor as the dist bound
for (size_t j = batch_size - 1; j >= 0; --j) {
if (search_result.seg_offsets_[j] != -1) {
dist_bound = search_result.distances_[j] * dist_bound_factor;
found_dist_bound = true;
break;
}
}
ASSERT_TRUE(found_dist_bound);
search_info.iterator_v2_info_->last_bound = dist_bound;
for (size_t rnd = 1; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
if (PositivelyRelated(metric_type_)) {
EXPECT_LT(search_result.distances_[i * batch_size + j],
dist_bound);
} else {
EXPECT_GT(search_result.distances_[i * batch_size + j],
dist_bound);
}
}
}
}
}
}
TEST_P(CachedSearchIteratorTest, NextBatchDistBoundEmptyResults) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const size_t batch_size = kBatchSize;
const float dist_bound = PositivelyRelated(metric_type_)
? -std::numeric_limits<float>::max()
: std::numeric_limits<float>::max();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_->last_bound = dist_bound;
size_t total_cnt = 0;
for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
++total_cnt;
}
}
}
EXPECT_EQ(total_cnt, 0);
}
TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadius) {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
SearchResult search_result;
SearchInfo search_info = GetDefaultNormalSearchInfo();
search_info.search_params_[knowhere::meta::RADIUS] = radius;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < kBatchSize; ++j) {
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
break;
}
float dist = search_result.distances_[i * kBatchSize + j];
if (PositivelyRelated(metric_type_)) {
ASSERT_GT(dist, radius);
} else {
ASSERT_LT(dist, radius);
}
}
}
}
}
TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadiusAndRangeFilter) {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
SearchResult search_result;
SearchInfo search_info = GetDefaultNormalSearchInfo();
search_info.search_params_[knowhere::meta::RADIUS] = radius;
search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < kBatchSize; ++j) {
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
break;
}
float dist = search_result.distances_[i * kBatchSize + j];
if (PositivelyRelated(metric_type_)) {
ASSERT_GT(dist, radius);
ASSERT_LE(dist, range_filter);
} else {
ASSERT_LT(dist, radius);
ASSERT_GE(dist, range_filter);
}
}
}
}
}
TEST_P(CachedSearchIteratorTest,
NextBatchRangeSearchLastBoundRadiusRangeFilter) {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
SearchResult search_result;
const float diff = (radius + range_filter) / 2;
const std::vector<float> last_bounds = {radius - diff,
radius,
radius + diff,
range_filter,
range_filter + diff};
SearchInfo search_info = GetDefaultNormalSearchInfo();
search_info.search_params_[knowhere::meta::RADIUS] = radius;
search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter;
for (float last_bound : last_bounds) {
search_info.iterator_v2_info_->last_bound = last_bound;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < kBatchSize; ++j) {
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
break;
}
float dist = search_result.distances_[i * kBatchSize + j];
if (PositivelyRelated(metric_type_)) {
ASSERT_LE(dist, last_bound);
ASSERT_GT(dist, radius);
ASSERT_LE(dist, range_filter);
} else {
ASSERT_GT(dist, last_bound);
ASSERT_LT(dist, radius);
ASSERT_GE(dist, range_filter);
}
}
}
}
}
}
TEST_P(CachedSearchIteratorTest, NextBatchZeroBatchSize) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_->batch_size = 0;
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchDiffBatchSizeComparedToInit) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_->batch_size = kBatchSize + 1;
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchEmptySearchInfo) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
SearchInfo empty_search_info;
EXPECT_THROW(iterator->NextBatch(empty_search_info, search_result),
SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchEmptyIteratorV2Info) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_ = std::nullopt;
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchtAllBatchesNormal) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const std::vector<size_t> kBatchSizes = {
1, 7, 43, 99, 100, 101, 1000, 1005};
// const std::vector<size_t> kBatchSizes = {1005};
for (size_t batch_size : kBatchSizes) {
search_info.iterator_v2_info_->batch_size = batch_size;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
size_t total_cnt = 0;
for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
SearchResult search_result;
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
std::unordered_set<int64_t> seg_offsets;
size_t cnt = 0;
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
++cnt;
seg_offsets.insert(
search_result.seg_offsets_[i * batch_size + j]);
}
total_cnt += cnt;
// check no duplicate
EXPECT_EQ(seg_offsets.size(), cnt);
// only check if the first distance of the first batch is 0
if (rnd == 0 && metric_type_ == knowhere::metric::L2) {
EXPECT_EQ(search_result.distances_[i * batch_size], 0);
}
}
EXPECT_EQ(search_result.unity_topK_, batch_size);
EXPECT_EQ(search_result.total_nq_, nq_);
EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size);
EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size);
}
if (std::get<0>(GetParam()) == ConstructorType::VectorIndex) {
EXPECT_GE(total_cnt, nb_ * nq_ * 0.9);
} else {
EXPECT_EQ(total_cnt, nb_ * nq_);
}
}
}
TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidSearchInfo) {
EXPECT_THROW(
DispatchIterator(std::get<0>(GetParam()), SearchInfo{}, nullptr),
SegcoreError);
EXPECT_THROW(
DispatchIterator(
std::get<0>(GetParam()), SearchInfo{.metric_type_ = ""}, nullptr),
SegcoreError);
EXPECT_THROW(DispatchIterator(std::get<0>(GetParam()),
SearchInfo{.metric_type_ = metric_type_},
nullptr),
SegcoreError);
EXPECT_THROW(DispatchIterator(std::get<0>(GetParam()),
SearchInfo{.metric_type_ = metric_type_,
.iterator_v2_info_ = {}},
nullptr),
SegcoreError);
EXPECT_THROW(
DispatchIterator(std::get<0>(GetParam()),
SearchInfo{.metric_type_ = metric_type_,
.iterator_v2_info_ =
SearchIteratorV2Info{.batch_size = 0}},
nullptr),
SegcoreError);
}
TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidParams) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
if (std::get<0>(GetParam()) == ConstructorType::VectorIndex) {
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
dynamic_cast<const VectorIndex&>(*index_hnsw_),
nullptr,
search_info,
nullptr),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
dynamic_cast<const VectorIndex&>(*index_hnsw_),
std::make_shared<knowhere::DataSet>(),
search_info,
nullptr),
SegcoreError);
} else if (std::get<0>(GetParam()) == ConstructorType::RawData) {
EXPECT_THROW(
auto iterator = std::make_unique<CachedSearchIterator>(
dataset::SearchDataset{},
dataset::RawDataset{0, dim_, nb_, base_dataset_.data()},
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
} else if (std::get<0>(GetParam()) == ConstructorType::VectorBase) {
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
dataset::SearchDataset{},
vector_base_.get(),
nb_,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
search_dataset_,
nullptr,
nb_,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
search_dataset_,
vector_base_.get(),
0,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
} else if (std::get<0>(GetParam()) == ConstructorType::ChunkedColumn) {
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
nullptr,
search_dataset_,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
column_,
dataset::SearchDataset{},
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
}
}
/********* Testcases End **********/
INSTANTIATE_TEST_SUITE_P(
CachedSearchIteratorTests,
CachedSearchIteratorTest,
::testing::Combine(::testing::ValuesIn(kConstructorTypes),
::testing::ValuesIn(kMetricTypes)),
[](const testing::TestParamInfo<std::tuple<ConstructorType, MetricType>>&
info) {
std::string constructor_type_str;
ConstructorType constructor_type = std::get<0>(info.param);
MetricType metric_type = std::get<1>(info.param);
switch (constructor_type) {
case ConstructorType::VectorIndex:
constructor_type_str = "VectorIndex";
break;
case ConstructorType::RawData:
constructor_type_str = "RawData";
break;
case ConstructorType::VectorBase:
constructor_type_str = "VectorBase";
break;
case ConstructorType::ChunkedColumn:
constructor_type_str = "ChunkedColumn";
break;
default:
constructor_type_str = "Unknown constructor type";
};
if (metric_type == knowhere::metric::L2) {
constructor_type_str += "_L2";
} else if (metric_type == knowhere::metric::IP) {
constructor_type_str += "_IP";
} else if (metric_type == knowhere::metric::COSINE) {
constructor_type_str += "_COSINE";
} else {
constructor_type_str += "_Unknown";
}
return constructor_type_str;
});

View File

@ -55,6 +55,12 @@ message Array {
schema.DataType element_type = 3;
}
message SearchIteratorV2Info {
string token = 1;
uint32 batch_size = 2;
optional float last_bound = 3;
}
message QueryInfo {
int64 topk = 1;
string metric_type = 3;
@ -67,6 +73,7 @@ message QueryInfo {
double bm25_avgdl = 10;
int64 query_field_id =11;
string hints = 12;
optional SearchIteratorV2Info search_iterator_v2_info = 13;
}
message ColumnInfo {

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/hashicorp/golang-lru/v2/expirable"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
@ -304,6 +305,13 @@ func (node *Proxy) Init() error {
node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool()
// Enable internal rand pool for UUIDv4 generation
// This is NOT thread-safe and should only be called before the service starts and
// there is no possibility that New or any other UUID V4 generation function will be called concurrently
// Only proxy generates UUID for now, and one Milvus process only has one proxy
uuid.EnableRandPool()
log.Debug("enable rand pool for UUIDv4 generation")
log.Info("init proxy done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address))
return nil
}

View File

@ -8,6 +8,7 @@ import (
"strings"
"github.com/cockroachdb/errors"
"github.com/google/uuid"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -82,6 +83,81 @@ type SearchInfo struct {
isIterator bool
}
func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) {
isIteratorV2Str, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterV2Key, searchParamsPair)
isIteratorV2, _ := strconv.ParseBool(isIteratorV2Str)
if !isIteratorV2 {
return nil, nil
}
// iteratorV1 and iteratorV2 should be set together for compatibility
if !isIterator {
return nil, fmt.Errorf("both %s and %s must be set in the SDK", IteratorField, SearchIterV2Key)
}
// disable groupBy when doing iteratorV2
// same behavior with V1
if isIteratorV2 && groupByFieldId > 0 {
return nil, merr.WrapErrParameterInvalid("", "",
"GroupBy is not permitted when using a search iterator")
}
// disable offset when doing iteratorV2
if isIteratorV2 && offset > 0 {
return nil, merr.WrapErrParameterInvalid("", "",
"Setting an offset is not permitted when using a search iterator v2")
}
// parse token, generate if not exist
token, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterIdKey, searchParamsPair)
if token == "" {
generatedToken, err := uuid.NewRandom()
if err != nil {
return nil, err
}
token = generatedToken.String()
} else {
// Validate existing token is a valid UUID
if _, err := uuid.Parse(token); err != nil {
return nil, fmt.Errorf("invalid token format")
}
}
// parse batch size, required non-zero value
batchSizeStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterBatchSizeKey, searchParamsPair)
if batchSizeStr == "" {
return nil, fmt.Errorf("batch size is required")
}
batchSize, err := strconv.ParseInt(batchSizeStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("batch size is invalid, %w", err)
}
// use the same validation logic as topk
if err := validateLimit(batchSize); err != nil {
return nil, fmt.Errorf("batch size is invalid, %w", err)
}
*queryTopK = batchSize // for compatibility
// prepare plan iterator v2 info proto
planIteratorV2Info := &planpb.SearchIteratorV2Info{
Token: token,
BatchSize: uint32(batchSize),
}
// append optional last bound if applicable
lastBoundStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterLastBoundKey, searchParamsPair)
if lastBoundStr != "" {
lastBound, err := strconv.ParseFloat(lastBoundStr, 32)
if err != nil {
return nil, fmt.Errorf("failed to parse input last bound, %w", err)
}
lastBoundFloat32 := float32(lastBound)
planIteratorV2Info.LastBound = &lastBoundFloat32 // escape pointer
}
return planIteratorV2Info, nil
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
var topK int64
@ -196,16 +272,22 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
"Not allowed to do range-search when doing search-group-by")}
}
planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, offset, &queryTopK)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("parse iterator v2 info failed: %w", err)}
}
return &SearchInfo{
planInfo: &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
StrictGroupSize: strictGroupSize,
Hints: hints,
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
StrictGroupSize: strictGroupSize,
Hints: hints,
SearchIteratorV2Info: planSearchIteratorV2Info,
},
offset: offset,
isIterator: isIterator,

View File

@ -69,6 +69,11 @@ const (
OffsetKey = "offset"
LimitKey = "limit"
SearchIterV2Key = "search_iter_v2"
SearchIterBatchSizeKey = "search_iter_batch_size"
SearchIterLastBoundKey = "search_iter_last_bound"
SearchIterIdKey = "search_iter_id"
InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask"
DropCollectionTaskName = "DropCollectionTask"

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -590,12 +591,15 @@ func (t *searchTask) Execute(ctx context.Context) error {
return nil
}
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
metricType := ""
if len(toReduceResults) >= 1 {
metricType = toReduceResults[0].GetMetricType()
}
return metricType
}
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
defer sp.End()
@ -631,6 +635,24 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter
return result, nil
}
// find the last bound based on reduced results and metric type
// only support nq == 1, for search iterator v2
func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 {
len := len(result.Results.Scores)
if len > 0 && result.GetResults().GetNumQueries() == 1 {
return result.Results.Scores[len-1]
}
// if no results found and incoming last bound is not nil, return it
if incomingLastBound != nil {
return *incomingLastBound
}
// if no results found and it is the first call, return the closest bound
if metric.PositivelyRelated(metricType) {
return math.MaxFloat32
}
return -math.MaxFloat32
}
func (t *searchTask) PostExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute")
defer sp.End()
@ -670,6 +692,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return err
}
metricType := getMetricType(toReduceResults)
// reduce
if t.SearchRequest.GetIsAdvanced() {
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
@ -696,16 +719,12 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index]
metricType := ""
if len(internalResults) >= 1 {
metricType = internalResults[0].GetMetricType()
}
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true)
subMetricType := getMetricType(internalResults)
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
if err != nil {
return err
}
t.reScorers[index].setMetricType(metricType)
t.reScorers[index].setMetricType(subMetricType)
t.reScorers[index].reScore(result)
multipleMilvusResults[index] = result
}
@ -721,7 +740,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return err
}
} else {
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false)
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
if err != nil {
return err
}
@ -751,6 +770,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName()
if t.isIterator && len(t.queryInfos) == 1 && t.queryInfos[0] != nil {
if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil {
t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{
Token: iterInfo.GetToken(),
LastBound: getLastBound(t.result, iterInfo.LastBound, metricType),
}
}
}
if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = getMaxMvccTsFromChannels(t.queryChannelsTs, t.BeginTs())

View File

@ -18,12 +18,14 @@ package proxy
import (
"context"
"fmt"
"math"
"strconv"
"strings"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -103,9 +105,124 @@ func TestSearchTask_PostExecute(t *testing.T) {
assert.Equal(t, qt.resultSizeInsufficient, true)
assert.Equal(t, qt.isTopkReduce, false)
})
t.Run("test search iterator v2", func(t *testing.T) {
const (
kRows = 10
kToken = "test-token"
)
collName := "test_collection_search_iterator_v2" + funcutil.GenRandomStr()
collSchema := createColl(t, collName, rc)
createIteratorSearchTask := func(t *testing.T, metricType string, rows int) *searchTask {
ids := make([]int64, rows)
for i := range ids {
ids[i] = int64(i)
}
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
scores := make([]float32, rows)
// proxy needs to reverse the score for negatively related metrics
for i := range scores {
if metric.PositivelyRelated(metricType) {
scores[i] = float32(len(scores) - i)
} else {
scores[i] = -float32(i + 1)
}
}
resultData := &schemapb.SearchResultData{
Ids: resultIDs,
Scores: scores,
NumQueries: 1,
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
Nq: 1,
},
schema: newSchemaInfo(collSchema),
request: &milvuspb.SearchRequest{
CollectionName: collName,
},
queryInfos: []*planpb.QueryInfo{{
SearchIteratorV2Info: &planpb.SearchIteratorV2Info{
Token: kToken,
BatchSize: 1,
},
}},
result: &milvuspb.SearchResults{
Results: resultData,
},
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
tr: timerecord.NewTimeRecorder("search"),
isIterator: true,
}
bytes, err := proto.Marshal(resultData)
assert.NoError(t, err)
qt.resultBuf.Insert(&internalpb.SearchResults{
MetricType: metricType,
SlicedBlob: bytes,
})
return qt
}
t.Run("test search iterator v2", func(t *testing.T) {
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
for _, metricType := range metrics {
qt := createIteratorSearchTask(t, metricType, kRows)
err = qt.PostExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
if metric.PositivelyRelated(metricType) {
assert.Equal(t, float32(1), qt.result.Results.SearchIteratorV2Results.LastBound)
} else {
assert.Equal(t, float32(kRows), qt.result.Results.SearchIteratorV2Results.LastBound)
}
}
})
t.Run("test search iterator v2 with empty result", func(t *testing.T) {
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
for _, metricType := range metrics {
qt := createIteratorSearchTask(t, metricType, 0)
err = qt.PostExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
if metric.PositivelyRelated(metricType) {
assert.Equal(t, float32(math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound)
} else {
assert.Equal(t, float32(-math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound)
}
}
})
t.Run("test search iterator v2 with empty result and incoming last bound", func(t *testing.T) {
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
kLastBound := float32(10)
for _, metricType := range metrics {
qt := createIteratorSearchTask(t, metricType, 0)
qt.queryInfos[0].SearchIteratorV2Info.LastBound = &kLastBound
err = qt.PostExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
assert.Equal(t, kLastBound, qt.result.Results.SearchIteratorV2Results.LastBound)
}
})
})
}
func createColl(t *testing.T, name string, rc types.RootCoordClient) {
func createColl(t *testing.T, name string, rc types.RootCoordClient) *schemapb.CollectionSchema {
schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name)
marshaledSchema, err := proto.Marshal(schema)
require.NoError(t, err)
@ -126,6 +243,8 @@ func createColl(t *testing.T, name string, rc types.RootCoordClient) {
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
return schema
}
func getBaseSearchParams() []*commonpb.KeyValuePair {
@ -2599,6 +2718,157 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size"))
}
})
t.Run("check search iterator v2", func(t *testing.T) {
kBatchSize := uint32(10)
generateValidParamsForSearchIteratorV2 := func() []*commonpb.KeyValuePair {
param := getValidSearchParams()
return append(param,
&commonpb.KeyValuePair{
Key: SearchIterV2Key,
Value: "True",
},
&commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
},
&commonpb.KeyValuePair{
Key: SearchIterBatchSizeKey,
Value: fmt.Sprintf("%d", kBatchSize),
},
)
}
t.Run("iteratorV2 normal", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
searchInfo := parseSearchInfo(param, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
assert.Equal(t, kBatchSize, searchInfo.planInfo.SearchIteratorV2Info.BatchSize)
assert.Len(t, searchInfo.planInfo.SearchIteratorV2Info.Token, 36)
assert.Equal(t, int64(kBatchSize), searchInfo.planInfo.GetTopk()) // compatibility
})
t.Run("iteratorV2 without isIterator", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, IteratorField, "False")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "both")
})
t.Run("iteratorV2 with groupBy", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "string_field",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(param, schema, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "roupBy")
})
t.Run("iteratorV2 with offset", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: OffsetKey,
Value: "10",
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "offset")
})
t.Run("iteratorV2 invalid token", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterIdKey,
Value: "invalid_token",
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "invalid token format")
})
t.Run("iteratorV2 passed token must be same", func(t *testing.T) {
token, err := uuid.NewRandom()
assert.NoError(t, err)
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterIdKey,
Value: token.String(),
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
assert.Equal(t, token.String(), searchInfo.planInfo.SearchIteratorV2Info.Token)
})
t.Run("iteratorV2 batch size", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, "1.123")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
})
t.Run("iteratorV2 batch size", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, "")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is required")
})
t.Run("iteratorV2 batch size negative", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, "-1")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
})
t.Run("iteratorV2 batch size too large", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, fmt.Sprintf("%d", Params.QuotaConfig.TopKLimit.GetAsInt64()+1))
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
})
t.Run("iteratorV2 last bound", func(t *testing.T) {
kLastBound := float32(1.123)
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterLastBoundKey,
Value: fmt.Sprintf("%f", kLastBound),
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
assert.Equal(t, kLastBound, *searchInfo.planInfo.SearchIteratorV2Info.LastBound)
})
t.Run("iteratorV2 invalid last bound", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterLastBoundKey,
Value: "xxx",
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "failed to parse input last bound")
})
})
}
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {