feat: LRU cache implementation (#32567)

issue: https://github.com/milvus-io/milvus/issues/32783
This pr is the implementation of lru cache on branch lru-dev.

Signed-off-by: sunby <sunbingyi1992@gmail.com>
Co-authored-by: chyezh <chyezh@outlook.com>
Co-authored-by: MrPresent-Han <chun.han@zilliz.com>
Co-authored-by: Ted Xu <ted.xu@zilliz.com>
Co-authored-by: jaime <yun.zhang@zilliz.com>
Co-authored-by: wayblink <anyang.wang@zilliz.com>
pull/32709/head
Bingyi Sun 2024-05-06 20:29:30 +08:00 committed by GitHub
parent 37a99ca23e
commit fecd9c21ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 3510 additions and 976 deletions

View File

@ -367,6 +367,7 @@ queryNode:
maxQueueLength: 16 # Maximum length of task queue in flowgraph
maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph
enableSegmentPrune: false # use partition prune function on shard delegator
useStreamComputing: false
ip: # if not specified, use the first unicastable address
port: 21123
grpc:

View File

@ -502,6 +502,11 @@ VectorDiskAnnIndex<T>::update_load_json(const Config& config) {
}
}
if (config.contains(kMmapFilepath)) {
load_config.erase(kMmapFilepath);
load_config[kEnableMmap] = true;
}
return load_config;
}

View File

@ -18,7 +18,6 @@
#include <unistd.h>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <filesystem>
#include <memory>
@ -33,7 +32,6 @@
#include "index/Index.h"
#include "index/IndexInfo.h"
#include "index/Meta.h"
#include "index/Utils.h"
#include "common/EasyAssert.h"
#include "config/ConfigKnowhere.h"
@ -44,16 +42,15 @@
#include "common/FieldData.h"
#include "common/File.h"
#include "common/Slice.h"
#include "common/Tracer.h"
#include "common/RangeSearchHelper.h"
#include "common/Utils.h"
#include "log/Log.h"
#include "mmap/Types.h"
#include "storage/DataCodec.h"
#include "storage/MemFileManagerImpl.h"
#include "storage/ThreadPools.h"
#include "storage/space.h"
#include "storage/Util.h"
#include "storage/prometheus_client.h"
namespace milvus::index {
@ -733,7 +730,8 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
}
LOG_INFO("load with slice meta: {}", !slice_meta_filepath.empty());
std::chrono::duration<double> load_duration_sum;
std::chrono::duration<double> write_disk_duration_sum;
if (!slice_meta_filepath
.empty()) { // load with the slice meta info, then we can load batch by batch
std::string index_file_prefix = slice_meta_filepath.substr(
@ -751,15 +749,20 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
std::string prefix = item[NAME];
int slice_num = item[SLICE_NUM];
auto total_len = static_cast<size_t>(item[TOTAL_LEN]);
auto HandleBatch = [&](int index) {
auto start_load2_mem = std::chrono::system_clock::now();
auto batch_data = file_manager_->LoadIndexToMemory(batch);
load_duration_sum +=
(std::chrono::system_clock::now() - start_load2_mem);
for (int j = index - batch.size() + 1; j <= index; j++) {
std::string file_name = GenSlicedFileName(prefix, j);
AssertInfo(batch_data.find(file_name) != batch_data.end(),
"lost index slice data");
auto data = batch_data[file_name];
auto start_write_file = std::chrono::system_clock::now();
auto written = file.Write(data->Data(), data->Size());
write_disk_duration_sum +=
(std::chrono::system_clock::now() - start_write_file);
AssertInfo(
written == data->Size(),
fmt::format("failed to write index data to disk {}: {}",
@ -784,24 +787,46 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
}
}
} else {
//1. load files into memory
auto start_load_files2_mem = std::chrono::system_clock::now();
auto result = file_manager_->LoadIndexToMemory(std::vector<std::string>(
pending_index_files.begin(), pending_index_files.end()));
load_duration_sum +=
(std::chrono::system_clock::now() - start_load_files2_mem);
//2. write data into files
auto start_write_file = std::chrono::system_clock::now();
for (auto& [_, index_data] : result) {
file.Write(index_data->Data(), index_data->Size());
}
write_disk_duration_sum +=
(std::chrono::system_clock::now() - start_write_file);
}
milvus::storage::internal_storage_download_duration.Observe(
std::chrono::duration_cast<std::chrono::milliseconds>(load_duration_sum)
.count());
milvus::storage::internal_storage_write_disk_duration.Observe(
std::chrono::duration_cast<std::chrono::milliseconds>(
write_disk_duration_sum)
.count());
file.Close();
LOG_INFO("load index into Knowhere...");
auto conf = config;
conf.erase(kMmapFilepath);
conf[kEnableMmap] = true;
auto start_deserialize = std::chrono::system_clock::now();
auto stat = index_.DeserializeFromFile(filepath.value(), conf);
auto deserialize_duration =
std::chrono::system_clock::now() - start_deserialize;
if (stat != knowhere::Status::success) {
PanicInfo(ErrorCode::UnexpectedError,
"failed to Deserialize index: {}",
KnowhereStatusString(stat));
}
milvus::storage::internal_storage_deserialize_duration.Observe(
std::chrono::duration_cast<std::chrono::milliseconds>(
deserialize_duration)
.count());
auto dim = index_.Dim();
this->SetDim(index_.Dim());
@ -811,7 +836,18 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
"failed to unlink mmap index file {}: {}",
filepath.value(),
strerror(errno));
LOG_INFO("load vector index done");
LOG_INFO(
"load vector index done, mmap_file_path:{}, download_duration:{}, "
"write_files_duration:{}, deserialize_duration:{}",
filepath.value(),
std::chrono::duration_cast<std::chrono::milliseconds>(load_duration_sum)
.count(),
std::chrono::duration_cast<std::chrono::milliseconds>(
write_disk_duration_sum)
.count(),
std::chrono::duration_cast<std::chrono::milliseconds>(
deserialize_duration)
.count());
}
template <typename T>

View File

@ -25,6 +25,7 @@ set(SEGCORE_FILES
SegmentSealedImpl.cpp
FieldIndexing.cpp
Reduce.cpp
StreamReduce.cpp
metrics_c.cpp
plan_c.cpp
reduce_c.cpp
@ -36,7 +37,8 @@ set(SEGCORE_FILES
segcore_init_c.cpp
TimestampIndex.cpp
Utils.cpp
ConcurrentVector.cpp)
ConcurrentVector.cpp
ReduceUtils.cpp)
add_library(milvus_segcore SHARED ${SEGCORE_FILES})
target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage)

View File

@ -41,10 +41,10 @@ class ThreadSafeVector {
template <typename... Args>
void
emplace_to_at_least(int64_t size, Args... args) {
std::lock_guard lck(mutex_);
if (size <= size_) {
return;
}
std::lock_guard lck(mutex_);
while (vec_.size() < size) {
vec_.emplace_back(std::forward<Args...>(args...));
++size_;
@ -52,24 +52,25 @@ class ThreadSafeVector {
}
const Type&
operator[](int64_t index) const {
std::shared_lock lck(mutex_);
AssertInfo(index < size_,
fmt::format(
"index out of range, index={}, size_={}", index, size_));
std::shared_lock lck(mutex_);
return vec_[index];
}
Type&
operator[](int64_t index) {
std::shared_lock lck(mutex_);
AssertInfo(index < size_,
fmt::format(
"index out of range, index={}, size_={}", index, size_));
std::shared_lock lck(mutex_);
return vec_[index];
}
int64_t
size() const {
std::lock_guard lck(mutex_);
return size_;
}
@ -81,7 +82,7 @@ class ThreadSafeVector {
}
private:
std::atomic<int64_t> size_ = 0;
int64_t size_ = 0;
std::deque<Type> vec_;
mutable std::shared_mutex mutex_;
};

View File

@ -598,6 +598,11 @@ struct InsertRecord {
fields_data_.clear();
}
bool
empty() const {
return pk2offset_->empty();
}
public:
ConcurrentVector<Timestamp> timestamps_;
ConcurrentVector<idx_t> row_ids_;

View File

@ -19,6 +19,7 @@
#include "Utils.h"
#include "common/EasyAssert.h"
#include "pkVisitor.h"
#include "ReduceUtils.h"
namespace milvus::segcore {
@ -130,7 +131,6 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
void
ReduceHelper::FillPrimaryKey() {
std::vector<SearchResult*> valid_search_results;
// get primary keys for duplicates removal
uint32_t valid_index = 0;
for (auto& search_result : search_results_) {
@ -368,7 +368,7 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
search_result_data->set_all_search_count(all_search_count);
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
std::vector<MergeBase> result_pairs(result_count);
// reserve space for pks
auto primary_field_id =
@ -461,14 +461,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
group_by_values[loc] =
search_result->group_by_values_.value()[ki];
// set result offset to fill output fields data
result_pairs[loc] = std::make_pair(search_result, ki);
result_pairs[loc] = {&search_result->output_fields_data_, ki};
}
}
// update result topKs
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
}
AssembleGroupByValues(search_result_data, group_by_values);
AssembleGroupByValues(search_result_data, group_by_values, plan_);
AssertInfo(search_result_data->scores_size() == result_count,
"wrong scores size, size = " +
@ -498,89 +498,4 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
return buffer;
}
void
ReduceHelper::AssembleGroupByValues(
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
const std::vector<GroupByValueType>& group_by_vals) {
auto group_by_field_id = plan_->plan_node_->search_info_.group_by_field_id_;
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
auto group_by_values_field =
std::make_unique<milvus::proto::schema::ScalarField>();
auto group_by_field =
plan_->schema_.operator[](group_by_field_id.value());
DataType group_by_data_type = group_by_field.get_data_type();
int group_by_val_size = group_by_vals.size();
switch (group_by_data_type) {
case DataType::INT8: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int8_t val = std::get<int8_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT16: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int16_t val = std::get<int16_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT32: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int32_t val = std::get<int32_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT64: {
auto field_data = group_by_values_field->mutable_long_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int64_t val = std::get<int64_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::BOOL: {
auto field_data = group_by_values_field->mutable_bool_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
bool val = std::get<bool>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::VARCHAR: {
auto field_data = group_by_values_field->mutable_string_data();
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
std::string val =
std::move(std::get<std::string>(group_by_vals[idx]));
*(field_data->mutable_data()->Add()) = val;
}
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported datatype for group_by operations ",
group_by_data_type));
}
}
search_result->mutable_group_by_field_value()->set_type(
milvus::proto::schema::DataType(group_by_data_type));
search_result->mutable_group_by_field_value()
->mutable_scalars()
->MergeFrom(*group_by_values_field.get());
return;
}
}
} // namespace milvus::segcore

View File

@ -22,6 +22,7 @@
#include "common/QueryResult.h"
#include "query/PlanImpl.h"
#include "ReduceStructure.h"
#include "segment_c.h"
namespace milvus::segcore {
@ -55,18 +56,22 @@ class ReduceHelper {
return search_result_data_blobs_.release();
}
private:
void
Initialize();
protected:
void
FilterInvalidSearchResult(SearchResult* search_result);
void
RefreshSearchResult();
void
FillPrimaryKey();
void
RefreshSearchResult();
ReduceResultData();
private:
void
Initialize();
void
FillEntryData();
@ -76,44 +81,34 @@ class ReduceHelper {
int64_t topk,
int64_t& result_offset);
void
ReduceResultData();
std::vector<char>
GetSearchResultDataSlice(int slice_index_);
void
AssembleGroupByValues(
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
const std::vector<GroupByValueType>& group_by_vals);
private:
protected:
std::vector<SearchResult*>& search_results_;
milvus::query::Plan* plan_;
std::vector<int64_t> slice_nqs_;
std::vector<int64_t> slice_topKs_;
int64_t total_nq_;
int64_t num_segments_;
int64_t num_slices_;
std::vector<int64_t> slice_nqs_prefix_sum_;
// dim0: num_segments_; dim1: total_nq_; dim2: offset
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
// output
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
// Used for merge results,
// define these here to avoid allocating them for each query
std::vector<SearchResultPair> pairs_;
int64_t num_segments_;
std::vector<int64_t> slice_topKs_;
std::priority_queue<SearchResultPair*,
std::vector<SearchResultPair*>,
SearchResultPairComparator>
heap_;
// Used for merge results,
// define these here to avoid allocating them for each query
std::vector<SearchResultPair> pairs_;
std::unordered_set<milvus::PkType> pk_set_;
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
// dim0: num_segments_; dim1: total_nq_; dim2: offset
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
private:
std::vector<int64_t> slice_nqs_;
int64_t total_nq_;
// output
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
};
} // namespace milvus::segcore

View File

@ -0,0 +1,106 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
//
// Created by zilliz on 2024/3/26.
//
#include "ReduceUtils.h"
namespace milvus::segcore {
void
AssembleGroupByValues(
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
const std::vector<GroupByValueType>& group_by_vals,
milvus::query::Plan* plan) {
auto group_by_field_id = plan->plan_node_->search_info_.group_by_field_id_;
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
auto group_by_values_field =
std::make_unique<milvus::proto::schema::ScalarField>();
auto group_by_field =
plan->schema_.operator[](group_by_field_id.value());
DataType group_by_data_type = group_by_field.get_data_type();
int group_by_val_size = group_by_vals.size();
switch (group_by_data_type) {
case DataType::INT8: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int8_t val = std::get<int8_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT16: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int16_t val = std::get<int16_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT32: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int32_t val = std::get<int32_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT64: {
auto field_data = group_by_values_field->mutable_long_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int64_t val = std::get<int64_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::BOOL: {
auto field_data = group_by_values_field->mutable_bool_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
bool val = std::get<bool>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::VARCHAR: {
auto field_data = group_by_values_field->mutable_string_data();
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
std::string val =
std::move(std::get<std::string>(group_by_vals[idx]));
*(field_data->mutable_data()->Add()) = val;
}
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported datatype for group_by operations ",
group_by_data_type));
}
}
search_result->mutable_group_by_field_value()->set_type(
milvus::proto::schema::DataType(group_by_data_type));
search_result->mutable_group_by_field_value()
->mutable_scalars()
->MergeFrom(*group_by_values_field.get());
return;
}
}
} // namespace milvus::segcore

View File

@ -0,0 +1,26 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "pb/schema.pb.h"
#include "common/Types.h"
#include "query/PlanImpl.h"
namespace milvus::segcore {
void
AssembleGroupByValues(
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
const std::vector<GroupByValueType>& group_by_vals,
milvus::query::Plan* plan);
}

View File

@ -105,6 +105,10 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
") than other column's row count (" +
std::to_string(num_rows_.value()) + ")");
}
LOG_INFO(
"Before setting field_bit for field index, fieldID:{}. segmentID:{}, ",
info.field_id,
id_);
if (get_bit(field_data_ready_bitset_, field_id)) {
fields_.erase(field_id);
set_bit(field_data_ready_bitset_, field_id, false);
@ -118,6 +122,9 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
metric_type,
std::move(const_cast<LoadIndexInfo&>(info).index));
set_bit(index_ready_bitset_, field_id, true);
LOG_INFO("Has load vec index done, fieldID:{}. segmentID:{}, ",
info.field_id,
id_);
}
void
@ -1021,7 +1028,8 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type,
int64_t count,
void* output) const {
AssertInfo(is_system_field_ready(),
"System field isn't ready when do bulk_insert");
"System field isn't ready when do bulk_insert, segID:{}",
id_);
switch (system_type) {
case SystemFieldType::Timestamp:
AssertInfo(
@ -1647,5 +1655,20 @@ SegmentSealedImpl::generate_interim_index(const FieldId field_id) {
return false;
}
}
void
SegmentSealedImpl::RemoveFieldFile(const FieldId field_id) {
auto cc = storage::ChunkCacheSingleton::GetInstance().GetChunkCache();
if (cc == nullptr) {
return;
}
for (const auto& iter : field_data_info_.field_infos) {
if (iter.second.field_id == field_id.get()) {
for (const auto& binlog : iter.second.insert_files) {
cc->Remove(binlog);
}
return;
}
}
}
} // namespace milvus::segcore

View File

@ -88,6 +88,9 @@ class SegmentSealedImpl : public SegmentSealed {
DataType
GetFieldDataType(FieldId fieldId) const override;
void
RemoveFieldFile(const FieldId field_id);
public:
size_t
GetMemoryUsageInBytes() const override {

View File

@ -0,0 +1,690 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "StreamReduce.h"
#include "SegmentInterface.h"
#include "segcore/Utils.h"
#include "Reduce.h"
#include "segcore/pkVisitor.h"
#include "segcore/ReduceUtils.h"
namespace milvus::segcore {
void
StreamReducerHelper::FillEntryData() {
for (auto search_result : search_results_to_merge_) {
auto segment = static_cast<milvus::segcore::SegmentInterface*>(
search_result->segment_);
segment->FillTargetEntry(plan_, *search_result);
}
}
void
StreamReducerHelper::AssembleMergedResult() {
if (search_results_to_merge_.size() > 0) {
std::unique_ptr<MergedSearchResult> new_merged_result =
std::make_unique<MergedSearchResult>();
std::vector<PkType> new_merged_pks;
std::vector<float> new_merged_distances;
std::vector<GroupByValueType> new_merged_groupBy_vals;
std::vector<MergeBase> merge_output_data_bases;
std::vector<int64_t> new_result_offsets;
bool need_handle_groupBy =
plan_->plan_node_->search_info_.group_by_field_id_.has_value();
int valid_size = 0;
std::vector<int> real_topKs(total_nq_);
for (int i = 0; i < num_slice_; i++) {
auto nq_begin = slice_nqs_prefix_sum_[i];
auto nq_end = slice_nqs_prefix_sum_[i + 1];
int64_t result_count = 0;
for (auto search_result : search_results_to_merge_) {
AssertInfo(
search_result->topk_per_nq_prefix_sum_.size() ==
search_result->total_nq_ + 1,
"incorrect topk_per_nq_prefix_sum_ size in search result");
result_count +=
search_result->topk_per_nq_prefix_sum_[nq_end] -
search_result->topk_per_nq_prefix_sum_[nq_begin];
}
if (merged_search_result->has_result_) {
result_count +=
merged_search_result->topk_per_nq_prefix_sum_[nq_end] -
merged_search_result->topk_per_nq_prefix_sum_[nq_begin];
}
int nq_base_offset = valid_size;
valid_size += result_count;
new_merged_pks.resize(valid_size);
new_merged_distances.resize(valid_size);
merge_output_data_bases.resize(valid_size);
new_result_offsets.resize(valid_size);
if (need_handle_groupBy) {
new_merged_groupBy_vals.resize(valid_size);
}
for (auto qi = nq_begin; qi < nq_end; qi++) {
for (auto search_result : search_results_to_merge_) {
AssertInfo(search_result != nullptr,
"null search result when reorganize");
if (search_result->result_offsets_.size() == 0) {
continue;
}
auto topK_start =
search_result->topk_per_nq_prefix_sum_[qi];
auto topK_end =
search_result->topk_per_nq_prefix_sum_[qi + 1];
for (auto ki = topK_start; ki < topK_end; ki++) {
auto loc = search_result->result_offsets_[ki];
AssertInfo(loc < result_count && loc >= 0,
"invalid loc when GetSearchResultDataSlice, "
"loc = " +
std::to_string(loc) +
", result_count = " +
std::to_string(result_count));
new_merged_pks[nq_base_offset + loc] =
search_result->primary_keys_[ki];
new_merged_distances[nq_base_offset + loc] =
search_result->distances_[ki];
if (need_handle_groupBy) {
new_merged_groupBy_vals[nq_base_offset + loc] =
search_result->group_by_values_.value()[ki];
}
merge_output_data_bases[nq_base_offset + loc] = {
&search_result->output_fields_data_, ki};
new_result_offsets[nq_base_offset + loc] = loc;
real_topKs[qi]++;
}
}
if (merged_search_result->has_result_) {
auto topK_start =
merged_search_result->topk_per_nq_prefix_sum_[qi];
auto topK_end =
merged_search_result->topk_per_nq_prefix_sum_[qi + 1];
for (auto ki = topK_start; ki < topK_end; ki++) {
auto loc = merged_search_result->reduced_offsets_[ki];
AssertInfo(loc < result_count && loc >= 0,
"invalid loc when GetSearchResultDataSlice, "
"loc = " +
std::to_string(loc) +
", result_count = " +
std::to_string(result_count));
new_merged_pks[nq_base_offset + loc] =
merged_search_result->primary_keys_[ki];
new_merged_distances[nq_base_offset + loc] =
merged_search_result->distances_[ki];
if (need_handle_groupBy) {
new_merged_groupBy_vals[nq_base_offset + loc] =
merged_search_result->group_by_values_
.value()[ki];
}
merge_output_data_bases[nq_base_offset + loc] = {
&merged_search_result->output_fields_data_, ki};
new_result_offsets[nq_base_offset + loc] = loc;
real_topKs[qi]++;
}
}
}
}
new_merged_result->primary_keys_ = std::move(new_merged_pks);
new_merged_result->distances_ = std::move(new_merged_distances);
if (need_handle_groupBy) {
new_merged_result->group_by_values_ =
std::move(new_merged_groupBy_vals);
}
new_merged_result->topk_per_nq_prefix_sum_.resize(total_nq_ + 1);
std::partial_sum(
real_topKs.begin(),
real_topKs.end(),
new_merged_result->topk_per_nq_prefix_sum_.begin() + 1);
new_merged_result->result_offsets_ = std::move(new_result_offsets);
for (auto field_id : plan_->target_entries_) {
auto& field_meta = plan_->schema_[field_id];
auto field_data =
MergeDataArray(merge_output_data_bases, field_meta);
if (field_meta.get_data_type() == DataType::ARRAY) {
field_data->mutable_scalars()
->mutable_array_data()
->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
}
new_merged_result->output_fields_data_[field_id] =
std::move(field_data);
}
merged_search_result = std::move(new_merged_result);
merged_search_result->has_result_ = true;
}
}
void
StreamReducerHelper::MergeReduce() {
FilterSearchResults();
FillPrimaryKeys();
InitializeReduceRecords();
ReduceResultData();
RefreshSearchResult();
FillEntryData();
AssembleMergedResult();
CleanReduceStatus();
}
void*
StreamReducerHelper::SerializeMergedResult() {
std::unique_ptr<SearchResultDataBlobs> search_result_blobs =
std::make_unique<milvus::segcore::SearchResultDataBlobs>();
AssertInfo(num_slice_ > 0,
"Wrong state for num_slice in streamReducer, num_slice:{}",
num_slice_);
search_result_blobs->blobs.resize(num_slice_);
for (int i = 0; i < num_slice_; i++) {
auto proto = GetSearchResultDataSlice(i);
search_result_blobs->blobs[i] = proto;
}
return search_result_blobs.release();
}
void
StreamReducerHelper::ReduceResultData() {
if (search_results_to_merge_.size() > 0) {
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_to_merge_[i];
auto result_count = search_result->get_total_result_count();
AssertInfo(search_result != nullptr,
"search result must not equal to nullptr");
AssertInfo(search_result->distances_.size() == result_count,
"incorrect search result distance size");
AssertInfo(search_result->seg_offsets_.size() == result_count,
"incorrect search result seg offset size");
AssertInfo(search_result->primary_keys_.size() == result_count,
"incorrect search result primary key size");
}
for (int64_t slice_index = 0; slice_index < slice_nqs_.size();
slice_index++) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
int64_t offset = 0;
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
StreamReduceSearchResultForOneNQ(
qi, slice_topKs_[slice_index], offset);
}
}
}
}
void
StreamReducerHelper::FilterSearchResults() {
uint32_t valid_index = 0;
for (auto& search_result : search_results_to_merge_) {
// skip when results num is 0
AssertInfo(search_result != nullptr,
"search_result to merge cannot be nullptr, there must be "
"sth wrong in the code");
if (search_result->unity_topK_ == 0) {
continue;
}
FilterInvalidSearchResult(search_result);
search_results_to_merge_[valid_index++] = search_result;
}
search_results_to_merge_.resize(valid_index);
num_segments_ = search_results_to_merge_.size();
}
void
StreamReducerHelper::InitializeReduceRecords() {
// init final_search_records and final_read_topKs
if (merged_search_result->has_result_) {
final_search_records_.resize(num_segments_ + 1);
} else {
final_search_records_.resize(num_segments_);
}
for (auto& search_record : final_search_records_) {
search_record.resize(total_nq_);
}
}
void
StreamReducerHelper::FillPrimaryKeys() {
for (auto& search_result : search_results_to_merge_) {
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
if (search_result->get_total_result_count() > 0) {
segment->FillPrimaryKeys(plan_, *search_result);
}
}
}
void
StreamReducerHelper::FilterInvalidSearchResult(SearchResult* search_result) {
auto total_nq = search_result->total_nq_;
auto topK = search_result->unity_topK_;
AssertInfo(search_result->seg_offsets_.size() == total_nq * topK,
"wrong seg offsets size, size = " +
std::to_string(search_result->seg_offsets_.size()) +
", expected size = " + std::to_string(total_nq * topK));
AssertInfo(search_result->distances_.size() == total_nq * topK,
"wrong distances size, size = " +
std::to_string(search_result->distances_.size()) +
", expected size = " + std::to_string(total_nq * topK));
std::vector<int64_t> real_topKs(total_nq, 0);
uint32_t valid_index = 0;
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
auto& offsets = search_result->seg_offsets_;
auto& distances = search_result->distances_;
if (search_result->group_by_values_.has_value()) {
AssertInfo(search_result->distances_.size() ==
search_result->group_by_values_.value().size(),
"wrong group_by_values size, size:{}, expected size:{} ",
search_result->group_by_values_.value().size(),
search_result->distances_.size());
}
for (auto i = 0; i < total_nq; ++i) {
for (auto j = 0; j < topK; ++j) {
auto index = i * topK + j;
if (offsets[index] != INVALID_SEG_OFFSET) {
AssertInfo(0 <= offsets[index] &&
offsets[index] < segment->get_row_count(),
fmt::format("invalid offset {}, segment {} with "
"rows num {}, data or index corruption",
offsets[index],
segment->get_segment_id(),
segment->get_row_count()));
real_topKs[i]++;
offsets[valid_index] = offsets[index];
distances[valid_index] = distances[index];
if (search_result->group_by_values_.has_value())
search_result->group_by_values_.value()[valid_index] =
search_result->group_by_values_.value()[index];
valid_index++;
}
}
}
offsets.resize(valid_index);
distances.resize(valid_index);
if (search_result->group_by_values_.has_value())
search_result->group_by_values_.value().resize(valid_index);
search_result->topk_per_nq_prefix_sum_.resize(total_nq + 1);
std::partial_sum(real_topKs.begin(),
real_topKs.end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
void
StreamReducerHelper::StreamReduceSearchResultForOneNQ(int64_t qi,
int64_t topK,
int64_t& offset) {
//1. clear heap for preceding left elements
while (!heap_.empty()) {
heap_.pop();
}
pk_set_.clear();
group_by_val_set_.clear();
//2. push new search results into sort-heap
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_to_merge_[i];
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
if (offset_beg == offset_end) {
continue;
}
auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg];
if (search_result->group_by_values_.has_value()) {
AssertInfo(
search_result->group_by_values_.value().size() > offset_beg,
"Wrong size for group_by_values size to "
"ReduceSearchResultForOneNQ:{}, not enough for"
"required offset_beg:{}",
search_result->group_by_values_.value().size(),
offset_beg);
}
auto result_pair = std::make_shared<StreamSearchResultPair>(
primary_key,
distance,
search_result,
nullptr,
i,
offset_beg,
offset_end,
search_result->group_by_values_.has_value() &&
search_result->group_by_values_.value().size() > offset_beg
? std::make_optional(
search_result->group_by_values_.value().at(offset_beg))
: std::nullopt);
heap_.push(result_pair);
}
if (heap_.empty()) {
return;
}
//3. if the merged_search_result has previous data
//push merged search result into the heap
if (merged_search_result->has_result_) {
auto merged_off_begin =
merged_search_result->topk_per_nq_prefix_sum_[qi];
auto merged_off_end =
merged_search_result->topk_per_nq_prefix_sum_[qi + 1];
if (merged_off_end > merged_off_begin) {
auto merged_pk =
merged_search_result->primary_keys_[merged_off_begin];
auto merged_distance =
merged_search_result->distances_[merged_off_begin];
auto merged_result_pair = std::make_shared<StreamSearchResultPair>(
merged_pk,
merged_distance,
nullptr,
merged_search_result.get(),
num_segments_, //use last index as the merged segment idex
merged_off_begin,
merged_off_end,
merged_search_result->group_by_values_.has_value() &&
merged_search_result->group_by_values_.value().size() >
merged_off_begin
? std::make_optional(
merged_search_result->group_by_values_.value().at(
merged_off_begin))
: std::nullopt);
heap_.push(merged_result_pair);
}
}
//3. pop heap to sort
int count = 0;
while (count < topK && !heap_.empty()) {
auto pilot = heap_.top();
heap_.pop();
auto seg_index = pilot->segment_index_;
auto pk = pilot->primary_key_;
if (pk == INVALID_PK) {
break; // valid search result for this nq has been run out, break to next
}
if (pk_set_.count(pk) == 0) {
bool skip_for_group_by = false;
if (pilot->group_by_value_.has_value()) {
if (group_by_val_set_.count(pilot->group_by_value_.value()) >
0) {
skip_for_group_by = true;
}
}
if (!skip_for_group_by) {
final_search_records_[seg_index][qi].push_back(pilot->offset_);
if (pilot->search_result_ != nullptr) {
pilot->search_result_->result_offsets_.push_back(offset++);
} else {
merged_search_result->reduced_offsets_.push_back(offset++);
}
pk_set_.insert(pk);
if (pilot->group_by_value_.has_value()) {
group_by_val_set_.insert(pilot->group_by_value_.value());
}
count++;
}
}
pilot->advance();
if (pilot->primary_key_ != INVALID_PK) {
heap_.push(pilot);
}
}
}
void
StreamReducerHelper::RefreshSearchResult() {
//1. refresh new input results
for (int i = 0; i < num_segments_; i++) {
std::vector<int64_t> real_topKs(total_nq_, 0);
auto search_result = search_results_to_merge_[i];
if (search_result->result_offsets_.size() > 0) {
uint32_t final_size = 0;
for (int j = 0; j < total_nq_; j++) {
final_size += final_search_records_[i][j].size();
}
std::vector<milvus::PkType> reduced_pks(final_size);
std::vector<float> reduced_distances(final_size);
std::vector<int64_t> reduced_seg_offsets(final_size);
std::vector<GroupByValueType> reduced_group_by_values(final_size);
uint32_t final_index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[i][j]) {
reduced_pks[final_index] =
search_result->primary_keys_[offset];
reduced_distances[final_index] =
search_result->distances_[offset];
reduced_seg_offsets[final_index] =
search_result->seg_offsets_[offset];
if (search_result->group_by_values_.has_value())
reduced_group_by_values[final_index] =
search_result->group_by_values_.value()[offset];
final_index++;
real_topKs[j]++;
}
}
search_result->primary_keys_.swap(reduced_pks);
search_result->distances_.swap(reduced_distances);
search_result->seg_offsets_.swap(reduced_seg_offsets);
if (search_result->group_by_values_.has_value()) {
search_result->group_by_values_.value().swap(
reduced_group_by_values);
}
}
std::partial_sum(real_topKs.begin(),
real_topKs.end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
//2. refresh merged search result possibly
if (merged_search_result->has_result_) {
std::vector<int64_t> real_topKs(total_nq_, 0);
if (merged_search_result->reduced_offsets_.size() > 0) {
uint32_t final_size = merged_search_result->reduced_offsets_.size();
std::vector<milvus::PkType> reduced_pks(final_size);
std::vector<float> reduced_distances(final_size);
std::vector<int64_t> reduced_seg_offsets(final_size);
std::vector<GroupByValueType> reduced_group_by_values(final_size);
uint32_t final_index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[num_segments_][j]) {
reduced_pks[final_index] =
merged_search_result->primary_keys_[offset];
reduced_distances[final_index] =
merged_search_result->distances_[offset];
if (merged_search_result->group_by_values_.has_value())
reduced_group_by_values[final_index] =
merged_search_result->group_by_values_
.value()[offset];
final_index++;
real_topKs[j]++;
}
}
merged_search_result->primary_keys_.swap(reduced_pks);
merged_search_result->distances_.swap(reduced_distances);
if (merged_search_result->group_by_values_.has_value()) {
merged_search_result->group_by_values_.value().swap(
reduced_group_by_values);
}
}
std::partial_sum(
real_topKs.begin(),
real_topKs.end(),
merged_search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
}
std::vector<char>
StreamReducerHelper::GetSearchResultDataSlice(int slice_index) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
auto search_result_data =
std::make_unique<milvus::proto::schema::SearchResultData>();
// set unify_topK and total_nq
search_result_data->set_top_k(slice_topKs_[slice_index]);
search_result_data->set_num_queries(nq_end - nq_begin);
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
int64_t result_count = 0;
if (merged_search_result->has_result_) {
AssertInfo(
nq_begin < merged_search_result->topk_per_nq_prefix_sum_.size(),
"nq_begin is incorrect for reduce, nq_begin:{}, topk_size:{}",
nq_begin,
merged_search_result->topk_per_nq_prefix_sum_.size());
AssertInfo(
nq_end < merged_search_result->topk_per_nq_prefix_sum_.size(),
"nq_end is incorrect for reduce, nq_end:{}, topk_size:{}",
nq_end,
merged_search_result->topk_per_nq_prefix_sum_.size());
result_count = merged_search_result->topk_per_nq_prefix_sum_[nq_end] -
merged_search_result->topk_per_nq_prefix_sum_[nq_begin];
}
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
std::vector<MergeBase> result_pairs(result_count);
// reserve space for pks
auto primary_field_id =
plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
AssertInfo(primary_field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
auto pk_type = plan_->schema_[primary_field_id].get_data_type();
switch (pk_type) {
case milvus::DataType::INT64: {
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
ids->mutable_data()->Resize(result_count, 0);
search_result_data->mutable_ids()->set_allocated_int_id(
ids.release());
break;
}
case milvus::DataType::VARCHAR: {
auto ids = std::make_unique<milvus::proto::schema::StringArray>();
std::vector<std::string> string_pks(result_count);
// TODO: prevent mem copy
*ids->mutable_data() = {string_pks.begin(), string_pks.end()};
search_result_data->mutable_ids()->set_allocated_str_id(
ids.release());
break;
}
default: {
PanicInfo(DataTypeInvalid,
fmt::format("unsupported primary key type {}", pk_type));
}
}
// reserve space for distances
search_result_data->mutable_scores()->Resize(result_count, 0);
//reserve space for group_by_values
std::vector<GroupByValueType> group_by_values;
if (plan_->plan_node_->search_info_.group_by_field_id_.has_value()) {
group_by_values.resize(result_count);
}
// fill pks and distances
for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0;
AssertInfo(merged_search_result != nullptr,
"null merged search result when reorganize");
if (!merged_search_result->has_result_ ||
merged_search_result->result_offsets_.size() == 0) {
continue;
}
auto topk_start = merged_search_result->topk_per_nq_prefix_sum_[qi];
auto topk_end = merged_search_result->topk_per_nq_prefix_sum_[qi + 1];
topk_count += topk_end - topk_start;
for (auto ki = topk_start; ki < topk_end; ki++) {
auto loc = merged_search_result->result_offsets_[ki];
AssertInfo(loc < result_count && loc >= 0,
"invalid loc when GetSearchResultDataSlice, loc = " +
std::to_string(loc) +
", result_count = " + std::to_string(result_count));
// set result pks
switch (pk_type) {
case milvus::DataType::INT64: {
search_result_data->mutable_ids()
->mutable_int_id()
->mutable_data()
->Set(loc,
std::visit(
Int64PKVisitor{},
merged_search_result->primary_keys_[ki]));
break;
}
case milvus::DataType::VARCHAR: {
*search_result_data->mutable_ids()
->mutable_str_id()
->mutable_data()
->Mutable(loc) =
std::visit(StrPKVisitor{},
merged_search_result->primary_keys_[ki]);
break;
}
default: {
PanicInfo(DataTypeInvalid,
fmt::format("unsupported primary key type {}",
pk_type));
}
}
search_result_data->mutable_scores()->Set(
loc, merged_search_result->distances_[ki]);
// set group by values
if (merged_search_result->group_by_values_.has_value() &&
ki < merged_search_result->group_by_values_.value().size())
group_by_values[loc] =
merged_search_result->group_by_values_.value()[ki];
// set result offset to fill output fields data
result_pairs[loc] = {&merged_search_result->output_fields_data_,
ki};
}
// update result topKs
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
}
AssembleGroupByValues(search_result_data, group_by_values, plan_);
AssertInfo(search_result_data->scores_size() == result_count,
"wrong scores size, size = " +
std::to_string(search_result_data->scores_size()) +
", expected size = " + std::to_string(result_count));
// set output fields
for (auto field_id : plan_->target_entries_) {
auto& field_meta = plan_->schema_[field_id];
auto field_data =
milvus::segcore::MergeDataArray(result_pairs, field_meta);
if (field_meta.get_data_type() == DataType::ARRAY) {
field_data->mutable_scalars()
->mutable_array_data()
->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
}
search_result_data->mutable_fields_data()->AddAllocated(
field_data.release());
}
// SearchResultData to blob
auto size = search_result_data->ByteSizeLong();
auto buffer = std::vector<char>(size);
search_result_data->SerializePartialToArray(buffer.data(), size);
return buffer;
}
void
StreamReducerHelper::CleanReduceStatus() {
this->final_search_records_.clear();
this->merged_search_result->reduced_offsets_.clear();
}
} // namespace milvus::segcore

View File

@ -0,0 +1,222 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <queue>
#include <unordered_set>
#include "common/Types.h"
#include "segcore/segment_c.h"
#include "query/PlanImpl.h"
#include "common/QueryResult.h"
#include "segcore/ReduceStructure.h"
#include "common/EasyAssert.h"
namespace milvus::segcore {
class MergedSearchResult {
public:
bool has_result_;
std::vector<PkType> primary_keys_;
std::vector<float> distances_;
std::optional<std::vector<GroupByValueType>> group_by_values_;
// set output fields data when filling target entity
std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data_;
// used for reduce, filter invalid pk, get real topks count
std::vector<size_t> topk_per_nq_prefix_sum_;
// fill data during reducing search result
std::vector<int64_t> result_offsets_;
std::vector<int64_t> reduced_offsets_;
};
struct StreamSearchResultPair {
milvus::PkType primary_key_;
float distance_;
milvus::SearchResult* search_result_;
MergedSearchResult* merged_result_;
int64_t segment_index_;
int64_t offset_;
int64_t offset_rb_;
std::optional<milvus::GroupByValueType> group_by_value_;
StreamSearchResultPair(milvus::PkType primary_key,
float distance,
SearchResult* result,
int64_t index,
int64_t lb,
int64_t rb)
: StreamSearchResultPair(primary_key,
distance,
result,
nullptr,
index,
lb,
rb,
std::nullopt) {
}
StreamSearchResultPair(
milvus::PkType primary_key,
float distance,
SearchResult* result,
MergedSearchResult* merged_result,
int64_t index,
int64_t lb,
int64_t rb,
std::optional<milvus::GroupByValueType> group_by_value)
: primary_key_(std::move(primary_key)),
distance_(distance),
search_result_(result),
merged_result_(merged_result),
segment_index_(index),
offset_(lb),
offset_rb_(rb),
group_by_value_(group_by_value) {
AssertInfo(
search_result_ != nullptr || merged_result_ != nullptr,
"For a valid StreamSearchResult pair, "
"at least one of merged_result_ or search_result_ is not nullptr");
}
bool
operator>(const StreamSearchResultPair& other) const {
if (std::fabs(distance_ - other.distance_) < 0.0000000119) {
return primary_key_ < other.primary_key_;
}
return distance_ > other.distance_;
}
void
advance() {
offset_++;
if (offset_ < offset_rb_) {
if (search_result_ != nullptr) {
primary_key_ = search_result_->primary_keys_.at(offset_);
distance_ = search_result_->distances_.at(offset_);
if (search_result_->group_by_values_.has_value() &&
offset_ < search_result_->group_by_values_.value().size()) {
group_by_value_ =
search_result_->group_by_values_.value().at(offset_);
}
} else {
primary_key_ = merged_result_->primary_keys_.at(offset_);
distance_ = merged_result_->distances_.at(offset_);
if (merged_result_->group_by_values_.has_value() &&
offset_ < merged_result_->group_by_values_.value().size()) {
group_by_value_ =
merged_result_->group_by_values_.value().at(offset_);
}
}
} else {
primary_key_ = INVALID_PK;
distance_ = std::numeric_limits<float>::min();
}
}
};
struct StreamSearchResultPairComparator {
bool
operator()(const std::shared_ptr<StreamSearchResultPair> lhs,
const std::shared_ptr<StreamSearchResultPair> rhs) const {
return (*rhs.get()) > (*lhs.get());
}
};
class StreamReducerHelper {
public:
explicit StreamReducerHelper(milvus::query::Plan* plan,
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t slice_num)
: plan_(plan),
slice_nqs_(slice_nqs, slice_nqs + slice_num),
slice_topKs_(slice_topKs, slice_topKs + slice_num) {
AssertInfo(slice_nqs_.size() > 0, "empty_nqs");
AssertInfo(slice_nqs_.size() == slice_topKs_.size(),
"unaligned slice_nqs and slice_topKs");
merged_search_result = std::make_unique<MergedSearchResult>();
merged_search_result->has_result_ = false;
num_slice_ = slice_nqs_.size();
slice_nqs_prefix_sum_.resize(num_slice_ + 1);
std::partial_sum(slice_nqs_.begin(),
slice_nqs_.end(),
slice_nqs_prefix_sum_.begin() + 1);
total_nq_ = slice_nqs_prefix_sum_[num_slice_];
}
void
SetSearchResultsToMerge(std::vector<SearchResult*>& search_results) {
search_results_to_merge_ = search_results;
num_segments_ = search_results_to_merge_.size();
AssertInfo(num_segments_ > 0, "empty search result");
}
public:
void
MergeReduce();
void*
SerializeMergedResult();
protected:
void
FilterSearchResults();
void
InitializeReduceRecords();
void
FillPrimaryKeys();
void
FilterInvalidSearchResult(SearchResult* search_result);
void
ReduceResultData();
private:
void
RefreshSearchResult();
void
StreamReduceSearchResultForOneNQ(int64_t qi, int64_t topK, int64_t& offset);
void
FillEntryData();
void
AssembleMergedResult();
std::vector<char>
GetSearchResultDataSlice(int slice_index);
void
CleanReduceStatus();
std::unique_ptr<MergedSearchResult> merged_search_result;
milvus::query::Plan* plan_;
std::vector<int64_t> slice_nqs_;
std::vector<int64_t> slice_topKs_;
std::vector<SearchResult*> search_results_to_merge_;
int64_t num_segments_{0};
int64_t num_slice_{0};
std::vector<int64_t> slice_nqs_prefix_sum_;
std::priority_queue<std::shared_ptr<StreamSearchResultPair>,
std::vector<std::shared_ptr<StreamSearchResultPair>>,
StreamSearchResultPairComparator>
heap_;
std::unordered_set<milvus::PkType> pk_set_;
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
int64_t total_nq_{0};
};
} // namespace milvus::segcore

View File

@ -530,19 +530,17 @@ CreateDataArrayFrom(const void* data_raw,
// TODO remove merge dataArray, instead fill target entity when get data slice
std::unique_ptr<DataArray>
MergeDataArray(
std::vector<std::pair<milvus::SearchResult*, int64_t>>& result_offsets,
const FieldMeta& field_meta) {
MergeDataArray(std::vector<MergeBase>& merge_bases,
const FieldMeta& field_meta) {
auto data_type = field_meta.get_data_type();
auto data_array = std::make_unique<DataArray>();
data_array->set_field_id(field_meta.get_id().get());
data_array->set_type(static_cast<milvus::proto::schema::DataType>(
field_meta.get_data_type()));
for (auto& result_pair : result_offsets) {
auto src_field_data =
result_pair.first->output_fields_data_[field_meta.get_id()].get();
auto src_offset = result_pair.second;
for (auto& merge_base : merge_bases) {
auto src_field_data = merge_base.get_field_data(field_meta.get_id());
auto src_offset = merge_base.getOffset();
AssertInfo(data_type == DataType(src_field_data->type()),
"merge field data type not consistent");
if (field_meta.is_vector()) {
@ -650,7 +648,6 @@ MergeDataArray(
}
}
}
return data_array;
}

View File

@ -77,10 +77,35 @@ CreateDataArrayFrom(const void* data_raw,
const FieldMeta& field_meta);
// TODO remove merge dataArray, instead fill target entity when get data slice
struct MergeBase {
private:
std::map<FieldId, std::unique_ptr<milvus::DataArray>>* output_fields_data_;
size_t offset_;
public:
MergeBase() {
}
MergeBase(std::map<FieldId, std::unique_ptr<milvus::DataArray>>*
output_fields_data,
size_t offset)
: output_fields_data_(output_fields_data), offset_(offset) {
}
size_t
getOffset() const {
return offset_;
}
milvus::DataArray*
get_field_data(FieldId fieldId) const {
return (*output_fields_data_)[fieldId].get();
}
};
std::unique_ptr<DataArray>
MergeDataArray(
std::vector<std::pair<milvus::SearchResult*, int64_t>>& result_offsets,
const FieldMeta& field_meta);
MergeDataArray(std::vector<MergeBase>& merge_bases,
const FieldMeta& field_meta);
template <bool is_sealed>
std::shared_ptr<DeletedRecord::TmpBitmap>

View File

@ -24,7 +24,7 @@ struct Int64PKVisitor {
};
template <>
int64_t
inline int64_t
Int64PKVisitor::operator()<int64_t>(int64_t t) const {
return t;
}
@ -38,7 +38,7 @@ struct StrPKVisitor {
};
template <>
std::string
inline std::string
StrPKVisitor::operator()<std::string>(std::string t) const {
return t;
}

View File

@ -15,10 +15,64 @@
#include "common/EasyAssert.h"
#include "query/Plan.h"
#include "segcore/reduce_c.h"
#include "segcore/StreamReduce.h"
#include "segcore/Utils.h"
using SearchResult = milvus::SearchResult;
CStatus
NewStreamReducer(CSearchPlan c_plan,
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t num_slices,
CSearchStreamReducer* stream_reducer) {
try {
//convert search results and search plan
auto plan = static_cast<milvus::query::Plan*>(c_plan);
auto stream_reduce_helper =
std::make_unique<milvus::segcore::StreamReducerHelper>(
plan, slice_nqs, slice_topKs, num_slices);
*stream_reducer = stream_reduce_helper.release();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);
}
}
CStatus
StreamReduce(CSearchStreamReducer c_stream_reducer,
CSearchResult* c_search_results,
int64_t num_segments) {
try {
auto stream_reducer =
static_cast<milvus::segcore::StreamReducerHelper*>(
c_stream_reducer);
std::vector<SearchResult*> search_results(num_segments);
for (int i = 0; i < num_segments; i++) {
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
}
stream_reducer->SetSearchResultsToMerge(search_results);
stream_reducer->MergeReduce();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);
}
}
CStatus
GetStreamReduceResult(CSearchStreamReducer c_stream_reducer,
CSearchResultDataBlobs* c_search_result_data_blobs) {
try {
auto stream_reducer =
static_cast<milvus::segcore::StreamReducerHelper*>(
c_stream_reducer);
*c_search_result_data_blobs = stream_reducer->SerializeMergedResult();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);
}
}
CStatus
ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchPlan c_plan,
@ -81,3 +135,13 @@ DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs) {
cSearchResultDataBlobs);
delete search_result_data_blobs;
}
void
DeleteStreamSearchReducer(CSearchStreamReducer c_stream_reducer) {
if (c_stream_reducer == nullptr) {
return;
}
auto stream_reducer =
static_cast<milvus::segcore::StreamReducerHelper*>(c_stream_reducer);
delete stream_reducer;
}

View File

@ -18,6 +18,23 @@ extern "C" {
#include "segcore/segment_c.h"
typedef void* CSearchResultDataBlobs;
typedef void* CSearchStreamReducer;
CStatus
NewStreamReducer(CSearchPlan c_plan,
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t num_slices,
CSearchStreamReducer* stream_reducer);
CStatus
StreamReduce(CSearchStreamReducer c_stream_reducer,
CSearchResult* c_search_results,
int64_t num_segments);
CStatus
GetStreamReduceResult(CSearchStreamReducer c_stream_reducer,
CSearchResultDataBlobs* c_search_result_data_blobs);
CStatus
ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
@ -36,6 +53,9 @@ GetSearchResultDataBlob(CProto* searchResultDataBlob,
void
DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs);
void
DeleteStreamSearchReducer(CSearchStreamReducer c_stream_reducer);
#ifdef __cplusplus
}
#endif

View File

@ -476,3 +476,10 @@ WarmupChunkCache(CSegmentInterface c_segment, int64_t field_id) {
return milvus::FailureCStatus(milvus::UnexpectedError, e.what());
}
}
void
RemoveFieldFile(CSegmentInterface c_segment, int64_t field_id) {
auto segment =
reinterpret_cast<milvus::segcore::SegmentSealedImpl*>(c_segment);
segment->RemoveFieldFile(milvus::FieldId(field_id));
}

View File

@ -156,6 +156,9 @@ Delete(CSegmentInterface c_segment,
const uint64_t ids_size,
const uint64_t* timestamps);
void
RemoveFieldFile(CSegmentInterface c_segment, int64_t field_id);
#ifdef __cplusplus
}
#endif

View File

@ -103,17 +103,28 @@ DeserializeFileData(const std::shared_ptr<uint8_t[]> input_data,
int64_t length) {
auto binlog_reader = std::make_shared<BinlogReader>(input_data, length);
auto medium_type = ReadMediumType(binlog_reader);
auto start_deserialize = std::chrono::system_clock::now();
std::unique_ptr<DataCodec> res;
switch (medium_type) {
case StorageType::Remote: {
return DeserializeRemoteFileData(binlog_reader);
res = DeserializeRemoteFileData(binlog_reader);
break;
}
case StorageType::LocalDisk: {
return DeserializeLocalFileData(binlog_reader);
res = DeserializeLocalFileData(binlog_reader);
break;
}
default:
PanicInfo(DataFormatBroken,
fmt::format("unsupported medium type {}", medium_type));
}
auto deserialize_duration =
std::chrono::system_clock::now() - start_deserialize;
LOG_INFO("DeserializeFileData_deserialize_duration_ms:{}",
std::chrono::duration_cast<std::chrono::milliseconds>(
deserialize_duration)
.count());
return res;
}
} // namespace milvus::storage

View File

@ -148,6 +148,24 @@ DEFINE_PROMETHEUS_COUNTER(internal_storage_op_count_remove_fail,
internal_storage_op_count,
removeFailMap)
//load metrics
std::map<std::string, std::string> downloadDurationLabels{{"type", "download"}};
std::map<std::string, std::string> writeDiskDurationLabels{
{"type", "write_disk"}};
std::map<std::string, std::string> deserializeDurationLabels{
{"type", "deserialize"}};
DEFINE_PROMETHEUS_HISTOGRAM_FAMILY(internal_storage_load_duration,
"[cpp]durations of load segment")
DEFINE_PROMETHEUS_HISTOGRAM(internal_storage_download_duration,
internal_storage_load_duration,
downloadDurationLabels)
DEFINE_PROMETHEUS_HISTOGRAM(internal_storage_write_disk_duration,
internal_storage_load_duration,
writeDiskDurationLabels)
DEFINE_PROMETHEUS_HISTOGRAM(internal_storage_deserialize_duration,
internal_storage_load_duration,
deserializeDurationLabels)
// mmap metrics
std::map<std::string, std::string> mmapAllocatedSpaceAnonLabel = {
{"type", "anon"}};

View File

@ -115,6 +115,11 @@ DECLARE_PROMETHEUS_COUNTER(internal_storage_op_count_list_fail);
DECLARE_PROMETHEUS_COUNTER(internal_storage_op_count_remove_suc);
DECLARE_PROMETHEUS_COUNTER(internal_storage_op_count_remove_fail);
DECLARE_PROMETHEUS_HISTOGRAM_FAMILY(internal_storage_load_duration);
DECLARE_PROMETHEUS_HISTOGRAM(internal_storage_download_duration);
DECLARE_PROMETHEUS_HISTOGRAM(internal_storage_write_disk_duration);
DECLARE_PROMETHEUS_HISTOGRAM(internal_storage_deserialize_duration);
// mmap metrics
DECLARE_PROMETHEUS_HISTOGRAM_FAMILY(internal_mmap_allocated_space_bytes);
DECLARE_PROMETHEUS_HISTOGRAM(internal_mmap_allocated_space_bytes_anon);
@ -122,4 +127,5 @@ DECLARE_PROMETHEUS_HISTOGRAM(internal_mmap_allocated_space_bytes_file);
DECLARE_PROMETHEUS_GAUGE_FAMILY(internal_mmap_in_used_space_bytes);
DECLARE_PROMETHEUS_GAUGE(internal_mmap_in_used_space_bytes_anon);
DECLARE_PROMETHEUS_GAUGE(internal_mmap_in_used_space_bytes_file);
} // namespace milvus::storage

View File

@ -26,6 +26,7 @@ set(MILVUS_TEST_FILES
test_concurrent_vector.cpp
test_c_api.cpp
test_expr_materialized_view.cpp
test_c_stream_reduce.cpp
test_expr.cpp
test_float16.cpp
test_growing.cpp

View File

@ -43,6 +43,7 @@
#include "plan/PlanNode.h"
#include "exec/expression/Expr.h"
#include "segcore/load_index_c.h"
#include "test_utils/c_api_test_utils.h"
namespace chrono = std::chrono;
@ -59,16 +60,6 @@ namespace {
const int64_t ROW_COUNT = 10 * 1000;
const int64_t BIAS = 4200;
CStatus
CSearch(CSegmentInterface c_segment,
CSearchPlan c_plan,
CPlaceholderGroup c_placeholder_group,
uint64_t timestamp,
CSearchResult* result) {
return Search(
{}, c_segment, c_plan, c_placeholder_group, timestamp, result);
}
CStatus
CRetrieve(CSegmentInterface c_segment,
CRetrievePlan c_plan,
@ -83,32 +74,6 @@ CRetrieve(CSegmentInterface c_segment,
false);
}
const char*
get_default_schema_config() {
static std::string conf = R"(name: "default-collection"
fields: <
fieldID: 100
name: "fakevec"
data_type: FloatVector
type_params: <
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
fieldID: 101
name: "age"
data_type: Int64
is_primary_key: true
>)";
static std::string fake_conf = "";
return conf.c_str();
}
const char*
get_float16_schema_config() {
static std::string conf = R"(name: "float16-collection"
@ -212,52 +177,6 @@ generate_data(int N) {
}
return std::make_tuple(raw_data, timestamps, uids);
}
std::string
generate_max_float_query_data(int all_nq, int max_float_nq) {
assert(max_float_nq <= all_nq);
namespace ser = milvus::proto::common;
int dim = DIM;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
for (int i = 0; i < all_nq; ++i) {
std::vector<float> vec;
if (i < max_float_nq) {
for (int d = 0; d < dim; ++d) {
vec.push_back(std::numeric_limits<float>::max());
}
} else {
for (int d = 0; d < dim; ++d) {
vec.push_back(1);
}
}
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
return blob;
}
std::string
generate_query_data(int nq) {
namespace ser = milvus::proto::common;
std::default_random_engine e(67);
int dim = DIM;
std::normal_distribution<double> dis(0.0, 1.0);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
for (int i = 0; i < nq; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
return blob;
}
std::string
generate_query_data_float16(int nq) {
@ -300,7 +219,7 @@ generate_query_data_bfloat16(int nq) {
auto blob = raw_group.SerializeAsString();
return blob;
}
// 创建枚举,包含schema::DataType::BinaryVectorschema::DataType::FloatVector
// create Enum for schema::DataType::BinaryVectorschema::DataType::FloatVector
enum VectorType {
BinaryVector = 0,
FloatVector = 1,
@ -1558,27 +1477,6 @@ TEST(CApiTest, GetRealCount) {
DeleteSegment(segment);
}
void
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
auto nq = ((SearchResult*)results[0])->total_nq_;
std::unordered_set<PkType> pk_set;
for (int qi = 0; qi < nq; qi++) {
pk_set.clear();
for (size_t i = 0; i < results.size(); i++) {
auto search_result = (SearchResult*)results[i];
ASSERT_EQ(nq, search_result->total_nq_);
auto topk_beg = search_result->topk_per_nq_prefix_sum_[qi];
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
for (size_t ki = topk_beg; ki < topk_end; ki++) {
ASSERT_NE(search_result->seg_offsets_[ki], INVALID_SEG_OFFSET);
auto ret = pk_set.insert(search_result->primary_keys_[ki]);
ASSERT_TRUE(ret.second);
}
}
}
}
TEST(CApiTest, ReduceNullResult) {
auto collection = NewCollection(get_default_schema_config());
CSegmentInterface segment;

View File

@ -0,0 +1,324 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include "test_utils/DataGen.h"
#include "test_utils/c_api_test_utils.h"
TEST(CApiTest, StreamReduce) {
int N = 300;
int topK = 100;
int num_queries = 2;
auto collection = NewCollection(get_default_schema_config());
//1. set up segments
CSegmentInterface segment_1;
auto status = NewSegment(collection, Growing, -1, &segment_1);
ASSERT_EQ(status.error_code, Success);
CSegmentInterface segment_2;
status = NewSegment(collection, Growing, -1, &segment_2);
ASSERT_EQ(status.error_code, Success);
//2. insert data into segments
auto schema = ((milvus::segcore::Collection*)collection)->get_schema();
auto dataset_1 = DataGen(schema, N, 55, 0, 1, 10, true);
int64_t offset_1;
PreInsert(segment_1, N, &offset_1);
auto insert_data_1 = serialize(dataset_1.raw_);
auto ins_res_1 = Insert(segment_1,
offset_1,
N,
dataset_1.row_ids_.data(),
dataset_1.timestamps_.data(),
insert_data_1.data(),
insert_data_1.size());
ASSERT_EQ(ins_res_1.error_code, Success);
auto dataset_2 = DataGen(schema, N, 66, 0, 1, 10, true);
int64_t offset_2;
PreInsert(segment_2, N, &offset_2);
auto insert_data_2 = serialize(dataset_2.raw_);
auto ins_res_2 = Insert(segment_2,
offset_2,
N,
dataset_2.row_ids_.data(),
dataset_2.timestamps_.data(),
insert_data_2.data(),
insert_data_2.size());
ASSERT_EQ(ins_res_2.error_code, Success);
//3. search two segments
auto fmt = boost::format(R"(vector_anns: <
field_id: 100
query_info: <
topk: %1%
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0">
output_field_ids: 100)") %
topK;
auto serialized_expr_plan = fmt.str();
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto binary_plan =
translate_text_plan_to_binary_plan(serialized_expr_plan.data());
status = CreateSearchPlanByExpr(
collection, binary_plan.data(), binary_plan.size(), &plan);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(
plan, blob.data(), blob.length(), &placeholderGroup);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
dataset_1.timestamps_.clear();
dataset_1.timestamps_.push_back(1);
dataset_2.timestamps_.clear();
dataset_2.timestamps_.push_back(1);
CSearchResult res1;
CSearchResult res2;
auto stats1 = CSearch(
segment_1, plan, placeholderGroup, dataset_1.timestamps_[N - 1], &res1);
ASSERT_EQ(stats1.error_code, Success);
auto stats2 = CSearch(
segment_2, plan, placeholderGroup, dataset_2.timestamps_[N - 1], &res2);
ASSERT_EQ(stats2.error_code, Success);
//4. stream reduce two search results
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
if (num_queries == 1) {
slice_nqs = std::vector<int64_t>{num_queries};
}
auto slice_topKs = std::vector<int64_t>{topK, topK};
if (topK == 1) {
slice_topKs = std::vector<int64_t>{topK, topK};
}
//5. set up stream reducer
CSearchStreamReducer c_search_stream_reducer;
NewStreamReducer(plan,
slice_nqs.data(),
slice_topKs.data(),
slice_nqs.size(),
&c_search_stream_reducer);
StreamReduce(c_search_stream_reducer, &res1, 1);
StreamReduce(c_search_stream_reducer, &res2, 1);
CSearchResultDataBlobs c_search_result_data_blobs;
GetStreamReduceResult(c_search_stream_reducer, &c_search_result_data_blobs);
SearchResultDataBlobs* search_result_data_blob =
(SearchResultDataBlobs*)(c_search_result_data_blobs);
//6. check
for (size_t i = 0; i < slice_nqs.size(); i++) {
milvus::proto::schema::SearchResultData search_result_data;
auto suc = search_result_data.ParseFromArray(
search_result_data_blob->blobs[i].data(),
search_result_data_blob->blobs[i].size());
ASSERT_TRUE(suc);
ASSERT_EQ(search_result_data.num_queries(), slice_nqs[i]);
ASSERT_EQ(search_result_data.top_k(), slice_topKs[i]);
ASSERT_EQ(search_result_data.ids().int_id().data_size(),
search_result_data.topks().at(0) * slice_nqs[i]);
ASSERT_EQ(search_result_data.scores().size(),
search_result_data.topks().at(0) * slice_nqs[i]);
ASSERT_EQ(search_result_data.topks().size(), slice_nqs[i]);
for (auto real_topk : search_result_data.topks()) {
ASSERT_LE(real_topk, slice_topKs[i]);
}
}
DeleteSearchResultDataBlobs(c_search_result_data_blobs);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(res1);
DeleteSearchResult(res2);
DeleteCollection(collection);
DeleteSegment(segment_1);
DeleteSegment(segment_2);
DeleteStreamSearchReducer(c_search_stream_reducer);
DeleteStreamSearchReducer(nullptr);
}
TEST(CApiTest, StreamReduceGroupBY) {
int N = 300;
int topK = 100;
int num_queries = 2;
int dim = 16;
namespace schema = milvus::proto::schema;
void* c_collection;
//1. set up schema and collection
{
schema::CollectionSchema collection_schema;
auto pk_field_schema = collection_schema.add_fields();
pk_field_schema->set_name("pk_field");
pk_field_schema->set_fieldid(100);
pk_field_schema->set_data_type(schema::DataType::Int64);
pk_field_schema->set_is_primary_key(true);
auto i8_field_schema = collection_schema.add_fields();
i8_field_schema->set_name("int8_field");
i8_field_schema->set_fieldid(101);
i8_field_schema->set_data_type(schema::DataType::Int8);
i8_field_schema->set_is_primary_key(false);
auto i16_field_schema = collection_schema.add_fields();
i16_field_schema->set_name("int16_field");
i16_field_schema->set_fieldid(102);
i16_field_schema->set_data_type(schema::DataType::Int16);
i16_field_schema->set_is_primary_key(false);
auto i32_field_schema = collection_schema.add_fields();
i32_field_schema->set_name("int32_field");
i32_field_schema->set_fieldid(103);
i32_field_schema->set_data_type(schema::DataType::Int32);
i32_field_schema->set_is_primary_key(false);
auto str_field_schema = collection_schema.add_fields();
str_field_schema->set_name("str_field");
str_field_schema->set_fieldid(104);
str_field_schema->set_data_type(schema::DataType::VarChar);
auto str_type_params = str_field_schema->add_type_params();
str_type_params->set_key(MAX_LENGTH);
str_type_params->set_value(std::to_string(64));
str_field_schema->set_is_primary_key(false);
auto vec_field_schema = collection_schema.add_fields();
vec_field_schema->set_name("fake_vec");
vec_field_schema->set_fieldid(105);
vec_field_schema->set_data_type(schema::DataType::FloatVector);
auto metric_type_param = vec_field_schema->add_index_params();
metric_type_param->set_key("metric_type");
metric_type_param->set_value(knowhere::metric::L2);
auto dim_param = vec_field_schema->add_type_params();
dim_param->set_key("dim");
dim_param->set_value(std::to_string(dim));
c_collection = NewCollection(&collection_schema, knowhere::metric::L2);
}
CSegmentInterface segment;
auto status = NewSegment(c_collection, Growing, -1, &segment);
ASSERT_EQ(status.error_code, Success);
//2. generate data and insert
auto c_schema = ((milvus::segcore::Collection*)c_collection)->get_schema();
auto dataset = DataGen(c_schema, N);
int64_t offset;
PreInsert(segment, N, &offset);
auto insert_data = serialize(dataset.raw_);
auto ins_res = Insert(segment,
offset,
N,
dataset.row_ids_.data(),
dataset.timestamps_.data(),
insert_data.data(),
insert_data.size());
ASSERT_EQ(ins_res.error_code, Success);
//3. search
auto fmt = boost::format(R"(vector_anns: <
field_id: 105
query_info: <
topk: %1%
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
group_by_field_id: 101
>
placeholder_tag: "$0">
output_field_ids: 100)") %
topK;
auto serialized_expr_plan = fmt.str();
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto binary_plan =
translate_text_plan_to_binary_plan(serialized_expr_plan.data());
status = CreateSearchPlanByExpr(
c_collection, binary_plan.data(), binary_plan.size(), &plan);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(
plan, blob.data(), blob.length(), &placeholderGroup);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
dataset.timestamps_.clear();
dataset.timestamps_.push_back(1);
CSearchResult res1;
CSearchResult res2;
auto res = CSearch(
segment, plan, placeholderGroup, dataset.timestamps_[N - 1], &res1);
ASSERT_EQ(res.error_code, Success);
res = CSearch(
segment, plan, placeholderGroup, dataset.timestamps_[N - 1], &res2);
ASSERT_EQ(res.error_code, Success);
//4. set up stream reducer
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
if (num_queries == 1) {
slice_nqs = std::vector<int64_t>{num_queries};
}
auto slice_topKs = std::vector<int64_t>{topK, topK};
if (topK == 1) {
slice_topKs = std::vector<int64_t>{topK, topK};
}
CSearchStreamReducer c_search_stream_reducer;
NewStreamReducer(plan,
slice_nqs.data(),
slice_topKs.data(),
slice_nqs.size(),
&c_search_stream_reducer);
//5. stream reduce
StreamReduce(c_search_stream_reducer, &res1, 1);
StreamReduce(c_search_stream_reducer, &res2, 1);
CSearchResultDataBlobs c_search_result_data_blobs;
GetStreamReduceResult(c_search_stream_reducer, &c_search_result_data_blobs);
SearchResultDataBlobs* search_result_data_blob =
(SearchResultDataBlobs*)(c_search_result_data_blobs);
//6. check result
for (size_t i = 0; i < slice_nqs.size(); i++) {
milvus::proto::schema::SearchResultData search_result_data;
auto suc = search_result_data.ParseFromArray(
search_result_data_blob->blobs[i].data(),
search_result_data_blob->blobs[i].size());
ASSERT_TRUE(suc);
ASSERT_EQ(search_result_data.num_queries(), slice_nqs[i]);
ASSERT_EQ(search_result_data.top_k(), slice_topKs[i]);
ASSERT_EQ(search_result_data.ids().int_id().data_size(),
search_result_data.topks().at(0) * slice_nqs[i]);
ASSERT_EQ(search_result_data.scores().size(),
search_result_data.topks().at(0) * slice_nqs[i]);
ASSERT_TRUE(search_result_data.has_group_by_field_value());
// check real topks
ASSERT_EQ(search_result_data.topks().size(), slice_nqs[i]);
for (auto real_topk : search_result_data.topks()) {
ASSERT_LE(real_topk, slice_topKs[i]);
}
}
DeleteSearchResultDataBlobs(c_search_result_data_blobs);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(res1);
DeleteSearchResult(res2);
DeleteCollection(c_collection);
DeleteSegment(segment);
DeleteStreamSearchReducer(c_search_stream_reducer);
DeleteStreamSearchReducer(nullptr);
}

View File

@ -234,7 +234,8 @@ struct GeneratedData {
uint64_t seed,
uint64_t ts_offset,
int repeat_count,
int array_len);
int array_len,
bool random_pk);
friend GeneratedData
DataGenForJsonArray(SchemaPtr schema,
int64_t N,
@ -292,9 +293,10 @@ inline GeneratedData DataGen(SchemaPtr schema,
uint64_t seed = 42,
uint64_t ts_offset = 0,
int repeat_count = 1,
int array_len = 10) {
int array_len = 10,
bool random_pk = false) {
using std::vector;
std::default_random_engine er(seed);
std::default_random_engine random(seed);
std::normal_distribution<> distr(0, 1);
int offset = 0;
@ -343,7 +345,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
Assert(dim % 8 == 0);
vector<uint8_t> data(dim / 8 * N);
for (auto& x : data) {
x = er();
x = random();
}
insert_cols(data, N, field_meta);
break;
@ -352,7 +354,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
auto dim = field_meta.get_dim();
vector<float16> final(dim * N);
for (auto& x : final) {
x = float16(distr(er) + offset);
x = float16(distr(random) + offset);
}
insert_cols(final, N, field_meta);
break;
@ -371,7 +373,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
auto dim = field_meta.get_dim();
vector<bfloat16> final(dim * N);
for (auto& x : final) {
x = bfloat16(distr(er) + offset);
x = bfloat16(distr(random) + offset);
}
insert_cols(final, N, field_meta);
break;
@ -387,7 +389,12 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::INT64: {
vector<int64_t> data(N);
for (int i = 0; i < N; i++) {
data[i] = i / repeat_count;
if (random_pk && schema->get_primary_field_id()->get() ==
field_id.get()) {
data[i] = random();
} else {
data[i] = i / repeat_count;
}
}
insert_cols(data, N, field_meta);
break;
@ -395,7 +402,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::INT32: {
vector<int> data(N);
for (auto& x : data) {
x = er() % (2 * N);
x = random() % (2 * N);
}
insert_cols(data, N, field_meta);
break;
@ -403,7 +410,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::INT16: {
vector<int16_t> data(N);
for (auto& x : data) {
x = er() % (2 * N);
x = random() % (2 * N);
}
insert_cols(data, N, field_meta);
break;
@ -411,7 +418,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::INT8: {
vector<int8_t> data(N);
for (auto& x : data) {
x = er() % (2 * N);
x = random() % (2 * N);
}
insert_cols(data, N, field_meta);
break;
@ -419,7 +426,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::FLOAT: {
vector<float> data(N);
for (auto& x : data) {
x = distr(er);
x = distr(random);
}
insert_cols(data, N, field_meta);
break;
@ -427,7 +434,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::DOUBLE: {
vector<double> data(N);
for (auto& x : data) {
x = distr(er);
x = distr(random);
}
insert_cols(data, N, field_meta);
break;
@ -435,7 +442,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::VARCHAR: {
vector<std::string> data(N);
for (int i = 0; i < N / repeat_count; i++) {
auto str = std::to_string(er());
auto str = std::to_string(random());
for (int j = 0; j < repeat_count; j++) {
data[i * repeat_count + j] = str;
}
@ -446,11 +453,12 @@ inline GeneratedData DataGen(SchemaPtr schema,
case DataType::JSON: {
vector<std::string> data(N);
for (int i = 0; i < N / repeat_count; i++) {
auto str =
R"({"int":)" + std::to_string(er()) + R"(,"double":)" +
std::to_string(static_cast<double>(er())) +
R"(,"string":")" + std::to_string(er()) +
R"(","bool": true)" + R"(, "array": [1,2,3])" + "}";
auto str = R"({"int":)" + std::to_string(random()) +
R"(,"double":)" +
std::to_string(static_cast<double>(random())) +
R"(,"string":")" + std::to_string(random()) +
R"(","bool": true)" + R"(, "array": [1,2,3])" +
"}";
data[i] = str;
}
insert_cols(data, N, field_meta);
@ -465,7 +473,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
for (int j = 0; j < array_len; j++) {
field_data.mutable_bool_data()->add_data(
static_cast<bool>(er()));
static_cast<bool>(random()));
}
data[i] = field_data;
}
@ -479,7 +487,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
for (int j = 0; j < array_len; j++) {
field_data.mutable_int_data()->add_data(
static_cast<int>(er()));
static_cast<int>(random()));
}
data[i] = field_data;
}
@ -490,7 +498,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
milvus::proto::schema::ScalarField field_data;
for (int j = 0; j < array_len; j++) {
field_data.mutable_long_data()->add_data(
static_cast<int64_t>(er()));
static_cast<int64_t>(random()));
}
data[i] = field_data;
}
@ -503,7 +511,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
for (int j = 0; j < array_len; j++) {
field_data.mutable_string_data()->add_data(
std::to_string(er()));
std::to_string(random()));
}
data[i] = field_data;
}
@ -515,7 +523,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
for (int j = 0; j < array_len; j++) {
field_data.mutable_float_data()->add_data(
static_cast<float>(er()));
static_cast<float>(random()));
}
data[i] = field_data;
}
@ -527,7 +535,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
for (int j = 0; j < array_len; j++) {
field_data.mutable_double_data()->add_data(
static_cast<double>(er()));
static_cast<double>(random()));
}
data[i] = field_data;
}
@ -1226,4 +1234,21 @@ NewCollection(const char* schema_proto_blob,
return (void*)collection.release();
}
inline CCollection
NewCollection(const milvus::proto::schema::CollectionSchema* schema,
MetricType metric_type = knowhere::metric::L2) {
auto collection = std::make_unique<milvus::segcore::Collection>(schema);
milvus::proto::segcore::CollectionIndexMeta col_index_meta;
for (auto field : collection->get_schema()->get_fields()) {
auto field_index_meta = col_index_meta.add_index_metas();
auto index_param = field_index_meta->add_index_params();
index_param->set_key("metric_type");
index_param->set_value(metric_type);
field_index_meta->set_fieldid(field.first.get());
}
collection->set_index_meta(
std::make_shared<CollectionIndexMeta>(col_index_meta));
return (void*)collection.release();
}
} // namespace milvus::segcore

View File

@ -63,6 +63,7 @@ generate_max_float_query_data(int all_nq, int max_float_nq) {
auto blob = raw_group.SerializeAsString();
return blob;
}
std::string
generate_query_data(int nq) {
namespace ser = milvus::proto::common;
@ -102,7 +103,8 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
auto ret = pk_set.insert(search_result->primary_keys_[ki]);
ASSERT_TRUE(ret.second);
if (search_result->group_by_values_.value().size() > ki) {
if (search_result->group_by_values_.has_value() &&
search_result->group_by_values_.value().size() > ki) {
auto group_by_val =
search_result->group_by_values_.value()[ki];
ASSERT_TRUE(group_by_val_set.count(group_by_val) == 0);
@ -112,4 +114,41 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
}
}
}
const char*
get_default_schema_config() {
static std::string conf = R"(name: "default-collection"
fields: <
fieldID: 100
name: "fakevec"
data_type: FloatVector
type_params: <
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
fieldID: 101
name: "age"
data_type: Int64
is_primary_key: true
>)";
static std::string fake_conf = "";
return conf.c_str();
}
CStatus
CSearch(CSegmentInterface c_segment,
CSearchPlan c_plan,
CPlaceholderGroup c_placeholder_group,
uint64_t timestamp,
CSearchResult* result) {
return Search(
{}, c_segment, c_plan, c_placeholder_group, timestamp, result);
}
} // namespace

View File

@ -371,6 +371,9 @@ func (gc *garbageCollector) checkDroppedSegmentGC(segment *SegmentInfo,
) bool {
log := log.With(zap.Int64("segmentID", segment.ID))
if !gc.isExpire(segment.GetDroppedAt()) {
return false
}
isCompacted := childSegment != nil || segment.GetCompacted()
if isCompacted {
// For compact A, B -> C, don't GC A or B if C is not indexed,
@ -382,10 +385,6 @@ func (gc *garbageCollector) checkDroppedSegmentGC(segment *SegmentInfo,
zap.Int64("child segment ID", childSegment.GetID()))
return false
}
} else {
if !gc.isExpire(segment.GetDroppedAt()) {
return false
}
}
segInsertChannel := segment.GetInsertChannel()

View File

@ -1308,6 +1308,7 @@ func TestGarbageCollector_clearETCD(t *testing.T) {
},
},
},
// cannot dropped for not expired.
segID + 6: {
SegmentInfo: &datapb.SegmentInfo{
ID: segID + 6,
@ -1343,6 +1344,24 @@ func TestGarbageCollector_clearETCD(t *testing.T) {
Compacted: true,
},
},
// can be dropped for expired and compacted
segID + 8: {
SegmentInfo: &datapb.SegmentInfo{
ID: segID + 8,
CollectionID: collID,
PartitionID: partID,
InsertChannel: "dmlChannel",
NumOfRows: 2000,
State: commonpb.SegmentState_Dropped,
MaxRowNum: 65535,
DroppedAt: uint64(time.Now().Add(-7 * 24 * time.Hour).UnixNano()),
CompactionFrom: nil,
DmlPosition: &msgpb.MsgPosition{
Timestamp: 900,
},
Compacted: true,
},
},
} {
m.segments.SetSegment(segID, segment)
}
@ -1390,9 +1409,11 @@ func TestGarbageCollector_clearETCD(t *testing.T) {
segF := gc.meta.GetSegment(segID + 5)
assert.NotNil(t, segF)
segG := gc.meta.GetSegment(segID + 6)
assert.Nil(t, segG)
assert.NotNil(t, segG)
segH := gc.meta.GetSegment(segID + 7)
assert.NotNil(t, segH)
segG = gc.meta.GetSegment(segID + 8)
assert.Nil(t, segG)
err := gc.meta.indexMeta.AddSegmentIndex(&model.SegmentIndex{
SegmentID: segID + 4,
CollectionID: collID,

View File

@ -2750,6 +2750,9 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
}
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
if merr.Ok(it.result.GetStatus()) {
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeUpsert, dbName, username).Add(float64(v))
}
metrics.ProxyFunctionCall.WithLabelValues(nodeID, method,
metrics.SuccessLabel, dbName, collectionName).Inc()
successCnt := it.result.UpsertCnt - int64(len(it.result.ErrIndex))

View File

@ -131,6 +131,7 @@ type shardDelegator struct {
// cause growing segment meta has been stored in segmentManager/distribution/pkOracle/excludeSegments
// in order to make add/remove growing be atomic, need lock before modify these meta info
growingSegmentLock sync.RWMutex
partitionStatsMut sync.RWMutex
}
// getLogger returns the zap logger with pre-defined shard attributes.
@ -217,8 +218,12 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
return nil, err
}
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed,
PruneInfo{filterRatio: paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
func() {
sd.partitionStatsMut.RLock()
defer sd.partitionStatsMut.RUnlock()
PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed,
PruneInfo{filterRatio: paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
}()
}
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
@ -477,7 +482,11 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
}
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
func() {
sd.partitionStatsMut.RLock()
defer sd.partitionStatsMut.RUnlock()
PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
}()
}
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
@ -804,7 +813,14 @@ func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs
log.Info("failed to find valid partition stats file for partition", zap.Int64("partitionID", partID))
continue
}
partStats, exists := sd.partitionStats[partID]
var partStats *storage.PartitionStatsSnapshot
var exists bool
func() {
sd.partitionStatsMut.RLock()
defer sd.partitionStatsMut.RUnlock()
partStats, exists = sd.partitionStats[partID]
}()
if !exists || (exists && partStats.GetVersion() < maxVersion) {
statsBytes, err := sd.chunkManager.Read(ctx, maxVersionFilePath)
if err != nil {
@ -816,8 +832,13 @@ func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs
log.Error("failed to parse partition stats from bytes", zap.Int("bytes_length", len(statsBytes)))
continue
}
sd.partitionStats[partID] = partStats
partStats.SetVersion(maxVersion)
func() {
sd.partitionStatsMut.Lock()
defer sd.partitionStatsMut.Unlock()
sd.partitionStats[partID] = partStats
}()
log.Info("Updated partitionStats for partition", zap.Int64("partitionID", partID))
}
}

View File

@ -138,14 +138,14 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
if ok := sd.VerifyExcludedSegments(segmentID, typeutil.MaxTimestamp); !ok {
log.Warn("try to insert data into released segment, skip it", zap.Int64("segmentID", segmentID))
sd.growingSegmentLock.Unlock()
growing.Release()
growing.Release(context.Background())
continue
}
if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) {
// register created growing segment after insert, avoid to add empty growing to delegator
sd.pkOracle.Register(growing, paramtable.GetNodeID())
sd.segmentManager.Put(segments.SegmentTypeGrowing, growing)
sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing)
sd.addGrowing(SegmentEntry{
NodeID: paramtable.GetNodeID(),
SegmentID: segmentID,
@ -378,7 +378,7 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
// clear loaded growing segments
for _, segment := range loaded {
segment.Release()
segment.Release(ctx)
}
return err
}

View File

@ -630,13 +630,13 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
growing0.EXPECT().ID().Return(1)
growing0.EXPECT().Partition().Return(10)
growing0.EXPECT().Type().Return(segments.SegmentTypeGrowing)
growing0.EXPECT().Release()
growing0.EXPECT().Release(context.Background())
growing1 := segments.NewMockSegment(s.T())
growing1.EXPECT().ID().Return(2)
growing1.EXPECT().Partition().Return(10)
growing1.EXPECT().Type().Return(segments.SegmentTypeGrowing)
growing1.EXPECT().Release()
growing1.EXPECT().Release(context.Background())
mockErr := merr.WrapErrServiceInternal("mock")
@ -1068,7 +1068,7 @@ func (s *DelegatorDataSuite) TestSyncTargetVersion() {
ms.EXPECT().Indexes().Return(nil)
ms.EXPECT().Shard().Return(s.vchannelName)
ms.EXPECT().Level().Return(datapb.SegmentLevel_L1)
s.manager.Segment.Put(segments.SegmentTypeGrowing, ms)
s.manager.Segment.Put(context.Background(), segments.SegmentTypeGrowing, ms)
}
s.delegator.SyncTargetVersion(int64(5), []int64{1}, []int64{2}, []int64{3, 4}, &msgpb.MsgPosition{})

View File

@ -28,6 +28,39 @@ func (_m *MockShardDelegator) EXPECT() *MockShardDelegator_Expecter {
return &MockShardDelegator_Expecter{mock: &_m.Mock}
}
// AddExcludedSegments provides a mock function with given fields: excludeInfo
func (_m *MockShardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {
_m.Called(excludeInfo)
}
// MockShardDelegator_AddExcludedSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddExcludedSegments'
type MockShardDelegator_AddExcludedSegments_Call struct {
*mock.Call
}
// AddExcludedSegments is a helper method to define mock.On call
// - excludeInfo map[int64]uint64
func (_e *MockShardDelegator_Expecter) AddExcludedSegments(excludeInfo interface{}) *MockShardDelegator_AddExcludedSegments_Call {
return &MockShardDelegator_AddExcludedSegments_Call{Call: _e.mock.On("AddExcludedSegments", excludeInfo)}
}
func (_c *MockShardDelegator_AddExcludedSegments_Call) Run(run func(excludeInfo map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[int64]uint64))
})
return _c
}
func (_c *MockShardDelegator_AddExcludedSegments_Call) Return() *MockShardDelegator_AddExcludedSegments_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardDelegator_AddExcludedSegments_Call) RunAndReturn(run func(map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with given fields:
func (_m *MockShardDelegator) Close() {
_m.Called()
@ -253,39 +286,6 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
return _c
}
// AddExcludedSegments provides a mock function with given fields: excludeInfo
func (_m *MockShardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {
_m.Called(excludeInfo)
}
// MockShardDelegator_AddExcludedSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddExcludedSegments'
type MockShardDelegator_AddExcludedSegments_Call struct {
*mock.Call
}
// AddExcludedSegments is a helper method to define mock.On call
// - excludeInfo map[int64]uint64
func (_e *MockShardDelegator_Expecter) AddExcludedSegments(excludeInfo interface{}) *MockShardDelegator_AddExcludedSegments_Call {
return &MockShardDelegator_AddExcludedSegments_Call{Call: _e.mock.On("AddExcludedSegments", excludeInfo)}
}
func (_c *MockShardDelegator_AddExcludedSegments_Call) Run(run func(excludeInfo map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[int64]uint64))
})
return _c
}
func (_c *MockShardDelegator_AddExcludedSegments_Call) Return() *MockShardDelegator_AddExcludedSegments_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardDelegator_AddExcludedSegments_Call) RunAndReturn(run func(map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
_c.Call.Return(run)
return _c
}
// LoadGrowing provides a mock function with given fields: ctx, infos, version
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
ret := _m.Called(ctx, infos, version)

View File

@ -163,6 +163,12 @@ func (node *QueryNode) loadIndex(ctx context.Context, req *querypb.LoadSegmentsR
continue
}
if localSegment.IsLazyLoad() {
localSegment.SetLoadInfo(info)
localSegment.SetNeedUpdatedVersion(req.GetVersion())
node.manager.DiskCache.MarkItemNeedReload(ctx, localSegment.ID())
return nil
}
err := node.loader.LoadIndex(ctx, localSegment, info, req.Version)
if err != nil {
log.Warn("failed to load index", zap.Error(err))
@ -425,11 +431,11 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge
results, readSegments, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
}
defer node.manager.Segment.Unpin(readSegments)
if err != nil {
log.Warn("get segments statistics failed", zap.Error(err))
return nil, err
}
defer node.manager.Segment.Unpin(readSegments)
return segmentStatsResponse(results), nil
}

View File

@ -47,11 +47,11 @@ func HandleCStatus(ctx context.Context, status *C.CStatus, extraInfo string, fie
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
log.Ctx(ctx).With(fields...).
log := log.Ctx(ctx).With(fields...).
WithOptions(zap.AddCallerSkip(1)) // Add caller stack to show HandleCStatus caller
err := merr.SegcoreError(int32(errorCode), errorMsg)
log.Warn("CStatus returns err", zap.Error(err))
log.Warn("CStatus returns err", zap.Error(err), zap.String("extra", extraInfo))
return err
}

View File

@ -148,13 +148,6 @@ func IncreaseVersion(version int64) SegmentAction {
}
}
type actionType int32
const (
removeAction actionType = iota
addAction
)
type Manager struct {
Collection CollectionManager
Segment SegmentManager
@ -173,18 +166,19 @@ func NewManager() *Manager {
}
manager.DiskCache = cache.NewCacheBuilder[int64, Segment]().WithLazyScavenger(func(key int64) int64 {
return int64(segMgr.sealedSegments[key].ResourceUsageEstimate().DiskSize)
}, diskCap).WithLoader(func(key int64) (Segment, bool) {
log.Debug("cache missed segment", zap.Int64("segmentID", key))
segMgr.mu.RLock()
defer segMgr.mu.RUnlock()
segment, ok := segMgr.sealedSegments[key]
if !ok {
// the segment has been released, just ignore it
return nil, false
segment := segMgr.GetWithType(key, SegmentTypeSealed)
if segment == nil {
return 0
}
return int64(segment.ResourceUsageEstimate().DiskSize)
}, diskCap).WithLoader(func(ctx context.Context, key int64) (Segment, error) {
log.Debug("cache missed segment", zap.Int64("segmentID", key))
segment := segMgr.GetWithType(key, SegmentTypeSealed)
if segment == nil {
// the segment has been released, just ignore it
log.Warn("segment is not found when loading", zap.Int64("segmentID", key))
return nil, merr.ErrSegmentNotFound
}
info := segment.LoadInfo()
_, err, _ := sf.Do(fmt.Sprint(segment.ID()), func() (nop interface{}, err error) {
cacheLoadRecord := metricsutil.NewCacheLoadRecord(getSegmentMetricLabel(segment))
@ -197,23 +191,50 @@ func NewManager() *Manager {
if collection == nil {
return nil, merr.WrapErrCollectionNotLoaded(segment.Collection(), "failed to load segment fields")
}
err = manager.Loader.LoadSegment(context.Background(), segment.(*LocalSegment), info, LoadStatusMapped)
err = manager.Loader.LoadLazySegment(ctx, segment.(*LocalSegment), info)
return nil, err
})
if err != nil {
log.Warn("cache sealed segment failed", zap.Error(err))
return nil, false
return nil, err
}
return segment, true
}).WithFinalizer(func(key int64, segment Segment) error {
log.Debug("evict segment from cache", zap.Int64("segmentID", key))
return segment, nil
}).WithFinalizer(func(ctx context.Context, key int64, segment Segment) error {
log.Ctx(ctx).Debug("evict segment from cache", zap.Int64("segmentID", key))
cacheEvictRecord := metricsutil.NewCacheEvictRecord(getSegmentMetricLabel(segment))
cacheEvictRecord.WithBytes(segment.ResourceUsageEstimate().DiskSize)
defer cacheEvictRecord.Finish(nil)
segment.Release(WithReleaseScope(ReleaseScopeData))
segment.Release(ctx, WithReleaseScope(ReleaseScopeData))
return nil
}).WithReloader(func(ctx context.Context, key int64) (Segment, error) {
segment := segMgr.GetWithType(key, SegmentTypeSealed)
if segment == nil {
// the segment has been released, just ignore it
log.Debug("segment is not found when reloading", zap.Int64("segmentID", key))
return nil, merr.ErrSegmentNotFound
}
localSegment := segment.(*LocalSegment)
err := manager.Loader.LoadIndex(ctx, localSegment, segment.LoadInfo(), segment.NeedUpdatedVersion())
if err != nil {
log.Warn("reload segment failed", zap.Int64("segmentID", key), zap.Error(err))
return nil, merr.ErrSegmentLoadFailed
}
if err := localSegment.RemoveUnusedFieldFiles(); err != nil {
log.Warn("remove unused field files failed", zap.Int64("segmentID", key), zap.Error(err))
return nil, merr.ErrSegmentReduplicate
}
return segment, nil
}).Build()
segMgr.registerReleaseCallback(func(s Segment) {
if s.Type() == SegmentTypeSealed {
manager.DiskCache.Expire(context.Background(), s.ID())
}
})
return manager
}
@ -225,7 +246,7 @@ type SegmentManager interface {
// Put puts the given segments in,
// and increases the ref count of the corresponding collection,
// dup segments will not increase the ref count
Put(segmentType SegmentType, segments ...Segment)
Put(ctx context.Context, segmentType SegmentType, segments ...Segment)
UpdateBy(action SegmentAction, filters ...SegmentFilter) int
Get(segmentID typeutil.UniqueID) Segment
GetWithType(segmentID typeutil.UniqueID, typ SegmentType) Segment
@ -242,9 +263,9 @@ type SegmentManager interface {
// Remove removes the given segment,
// and decreases the ref count of the corresponding collection,
// will not decrease the ref count if the given segment not exists
Remove(segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int)
RemoveBy(filters ...SegmentFilter) (int, int)
Clear()
Remove(ctx context.Context, segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int)
RemoveBy(ctx context.Context, filters ...SegmentFilter) (int, int)
Clear(ctx context.Context)
}
var _ SegmentManager = (*segmentManager)(nil)
@ -255,6 +276,9 @@ type segmentManager struct {
growingSegments map[typeutil.UniqueID]Segment
sealedSegments map[typeutil.UniqueID]Segment
// releaseCallback is the callback function when a segment is released.
releaseCallback func(s Segment)
}
func NewSegmentManager() *segmentManager {
@ -265,7 +289,7 @@ func NewSegmentManager() *segmentManager {
return mgr
}
func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
func (mgr *segmentManager) Put(ctx context.Context, segmentType SegmentType, segments ...Segment) {
var replacedSegment []Segment
mgr.mu.Lock()
defer mgr.mu.Unlock()
@ -278,7 +302,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
default:
panic("unexpected segment type")
}
log := log.Ctx(ctx)
for _, segment := range segments {
oldSegment, ok := targetMap[segment.ID()]
@ -290,7 +314,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
zap.Int64("newVersion", segment.Version()),
)
// delete redundant segment
segment.Release()
segment.Release(ctx)
continue
}
replacedSegment = append(replacedSegment, oldSegment)
@ -313,7 +337,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
if len(replacedSegment) > 0 {
go func() {
for _, segment := range replacedSegment {
remove(segment)
mgr.remove(ctx, segment)
}
}()
}
@ -563,7 +587,7 @@ func (mgr *segmentManager) Empty() bool {
// returns true if the segment exists,
// false otherwise
func (mgr *segmentManager) Remove(segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int) {
func (mgr *segmentManager) Remove(ctx context.Context, segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int) {
mgr.mu.Lock()
var removeGrowing, removeSealed int
@ -596,11 +620,11 @@ func (mgr *segmentManager) Remove(segmentID typeutil.UniqueID, scope querypb.Dat
mgr.mu.Unlock()
if growing != nil {
remove(growing)
mgr.remove(ctx, growing)
}
if sealed != nil {
remove(sealed)
mgr.remove(ctx, sealed)
}
return removeGrowing, removeSealed
@ -628,7 +652,7 @@ func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID type
return nil
}
func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) (int, int) {
func (mgr *segmentManager) RemoveBy(ctx context.Context, filters ...SegmentFilter) (int, int) {
mgr.mu.Lock()
var removeSegments []Segment
@ -651,28 +675,34 @@ func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) (int, int) {
mgr.mu.Unlock()
for _, s := range removeSegments {
remove(s)
mgr.remove(ctx, s)
}
return removeGrowing, removeSealed
}
func (mgr *segmentManager) Clear() {
func (mgr *segmentManager) Clear(ctx context.Context) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
for id, segment := range mgr.growingSegments {
delete(mgr.growingSegments, id)
remove(segment)
mgr.remove(ctx, segment)
}
for id, segment := range mgr.sealedSegments {
delete(mgr.sealedSegments, id)
remove(segment)
mgr.remove(ctx, segment)
}
mgr.updateMetric()
}
// registerReleaseCallback registers the callback function when a segment is released.
// TODO: bad implementation for keep consistency with DiskCache, need to be refactor.
func (mgr *segmentManager) registerReleaseCallback(callback func(s Segment)) {
mgr.releaseCallback = callback
}
func (mgr *segmentManager) updateMetric() {
// update collection and partiation metric
collections, partiations := make(typeutil.Set[int64]), make(typeutil.Set[int64])
@ -688,8 +718,11 @@ func (mgr *segmentManager) updateMetric() {
metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(partiations.Len()))
}
func remove(segment Segment) bool {
segment.Release()
func (mgr *segmentManager) remove(ctx context.Context, segment Segment) bool {
segment.Release(ctx)
if mgr.releaseCallback != nil {
mgr.releaseCallback(segment)
}
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),

View File

@ -62,7 +62,7 @@ func (s *ManagerSuite) SetupTest() {
s.Require().NoError(err)
s.segments = append(s.segments, segment)
s.mgr.Put(s.types[i], segment)
s.mgr.Put(context.Background(), s.types[i], segment)
}
}
@ -82,7 +82,7 @@ func (s *ManagerSuite) TestGetBy() {
segments := s.mgr.GetBy(WithType(typ))
s.Contains(lo.Map(segments, func(segment Segment, _ int) int64 { return segment.ID() }), s.segmentIDs[i])
}
s.mgr.Clear()
s.mgr.Clear(context.Background())
for _, typ := range s.types {
segments := s.mgr.GetBy(WithType(typ))
@ -101,7 +101,7 @@ func (s *ManagerSuite) TestRemoveGrowing() {
for i, id := range s.segmentIDs {
isGrowing := s.types[i] == SegmentTypeGrowing
s.mgr.Remove(id, querypb.DataScope_Streaming)
s.mgr.Remove(context.Background(), id, querypb.DataScope_Streaming)
s.Equal(s.mgr.Get(id) == nil, isGrowing)
}
}
@ -110,21 +110,21 @@ func (s *ManagerSuite) TestRemoveSealed() {
for i, id := range s.segmentIDs {
isSealed := s.types[i] == SegmentTypeSealed
s.mgr.Remove(id, querypb.DataScope_Historical)
s.mgr.Remove(context.Background(), id, querypb.DataScope_Historical)
s.Equal(s.mgr.Get(id) == nil, isSealed)
}
}
func (s *ManagerSuite) TestRemoveAll() {
for _, id := range s.segmentIDs {
s.mgr.Remove(id, querypb.DataScope_All)
s.mgr.Remove(context.Background(), id, querypb.DataScope_All)
s.Nil(s.mgr.Get(id))
}
}
func (s *ManagerSuite) TestRemoveBy() {
for _, id := range s.segmentIDs {
s.mgr.RemoveBy(WithID(id))
s.mgr.RemoveBy(context.Background(), WithID(id))
s.Nil(s.mgr.Get(id))
}
}

View File

@ -261,13 +261,57 @@ func (_c *MockLoader_LoadIndex_Call) RunAndReturn(run func(context.Context, *Loc
return _c
}
// LoadSegment provides a mock function with given fields: ctx, segment, loadInfo, loadStatus
func (_m *MockLoader) LoadSegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo, loadStatus LoadStatus) error {
ret := _m.Called(ctx, segment, loadInfo, loadStatus)
// LoadLazySegment provides a mock function with given fields: ctx, segment, loadInfo
func (_m *MockLoader) LoadLazySegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo) error {
ret := _m.Called(ctx, segment, loadInfo)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo, LoadStatus) error); ok {
r0 = rf(ctx, segment, loadInfo, loadStatus)
if rf, ok := ret.Get(0).(func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error); ok {
r0 = rf(ctx, segment, loadInfo)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockLoader_LoadLazySegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadLazySegment'
type MockLoader_LoadLazySegment_Call struct {
*mock.Call
}
// LoadLazySegment is a helper method to define mock.On call
// - ctx context.Context
// - segment *LocalSegment
// - loadInfo *querypb.SegmentLoadInfo
func (_e *MockLoader_Expecter) LoadLazySegment(ctx interface{}, segment interface{}, loadInfo interface{}) *MockLoader_LoadLazySegment_Call {
return &MockLoader_LoadLazySegment_Call{Call: _e.mock.On("LoadLazySegment", ctx, segment, loadInfo)}
}
func (_c *MockLoader_LoadLazySegment_Call) Run(run func(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo)) *MockLoader_LoadLazySegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*LocalSegment), args[2].(*querypb.SegmentLoadInfo))
})
return _c
}
func (_c *MockLoader_LoadLazySegment_Call) Return(_a0 error) *MockLoader_LoadLazySegment_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockLoader_LoadLazySegment_Call) RunAndReturn(run func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error) *MockLoader_LoadLazySegment_Call {
_c.Call.Return(run)
return _c
}
// LoadSegment provides a mock function with given fields: ctx, segment, loadInfo
func (_m *MockLoader) LoadSegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo) error {
ret := _m.Called(ctx, segment, loadInfo)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error); ok {
r0 = rf(ctx, segment, loadInfo)
} else {
r0 = ret.Error(0)
}
@ -284,14 +328,13 @@ type MockLoader_LoadSegment_Call struct {
// - ctx context.Context
// - segment *LocalSegment
// - loadInfo *querypb.SegmentLoadInfo
// - loadStatus LoadStatus
func (_e *MockLoader_Expecter) LoadSegment(ctx interface{}, segment interface{}, loadInfo interface{}, loadStatus interface{}) *MockLoader_LoadSegment_Call {
return &MockLoader_LoadSegment_Call{Call: _e.mock.On("LoadSegment", ctx, segment, loadInfo, loadStatus)}
func (_e *MockLoader_Expecter) LoadSegment(ctx interface{}, segment interface{}, loadInfo interface{}) *MockLoader_LoadSegment_Call {
return &MockLoader_LoadSegment_Call{Call: _e.mock.On("LoadSegment", ctx, segment, loadInfo)}
}
func (_c *MockLoader_LoadSegment_Call) Run(run func(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo, loadStatus LoadStatus)) *MockLoader_LoadSegment_Call {
func (_c *MockLoader_LoadSegment_Call) Run(run func(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo)) *MockLoader_LoadSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*LocalSegment), args[2].(*querypb.SegmentLoadInfo), args[3].(LoadStatus))
run(args[0].(context.Context), args[1].(*LocalSegment), args[2].(*querypb.SegmentLoadInfo))
})
return _c
}
@ -301,7 +344,7 @@ func (_c *MockLoader_LoadSegment_Call) Return(_a0 error) *MockLoader_LoadSegment
return _c
}
func (_c *MockLoader_LoadSegment_Call) RunAndReturn(run func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo, LoadStatus) error) *MockLoader_LoadSegment_Call {
func (_c *MockLoader_LoadSegment_Call) RunAndReturn(run func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error) *MockLoader_LoadSegment_Call {
_c.Call.Return(run)
return _c
}

View File

@ -709,47 +709,6 @@ func (_c *MockSegment_LoadInfo_Call) RunAndReturn(run func() *querypb.SegmentLoa
return _c
}
// LoadStatus provides a mock function with given fields:
func (_m *MockSegment) LoadStatus() LoadStatus {
ret := _m.Called()
var r0 LoadStatus
if rf, ok := ret.Get(0).(func() LoadStatus); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(LoadStatus)
}
return r0
}
// MockSegment_LoadStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadStatus'
type MockSegment_LoadStatus_Call struct {
*mock.Call
}
// LoadStatus is a helper method to define mock.On call
func (_e *MockSegment_Expecter) LoadStatus() *MockSegment_LoadStatus_Call {
return &MockSegment_LoadStatus_Call{Call: _e.mock.On("LoadStatus")}
}
func (_c *MockSegment_LoadStatus_Call) Run(run func()) *MockSegment_LoadStatus_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_LoadStatus_Call) Return(_a0 LoadStatus) *MockSegment_LoadStatus_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSegment_LoadStatus_Call) RunAndReturn(run func() LoadStatus) *MockSegment_LoadStatus_Call {
_c.Call.Return(run)
return _c
}
// MayPkExist provides a mock function with given fields: pk
func (_m *MockSegment) MayPkExist(pk storage.PrimaryKey) bool {
ret := _m.Called(pk)
@ -833,6 +792,47 @@ func (_c *MockSegment_MemSize_Call) RunAndReturn(run func() int64) *MockSegment_
return _c
}
// NeedUpdatedVersion provides a mock function with given fields:
func (_m *MockSegment) NeedUpdatedVersion() int64 {
ret := _m.Called()
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockSegment_NeedUpdatedVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NeedUpdatedVersion'
type MockSegment_NeedUpdatedVersion_Call struct {
*mock.Call
}
// NeedUpdatedVersion is a helper method to define mock.On call
func (_e *MockSegment_Expecter) NeedUpdatedVersion() *MockSegment_NeedUpdatedVersion_Call {
return &MockSegment_NeedUpdatedVersion_Call{Call: _e.mock.On("NeedUpdatedVersion")}
}
func (_c *MockSegment_NeedUpdatedVersion_Call) Run(run func()) *MockSegment_NeedUpdatedVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_NeedUpdatedVersion_Call) Return(_a0 int64) *MockSegment_NeedUpdatedVersion_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSegment_NeedUpdatedVersion_Call) RunAndReturn(run func() int64) *MockSegment_NeedUpdatedVersion_Call {
_c.Call.Return(run)
return _c
}
// Partition provides a mock function with given fields:
func (_m *MockSegment) Partition() int64 {
ret := _m.Called()
@ -888,72 +888,41 @@ func (_m *MockSegment) PinIfNotReleased() error {
return r0
}
// MockSegment_RLock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RLock'
type MockSegment_RLock_Call struct {
// MockSegment_PinIfNotReleased_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PinIfNotReleased'
type MockSegment_PinIfNotReleased_Call struct {
*mock.Call
}
// RLock is a helper method to define mock.On call
func (_e *MockSegment_Expecter) RLock() *MockSegment_RLock_Call {
return &MockSegment_RLock_Call{Call: _e.mock.On("RLock")}
// PinIfNotReleased is a helper method to define mock.On call
func (_e *MockSegment_Expecter) PinIfNotReleased() *MockSegment_PinIfNotReleased_Call {
return &MockSegment_PinIfNotReleased_Call{Call: _e.mock.On("PinIfNotReleased")}
}
func (_c *MockSegment_RLock_Call) Run(run func()) *MockSegment_RLock_Call {
func (_c *MockSegment_PinIfNotReleased_Call) Run(run func()) *MockSegment_PinIfNotReleased_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_RLock_Call) Return(_a0 error) *MockSegment_RLock_Call {
func (_c *MockSegment_PinIfNotReleased_Call) Return(_a0 error) *MockSegment_PinIfNotReleased_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSegment_RLock_Call) RunAndReturn(run func() error) *MockSegment_RLock_Call {
func (_c *MockSegment_PinIfNotReleased_Call) RunAndReturn(run func() error) *MockSegment_PinIfNotReleased_Call {
_c.Call.Return(run)
return _c
}
// Unpin provides a mock function with given fields:
func (_m *MockSegment) Unpin() {
_m.Called()
}
// MockSegment_RUnlock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RUnlock'
type MockSegment_RUnlock_Call struct {
*mock.Call
}
// RUnlock is a helper method to define mock.On call
func (_e *MockSegment_Expecter) RUnlock() *MockSegment_RUnlock_Call {
return &MockSegment_RUnlock_Call{Call: _e.mock.On("RUnlock")}
}
func (_c *MockSegment_RUnlock_Call) Run(run func()) *MockSegment_RUnlock_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_RUnlock_Call) Return() *MockSegment_RUnlock_Call {
_c.Call.Return()
return _c
}
func (_c *MockSegment_RUnlock_Call) RunAndReturn(run func()) *MockSegment_RUnlock_Call {
_c.Call.Return(run)
return _c
}
// Release provides a mock function with given fields: opts
func (_m *MockSegment) Release(opts ...releaseOption) {
// Release provides a mock function with given fields: ctx, opts
func (_m *MockSegment) Release(ctx context.Context, opts ...releaseOption) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
_m.Called(_ca...)
}
@ -964,21 +933,22 @@ type MockSegment_Release_Call struct {
}
// Release is a helper method to define mock.On call
// - ctx context.Context
// - opts ...releaseOption
func (_e *MockSegment_Expecter) Release(opts ...interface{}) *MockSegment_Release_Call {
func (_e *MockSegment_Expecter) Release(ctx interface{}, opts ...interface{}) *MockSegment_Release_Call {
return &MockSegment_Release_Call{Call: _e.mock.On("Release",
append([]interface{}{}, opts...)...)}
append([]interface{}{ctx}, opts...)...)}
}
func (_c *MockSegment_Release_Call) Run(run func(opts ...releaseOption)) *MockSegment_Release_Call {
func (_c *MockSegment_Release_Call) Run(run func(ctx context.Context, opts ...releaseOption)) *MockSegment_Release_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]releaseOption, len(args)-0)
for i, a := range args[0:] {
variadicArgs := make([]releaseOption, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(releaseOption)
}
}
run(variadicArgs...)
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
@ -988,7 +958,48 @@ func (_c *MockSegment_Release_Call) Return() *MockSegment_Release_Call {
return _c
}
func (_c *MockSegment_Release_Call) RunAndReturn(run func(...releaseOption)) *MockSegment_Release_Call {
func (_c *MockSegment_Release_Call) RunAndReturn(run func(context.Context, ...releaseOption)) *MockSegment_Release_Call {
_c.Call.Return(run)
return _c
}
// RemoveUnusedFieldFiles provides a mock function with given fields:
func (_m *MockSegment) RemoveUnusedFieldFiles() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockSegment_RemoveUnusedFieldFiles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveUnusedFieldFiles'
type MockSegment_RemoveUnusedFieldFiles_Call struct {
*mock.Call
}
// RemoveUnusedFieldFiles is a helper method to define mock.On call
func (_e *MockSegment_Expecter) RemoveUnusedFieldFiles() *MockSegment_RemoveUnusedFieldFiles_Call {
return &MockSegment_RemoveUnusedFieldFiles_Call{Call: _e.mock.On("RemoveUnusedFieldFiles")}
}
func (_c *MockSegment_RemoveUnusedFieldFiles_Call) Run(run func()) *MockSegment_RemoveUnusedFieldFiles_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_RemoveUnusedFieldFiles_Call) Return(_a0 error) *MockSegment_RemoveUnusedFieldFiles_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSegment_RemoveUnusedFieldFiles_Call) RunAndReturn(run func() error) *MockSegment_RemoveUnusedFieldFiles_Call {
_c.Call.Return(run)
return _c
}
@ -1440,6 +1451,38 @@ func (_c *MockSegment_Type_Call) RunAndReturn(run func() commonpb.SegmentState)
return _c
}
// Unpin provides a mock function with given fields:
func (_m *MockSegment) Unpin() {
_m.Called()
}
// MockSegment_Unpin_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Unpin'
type MockSegment_Unpin_Call struct {
*mock.Call
}
// Unpin is a helper method to define mock.On call
func (_e *MockSegment_Expecter) Unpin() *MockSegment_Unpin_Call {
return &MockSegment_Unpin_Call{Call: _e.mock.On("Unpin")}
}
func (_c *MockSegment_Unpin_Call) Run(run func()) *MockSegment_Unpin_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_Unpin_Call) Return() *MockSegment_Unpin_Call {
_c.Call.Return()
return _c
}
func (_c *MockSegment_Unpin_Call) RunAndReturn(run func()) *MockSegment_Unpin_Call {
_c.Call.Return(run)
return _c
}
// UpdateBloomFilter provides a mock function with given fields: pks
func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
_m.Called(pks)

View File

@ -3,7 +3,10 @@
package segments
import (
context "context"
commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
mock "github.com/stretchr/testify/mock"
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
@ -22,9 +25,9 @@ func (_m *MockSegmentManager) EXPECT() *MockSegmentManager_Expecter {
return &MockSegmentManager_Expecter{mock: &_m.Mock}
}
// Clear provides a mock function with given fields:
func (_m *MockSegmentManager) Clear() {
_m.Called()
// Clear provides a mock function with given fields: ctx
func (_m *MockSegmentManager) Clear(ctx context.Context) {
_m.Called(ctx)
}
// MockSegmentManager_Clear_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Clear'
@ -33,13 +36,14 @@ type MockSegmentManager_Clear_Call struct {
}
// Clear is a helper method to define mock.On call
func (_e *MockSegmentManager_Expecter) Clear() *MockSegmentManager_Clear_Call {
return &MockSegmentManager_Clear_Call{Call: _e.mock.On("Clear")}
// - ctx context.Context
func (_e *MockSegmentManager_Expecter) Clear(ctx interface{}) *MockSegmentManager_Clear_Call {
return &MockSegmentManager_Clear_Call{Call: _e.mock.On("Clear", ctx)}
}
func (_c *MockSegmentManager_Clear_Call) Run(run func()) *MockSegmentManager_Clear_Call {
func (_c *MockSegmentManager_Clear_Call) Run(run func(ctx context.Context)) *MockSegmentManager_Clear_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
run(args[0].(context.Context))
})
return _c
}
@ -49,7 +53,7 @@ func (_c *MockSegmentManager_Clear_Call) Return() *MockSegmentManager_Clear_Call
return _c
}
func (_c *MockSegmentManager_Clear_Call) RunAndReturn(run func()) *MockSegmentManager_Clear_Call {
func (_c *MockSegmentManager_Clear_Call) RunAndReturn(run func(context.Context)) *MockSegmentManager_Clear_Call {
_c.Call.Return(run)
return _c
}
@ -465,14 +469,14 @@ func (_c *MockSegmentManager_GetWithType_Call) RunAndReturn(run func(int64, comm
return _c
}
// Put provides a mock function with given fields: segmentType, segments
func (_m *MockSegmentManager) Put(segmentType commonpb.SegmentState, segments ...Segment) {
// Put provides a mock function with given fields: ctx, segmentType, segments
func (_m *MockSegmentManager) Put(ctx context.Context, segmentType commonpb.SegmentState, segments ...Segment) {
_va := make([]interface{}, len(segments))
for _i := range segments {
_va[_i] = segments[_i]
}
var _ca []interface{}
_ca = append(_ca, segmentType)
_ca = append(_ca, ctx, segmentType)
_ca = append(_ca, _va...)
_m.Called(_ca...)
}
@ -483,22 +487,23 @@ type MockSegmentManager_Put_Call struct {
}
// Put is a helper method to define mock.On call
// - ctx context.Context
// - segmentType commonpb.SegmentState
// - segments ...Segment
func (_e *MockSegmentManager_Expecter) Put(segmentType interface{}, segments ...interface{}) *MockSegmentManager_Put_Call {
func (_e *MockSegmentManager_Expecter) Put(ctx interface{}, segmentType interface{}, segments ...interface{}) *MockSegmentManager_Put_Call {
return &MockSegmentManager_Put_Call{Call: _e.mock.On("Put",
append([]interface{}{segmentType}, segments...)...)}
append([]interface{}{ctx, segmentType}, segments...)...)}
}
func (_c *MockSegmentManager_Put_Call) Run(run func(segmentType commonpb.SegmentState, segments ...Segment)) *MockSegmentManager_Put_Call {
func (_c *MockSegmentManager_Put_Call) Run(run func(ctx context.Context, segmentType commonpb.SegmentState, segments ...Segment)) *MockSegmentManager_Put_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]Segment, len(args)-1)
for i, a := range args[1:] {
variadicArgs := make([]Segment, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(Segment)
}
}
run(args[0].(commonpb.SegmentState), variadicArgs...)
run(args[0].(context.Context), args[1].(commonpb.SegmentState), variadicArgs...)
})
return _c
}
@ -508,28 +513,28 @@ func (_c *MockSegmentManager_Put_Call) Return() *MockSegmentManager_Put_Call {
return _c
}
func (_c *MockSegmentManager_Put_Call) RunAndReturn(run func(commonpb.SegmentState, ...Segment)) *MockSegmentManager_Put_Call {
func (_c *MockSegmentManager_Put_Call) RunAndReturn(run func(context.Context, commonpb.SegmentState, ...Segment)) *MockSegmentManager_Put_Call {
_c.Call.Return(run)
return _c
}
// Remove provides a mock function with given fields: segmentID, scope
func (_m *MockSegmentManager) Remove(segmentID int64, scope querypb.DataScope) (int, int) {
ret := _m.Called(segmentID, scope)
// Remove provides a mock function with given fields: ctx, segmentID, scope
func (_m *MockSegmentManager) Remove(ctx context.Context, segmentID int64, scope querypb.DataScope) (int, int) {
ret := _m.Called(ctx, segmentID, scope)
var r0 int
var r1 int
if rf, ok := ret.Get(0).(func(int64, querypb.DataScope) (int, int)); ok {
return rf(segmentID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, querypb.DataScope) (int, int)); ok {
return rf(ctx, segmentID, scope)
}
if rf, ok := ret.Get(0).(func(int64, querypb.DataScope) int); ok {
r0 = rf(segmentID, scope)
if rf, ok := ret.Get(0).(func(context.Context, int64, querypb.DataScope) int); ok {
r0 = rf(ctx, segmentID, scope)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func(int64, querypb.DataScope) int); ok {
r1 = rf(segmentID, scope)
if rf, ok := ret.Get(1).(func(context.Context, int64, querypb.DataScope) int); ok {
r1 = rf(ctx, segmentID, scope)
} else {
r1 = ret.Get(1).(int)
}
@ -543,15 +548,16 @@ type MockSegmentManager_Remove_Call struct {
}
// Remove is a helper method to define mock.On call
// - ctx context.Context
// - segmentID int64
// - scope querypb.DataScope
func (_e *MockSegmentManager_Expecter) Remove(segmentID interface{}, scope interface{}) *MockSegmentManager_Remove_Call {
return &MockSegmentManager_Remove_Call{Call: _e.mock.On("Remove", segmentID, scope)}
func (_e *MockSegmentManager_Expecter) Remove(ctx interface{}, segmentID interface{}, scope interface{}) *MockSegmentManager_Remove_Call {
return &MockSegmentManager_Remove_Call{Call: _e.mock.On("Remove", ctx, segmentID, scope)}
}
func (_c *MockSegmentManager_Remove_Call) Run(run func(segmentID int64, scope querypb.DataScope)) *MockSegmentManager_Remove_Call {
func (_c *MockSegmentManager_Remove_Call) Run(run func(ctx context.Context, segmentID int64, scope querypb.DataScope)) *MockSegmentManager_Remove_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(querypb.DataScope))
run(args[0].(context.Context), args[1].(int64), args[2].(querypb.DataScope))
})
return _c
}
@ -561,34 +567,35 @@ func (_c *MockSegmentManager_Remove_Call) Return(_a0 int, _a1 int) *MockSegmentM
return _c
}
func (_c *MockSegmentManager_Remove_Call) RunAndReturn(run func(int64, querypb.DataScope) (int, int)) *MockSegmentManager_Remove_Call {
func (_c *MockSegmentManager_Remove_Call) RunAndReturn(run func(context.Context, int64, querypb.DataScope) (int, int)) *MockSegmentManager_Remove_Call {
_c.Call.Return(run)
return _c
}
// RemoveBy provides a mock function with given fields: filters
func (_m *MockSegmentManager) RemoveBy(filters ...SegmentFilter) (int, int) {
// RemoveBy provides a mock function with given fields: ctx, filters
func (_m *MockSegmentManager) RemoveBy(ctx context.Context, filters ...SegmentFilter) (int, int) {
_va := make([]interface{}, len(filters))
for _i := range filters {
_va[_i] = filters[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 int
var r1 int
if rf, ok := ret.Get(0).(func(...SegmentFilter) (int, int)); ok {
return rf(filters...)
if rf, ok := ret.Get(0).(func(context.Context, ...SegmentFilter) (int, int)); ok {
return rf(ctx, filters...)
}
if rf, ok := ret.Get(0).(func(...SegmentFilter) int); ok {
r0 = rf(filters...)
if rf, ok := ret.Get(0).(func(context.Context, ...SegmentFilter) int); ok {
r0 = rf(ctx, filters...)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func(...SegmentFilter) int); ok {
r1 = rf(filters...)
if rf, ok := ret.Get(1).(func(context.Context, ...SegmentFilter) int); ok {
r1 = rf(ctx, filters...)
} else {
r1 = ret.Get(1).(int)
}
@ -602,21 +609,22 @@ type MockSegmentManager_RemoveBy_Call struct {
}
// RemoveBy is a helper method to define mock.On call
// - ctx context.Context
// - filters ...SegmentFilter
func (_e *MockSegmentManager_Expecter) RemoveBy(filters ...interface{}) *MockSegmentManager_RemoveBy_Call {
func (_e *MockSegmentManager_Expecter) RemoveBy(ctx interface{}, filters ...interface{}) *MockSegmentManager_RemoveBy_Call {
return &MockSegmentManager_RemoveBy_Call{Call: _e.mock.On("RemoveBy",
append([]interface{}{}, filters...)...)}
append([]interface{}{ctx}, filters...)...)}
}
func (_c *MockSegmentManager_RemoveBy_Call) Run(run func(filters ...SegmentFilter)) *MockSegmentManager_RemoveBy_Call {
func (_c *MockSegmentManager_RemoveBy_Call) Run(run func(ctx context.Context, filters ...SegmentFilter)) *MockSegmentManager_RemoveBy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]SegmentFilter, len(args)-0)
for i, a := range args[0:] {
variadicArgs := make([]SegmentFilter, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(SegmentFilter)
}
}
run(variadicArgs...)
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
@ -626,7 +634,7 @@ func (_c *MockSegmentManager_RemoveBy_Call) Return(_a0 int, _a1 int) *MockSegmen
return _c
}
func (_c *MockSegmentManager_RemoveBy_Call) RunAndReturn(run func(...SegmentFilter) (int, int)) *MockSegmentManager_RemoveBy_Call {
func (_c *MockSegmentManager_RemoveBy_Call) RunAndReturn(run func(context.Context, ...SegmentFilter) (int, int)) *MockSegmentManager_RemoveBy_Call {
_c.Call.Return(run)
return _c
}

View File

@ -39,8 +39,11 @@ type SearchResult struct {
cSearchResult C.CSearchResult
}
// searchResultDataBlobs is the CSearchResultsDataBlobs in C++
type searchResultDataBlobs = C.CSearchResultDataBlobs
// SearchResultDataBlobs is the CSearchResultsDataBlobs in C++
type (
SearchResultDataBlobs = C.CSearchResultDataBlobs
StreamSearchReducer = C.CSearchStreamReducer
)
// RetrieveResult contains a pointer to the retrieve result in C++ memory
type RetrieveResult struct {
@ -71,9 +74,58 @@ func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *S
return sInfo
}
func NewStreamReducer(ctx context.Context,
plan *SearchPlan,
sliceNQs []int64,
sliceTopKs []int64,
) (StreamSearchReducer, error) {
if plan.cSearchPlan == nil {
return nil, fmt.Errorf("nil search plan")
}
if len(sliceNQs) == 0 {
return nil, fmt.Errorf("empty slice nqs is not allowed")
}
if len(sliceNQs) != len(sliceTopKs) {
return nil, fmt.Errorf("unaligned sliceNQs(len=%d) and sliceTopKs(len=%d)", len(sliceNQs), len(sliceTopKs))
}
cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0])
cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0])
cNumSlices := C.int64_t(len(sliceNQs))
var streamReducer StreamSearchReducer
status := C.NewStreamReducer(plan.cSearchPlan, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices, &streamReducer)
if err := HandleCStatus(ctx, &status, "MergeSearchResultsWithOutputFields failed"); err != nil {
return nil, err
}
return streamReducer, nil
}
func StreamReduceSearchResult(ctx context.Context,
newResult *SearchResult, streamReducer StreamSearchReducer,
) error {
cSearchResults := make([]C.CSearchResult, 0)
cSearchResults = append(cSearchResults, newResult.cSearchResult)
cSearchResultPtr := &cSearchResults[0]
status := C.StreamReduce(streamReducer, cSearchResultPtr, 1)
if err := HandleCStatus(ctx, &status, "StreamReduceSearchResult failed"); err != nil {
return err
}
return nil
}
func GetStreamReduceResult(ctx context.Context, streamReducer StreamSearchReducer) (SearchResultDataBlobs, error) {
var cSearchResultDataBlobs SearchResultDataBlobs
status := C.GetStreamReduceResult(streamReducer, &cSearchResultDataBlobs)
if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil {
return nil, err
}
return cSearchResultDataBlobs, nil
}
func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searchResults []*SearchResult,
numSegments int64, sliceNQs []int64, sliceTopKs []int64,
) (searchResultDataBlobs, error) {
) (SearchResultDataBlobs, error) {
if plan.cSearchPlan == nil {
return nil, fmt.Errorf("nil search plan")
}
@ -98,7 +150,7 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc
cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0])
cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0])
cNumSlices := C.int64_t(len(sliceNQs))
var cSearchResultDataBlobs searchResultDataBlobs
var cSearchResultDataBlobs SearchResultDataBlobs
status := C.ReduceSearchResultsAndFillData(&cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)
if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil {
@ -107,7 +159,7 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc
return cSearchResultDataBlobs, nil
}
func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs searchResultDataBlobs, blobIndex int) ([]byte, error) {
func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs SearchResultDataBlobs, blobIndex int) ([]byte, error) {
var blob C.CProto
status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex))
if err := HandleCStatus(ctx, &status, "marshal failed"); err != nil {
@ -116,10 +168,14 @@ func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs searchR
return GetCProtoBlob(&blob), nil
}
func DeleteSearchResultDataBlobs(cSearchResultDataBlobs searchResultDataBlobs) {
func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) {
C.DeleteSearchResultDataBlobs(cSearchResultDataBlobs)
}
func DeleteStreamReduceHelper(cStreamReduceHelper StreamSearchReducer) {
C.DeleteStreamSearchReducer(cStreamReduceHelper)
}
func DeleteSearchResults(results []*SearchResult) {
if len(results) == 0 {
return

View File

@ -83,6 +83,7 @@ func (suite *ReduceSuite) SetupTest() {
CollectionID: suite.collectionID,
PartitionID: suite.partitionID,
InsertChannel: "dml",
NumOfRows: int64(msgLength),
Level: datapb.SegmentLevel_Legacy,
},
)
@ -104,7 +105,7 @@ func (suite *ReduceSuite) SetupTest() {
}
func (suite *ReduceSuite) TearDownTest() {
suite.segment.Release()
suite.segment.Release(context.Background())
DeleteCollection(suite.collection)
ctx := context.Background()
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)

View File

@ -520,16 +520,18 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
selected := make([]int, 0, ret.GetAllRetrieveCount())
var limit int = -1
if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
loopEnd = int(param.limit)
limit = int(param.limit)
}
idSet := make(map[interface{}]struct{})
cursors := make([]int64, len(validRetrieveResults))
var availableCount int
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
for j := 0; j < loopEnd; j++ {
for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ {
sel, drainOneResult := typeutil.SelectMinPK(param.limit, validRetrieveResults, cursors)
if sel == -1 || (param.mergeStopForBest && drainOneResult) {
break
@ -542,6 +544,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
selectedOffsets[sel] = append(selectedOffsets[sel], validRetrieveResults[sel].GetOffset()[cursors[sel]])
selectedIndexes[sel] = append(selectedIndexes[sel], cursors[sel])
idSet[pk] = struct{}{}
availableCount++
} else {
// primary keys duplicate
skipDupCnt++

View File

@ -20,8 +20,10 @@ import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
@ -44,11 +46,7 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
result *segcorepb.RetrieveResults
segment Segment
}
var (
resultCh = make(chan segmentResult, len(segments))
errs = make([]error, len(segments))
wg sync.WaitGroup
)
resultCh := make(chan segmentResult, len(segments))
plan.ignoreNonPk = len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk()
@ -57,25 +55,29 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
label = metrics.GrowingSegmentLabel
}
retriever := func(s Segment) error {
retriever := func(ctx context.Context, s Segment) error {
tr := timerecord.NewTimeRecorder("retrieveOnSegments")
result, err := s.Retrieve(ctx, plan)
if err != nil {
return err
}
resultCh <- segmentResult{
result,
s,
}
if err != nil {
return err
}
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel, label).Observe(float64(tr.ElapseSpan().Milliseconds()))
return nil
}
for i, segment := range segments {
wg.Add(1)
go func(seg Segment, i int) {
defer wg.Done()
errGroup, ctx := errgroup.WithContext(ctx)
for _, segment := range segments {
seg := segment
errGroup.Go(func() error {
if ctx.Err() != nil {
return ctx.Err()
}
// record search time and cache miss
var err error
accessRecord := metricsutil.NewQuerySegmentAccessRecord(getSegmentMetricLabel(seg))
@ -84,26 +86,26 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
}()
if seg.IsLazyLoad() {
var timeout time.Duration
timeout, err = lazyloadWaitTimeout(ctx)
if err != nil {
return err
}
var missing bool
missing, err = mgr.DiskCache.Do(seg.ID(), retriever)
missing, err = mgr.DiskCache.DoWait(ctx, seg.ID(), timeout, retriever)
if missing {
accessRecord.CacheMissing()
}
} else {
err = retriever(seg)
return err
}
if err != nil {
errs[i] = err
}
}(segment, i)
return retriever(ctx, seg)
})
}
wg.Wait()
err := errGroup.Wait()
close(resultCh)
for _, err := range errs {
if err != nil {
return nil, nil, err
}
if err != nil {
return nil, nil, err
}
var retrieveSegments []Segment
@ -163,9 +165,12 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
// retrieve will retrieve all the validate target segments
func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]*segcorepb.RetrieveResults, []Segment, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
var err error
var SegType commonpb.SegmentState
var retrieveResults []*segcorepb.RetrieveResults
var retrieveSegments []Segment
segIDs := req.GetSegmentIDs()
@ -181,7 +186,7 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu
}
if err != nil {
return retrieveResults, retrieveSegments, err
return nil, nil, err
}
return retrieveOnSegments(ctx, manager, retrieveSegments, SegType, plan, req)

View File

@ -92,6 +92,7 @@ func (suite *RetrieveSuite) SetupTest() {
CollectionID: suite.collectionID,
PartitionID: suite.partitionID,
InsertChannel: "dml",
NumOfRows: int64(msgLength),
Level: datapb.SegmentLevel_Legacy,
},
)
@ -132,13 +133,13 @@ func (suite *RetrieveSuite) SetupTest() {
err = suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord)
suite.Require().NoError(err)
suite.manager.Segment.Put(SegmentTypeSealed, suite.sealed)
suite.manager.Segment.Put(SegmentTypeGrowing, suite.growing)
suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed)
suite.manager.Segment.Put(context.Background(), SegmentTypeGrowing, suite.growing)
}
func (suite *RetrieveSuite) TearDownTest() {
suite.sealed.Release()
suite.growing.Release()
suite.sealed.Release(context.Background())
suite.growing.Release(context.Background())
DeleteCollection(suite.collection)
ctx := context.Background()
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
@ -249,7 +250,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
plan, err := genSimpleRetrievePlan(suite.collection)
suite.NoError(err)
suite.sealed.Release()
suite.sealed.Release(context.Background())
req := &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
CollectionID: suite.collectionID,

View File

@ -20,13 +20,18 @@ import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/atomic"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
@ -34,30 +39,20 @@ import (
// searchOnSegments performs search on listed segments
// all segment ids are validated before calling this function
func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segType SegmentType, searchReq *SearchRequest) ([]*SearchResult, error) {
var (
// results variables
resultCh = make(chan *SearchResult, len(segments))
errs = make([]error, len(segments))
wg sync.WaitGroup
// For log only
mu sync.Mutex
segmentsWithoutIndex []int64
)
searchLabel := metrics.SealedSegmentLabel
if segType == commonpb.SegmentState_Growing {
searchLabel = metrics.GrowingSegmentLabel
}
searcher := func(s Segment) error {
resultCh := make(chan *SearchResult, len(segments))
searcher := func(ctx context.Context, s Segment) error {
// record search time
tr := timerecord.NewTimeRecorder("searchOnSegments")
searchResult, err := s.Search(ctx, searchReq)
resultCh <- searchResult
if err != nil {
return err
}
resultCh <- searchResult
// update metrics
elapsed := tr.ElapseSpan().Milliseconds()
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
@ -68,35 +63,41 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
}
// calling segment search in goroutines
for i, segment := range segments {
wg.Add(1)
go func(seg Segment, i int) {
defer wg.Done()
if !seg.ExistIndex(searchReq.searchFieldID) {
mu.Lock()
segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID())
mu.Unlock()
errGroup, ctx := errgroup.WithContext(ctx)
segmentsWithoutIndex := make([]int64, 0)
for _, segment := range segments {
seg := segment
if !seg.ExistIndex(searchReq.searchFieldID) {
segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID())
}
errGroup.Go(func() error {
if ctx.Err() != nil {
return ctx.Err()
}
var err error
accessRecord := metricsutil.NewSearchSegmentAccessRecord(getSegmentMetricLabel(seg))
defer func() {
accessRecord.Finish(err)
}()
if seg.IsLazyLoad() {
var timeout time.Duration
timeout, err = lazyloadWaitTimeout(ctx)
if err != nil {
return err
}
var missing bool
missing, err = mgr.DiskCache.Do(seg.ID(), searcher)
missing, err = mgr.DiskCache.DoWait(ctx, seg.ID(), timeout, searcher)
if missing {
accessRecord.CacheMissing()
}
} else {
err = searcher(seg)
return err
}
if err != nil {
errs[i] = err
}
}(segment, i)
return searcher(ctx, seg)
})
}
wg.Wait()
err := errGroup.Wait()
close(resultCh)
searchResults := make([]*SearchResult, 0, len(segments))
@ -104,11 +105,9 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
searchResults = append(searchResults, result)
}
for _, err := range errs {
if err != nil {
DeleteSearchResults(searchResults)
return nil, err
}
if err != nil {
DeleteSearchResults(searchResults)
return nil, err
}
if len(segmentsWithoutIndex) > 0 {
@ -118,11 +117,115 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
return searchResults, nil
}
// searchSegmentsStreamly performs search on listed segments in a stream mode instead of a batch mode
// all segment ids are validated before calling this function
func searchSegmentsStreamly(ctx context.Context,
mgr *Manager,
segments []Segment,
searchReq *SearchRequest,
streamReduce func(result *SearchResult) error,
) error {
searchLabel := metrics.SealedSegmentLabel
searchResultsToClear := make([]*SearchResult, 0)
var reduceMutex sync.Mutex
var sumReduceDuration atomic.Duration
searcher := func(ctx context.Context, seg Segment) error {
// record search time
tr := timerecord.NewTimeRecorder("searchOnSegments")
searchResult, searchErr := seg.Search(ctx, searchReq)
searchDuration := tr.RecordSpan().Milliseconds()
if searchErr != nil {
return searchErr
}
reduceMutex.Lock()
searchResultsToClear = append(searchResultsToClear, searchResult)
reducedErr := streamReduce(searchResult)
reduceMutex.Unlock()
reduceDuration := tr.RecordSpan()
if reducedErr != nil {
return reducedErr
}
sumReduceDuration.Add(reduceDuration)
// update metrics
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(searchDuration))
metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(searchDuration) / float64(searchReq.getNumOfQuery()))
return nil
}
// calling segment search in goroutines
errGroup, ctx := errgroup.WithContext(ctx)
log := log.Ctx(ctx)
for _, segment := range segments {
seg := segment
errGroup.Go(func() error {
if ctx.Err() != nil {
return ctx.Err()
}
var err error
accessRecord := metricsutil.NewSearchSegmentAccessRecord(getSegmentMetricLabel(seg))
defer func() {
accessRecord.Finish(err)
}()
if seg.IsLazyLoad() {
log.Debug("before doing stream search in DiskCache", zap.Int64("segID", seg.ID()))
var timeout time.Duration
timeout, err = lazyloadWaitTimeout(ctx)
if err != nil {
return err
}
var missing bool
missing, err = mgr.DiskCache.DoWait(ctx, seg.ID(), timeout, searcher)
if missing {
accessRecord.CacheMissing()
}
if err != nil {
log.Error("failed to do search for disk cache", zap.Int64("seg_id", seg.ID()), zap.Error(err))
}
log.Debug("after doing stream search in DiskCache", zap.Int64("segID", seg.ID()), zap.Error(err))
return err
}
return searcher(ctx, seg)
})
}
err := errGroup.Wait()
DeleteSearchResults(searchResultsToClear)
if err != nil {
return err
}
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel,
metrics.ReduceSegments,
metrics.StreamReduce).Observe(float64(sumReduceDuration.Load().Milliseconds()))
log.Debug("stream reduce sum duration:", zap.Duration("duration", sumReduceDuration.Load()))
return nil
}
func lazyloadWaitTimeout(ctx context.Context) (time.Duration, error) {
timeout := params.Params.QueryNodeCfg.LazyLoadWaitTimeout.GetAsDuration(time.Millisecond)
deadline, ok := ctx.Deadline()
if ok {
remain := time.Until(deadline)
if remain <= 0 {
return -1, merr.WrapErrServiceInternal("search context deadline exceeded")
} else if remain < timeout {
timeout = remain
}
}
return timeout, nil
}
// search will search on the historical segments the target segments in historical.
// if segIDs is not specified, it will search on all the historical segments speficied by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return nil, nil, err
@ -134,6 +237,10 @@ func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRe
// searchStreaming will search all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return nil, nil, err
@ -141,3 +248,21 @@ func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchReq
searchResults, err := searchSegments(ctx, manager, segments, SegmentTypeGrowing, searchReq)
return searchResults, segments, err
}
func SearchHistoricalStreamly(ctx context.Context, manager *Manager, searchReq *SearchRequest,
collID int64, partIDs []int64, segIDs []int64, streamReduce func(result *SearchResult) error,
) ([]Segment, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
if err != nil {
return segments, err
}
err = searchSegmentsStreamly(ctx, manager, segments, searchReq, streamReduce)
if err != nil {
return segments, err
}
return segments, nil
}

View File

@ -83,6 +83,7 @@ func (suite *SearchSuite) SetupTest() {
CollectionID: suite.collectionID,
PartitionID: suite.partitionID,
InsertChannel: "dml",
NumOfRows: int64(msgLength),
Level: datapb.SegmentLevel_Legacy,
},
)
@ -122,12 +123,12 @@ func (suite *SearchSuite) SetupTest() {
suite.Require().NoError(err)
suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord)
suite.manager.Segment.Put(SegmentTypeSealed, suite.sealed)
suite.manager.Segment.Put(SegmentTypeGrowing, suite.growing)
suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed)
suite.manager.Segment.Put(context.Background(), SegmentTypeGrowing, suite.growing)
}
func (suite *SearchSuite) TearDownTest() {
suite.sealed.Release()
suite.sealed.Release(context.Background())
DeleteCollection(suite.collection)
ctx := context.Background()
suite.chunkManager.RemoveWithPrefix(ctx, paramtable.Get().MinioCfg.RootPath.GetValue())

View File

@ -29,6 +29,7 @@ import (
"context"
"fmt"
"io"
"strconv"
"strings"
"unsafe"
@ -47,6 +48,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
"github.com/milvus-io/milvus/internal/storage"
@ -73,50 +75,59 @@ var ErrSegmentUnhealthy = errors.New("segment unhealthy")
type IndexedFieldInfo struct {
FieldBinlog *datapb.FieldBinlog
IndexInfo *querypb.FieldIndexInfo
LazyLoad bool
IsLoaded bool
}
type baseSegment struct {
collection *Collection
version *atomic.Int64
// the load status of the segment,
// only transitions below are allowed:
// 1. LoadStatusMeta <-> LoadStatusMapped
// 2. LoadStatusMeta <-> LoadStatusInMemory
loadStatus *atomic.String
segmentType SegmentType
bloomFilterSet *pkoracle.BloomFilterSet
loadInfo *querypb.SegmentLoadInfo
loadInfo *atomic.Pointer[querypb.SegmentLoadInfo]
isLazyLoad bool
resourceUsageCache *atomic.Pointer[ResourceUsage]
needUpdatedVersion *atomic.Int64 // only for lazy load mode update index
}
func newBaseSegment(collection *Collection, segmentType SegmentType, version int64, loadInfo *querypb.SegmentLoadInfo) baseSegment {
return baseSegment{
collection: collection,
loadInfo: loadInfo,
version: atomic.NewInt64(version),
loadStatus: atomic.NewString(string(LoadStatusMeta)),
segmentType: segmentType,
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
collection: collection,
loadInfo: atomic.NewPointer[querypb.SegmentLoadInfo](loadInfo),
version: atomic.NewInt64(version),
isLazyLoad: isLazyLoad(collection, segmentType),
segmentType: segmentType,
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
resourceUsageCache: atomic.NewPointer[ResourceUsage](nil),
needUpdatedVersion: atomic.NewInt64(0),
}
}
// isLazyLoad checks if the segment is lazy load
func isLazyLoad(collection *Collection, segmentType SegmentType) bool {
return segmentType == SegmentTypeSealed && // only sealed segment enable lazy load
(common.IsCollectionLazyLoadEnabled(collection.Schema().Properties...) || // collection level lazy load
(!common.HasLazyload(collection.Schema().Properties) &&
params.Params.QueryNodeCfg.LazyLoadEnabled.GetAsBool())) // global level lazy load
}
// ID returns the identity number.
func (s *baseSegment) ID() int64 {
return s.loadInfo.GetSegmentID()
return s.loadInfo.Load().GetSegmentID()
}
func (s *baseSegment) Collection() int64 {
return s.loadInfo.GetCollectionID()
return s.loadInfo.Load().GetCollectionID()
}
func (s *baseSegment) GetCollection() *Collection {
return s.collection
}
func (s *baseSegment) Partition() int64 {
return s.loadInfo.GetPartitionID()
return s.loadInfo.Load().GetPartitionID()
}
func (s *baseSegment) DatabaseName() string {
@ -128,7 +139,7 @@ func (s *baseSegment) ResourceGroup() string {
}
func (s *baseSegment) Shard() string {
return s.loadInfo.GetInsertChannel()
return s.loadInfo.Load().GetInsertChannel()
}
func (s *baseSegment) Type() SegmentType {
@ -136,11 +147,11 @@ func (s *baseSegment) Type() SegmentType {
}
func (s *baseSegment) Level() datapb.SegmentLevel {
return s.loadInfo.GetLevel()
return s.loadInfo.Load().GetLevel()
}
func (s *baseSegment) StartPosition() *msgpb.MsgPosition {
return s.loadInfo.GetStartPosition()
return s.loadInfo.Load().GetStartPosition()
}
func (s *baseSegment) Version() int64 {
@ -151,16 +162,8 @@ func (s *baseSegment) CASVersion(old, newVersion int64) bool {
return s.version.CompareAndSwap(old, newVersion)
}
func (s *baseSegment) LoadStatus() LoadStatus {
return LoadStatus(s.loadStatus.Load())
}
func (s *baseSegment) LoadInfo() *querypb.SegmentLoadInfo {
if s.segmentType == SegmentTypeGrowing {
// Growing segment do not have load info.
return nil
}
return s.loadInfo
return s.loadInfo.Load()
}
func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
@ -185,7 +188,7 @@ func (s *baseSegment) ResourceUsageEstimate() ResourceUsage {
return *cache
}
usage, err := getResourceUsageEstimateOfSegment(s.collection.Schema(), s.loadInfo, resourceEstimateFactor{
usage, err := getResourceUsageEstimateOfSegment(s.collection.Schema(), s.LoadInfo(), resourceEstimateFactor{
memoryUsageFactor: 1.0,
memoryIndexUsageFactor: 1.0,
enableTempSegmentIndex: false,
@ -200,7 +203,21 @@ func (s *baseSegment) ResourceUsageEstimate() ResourceUsage {
return *usage
}
func (s *baseSegment) IsLazyLoad() bool { return s.isLazyLoad }
func (s *baseSegment) IsLazyLoad() bool {
return s.isLazyLoad
}
func (s *baseSegment) NeedUpdatedVersion() int64 {
return s.needUpdatedVersion.Load()
}
func (s *baseSegment) SetLoadInfo(loadInfo *querypb.SegmentLoadInfo) {
s.loadInfo.Store(loadInfo)
}
func (s *baseSegment) SetNeedUpdatedVersion(version int64) {
s.needUpdatedVersion.Store(version)
}
type FieldInfo struct {
datapb.FieldBinlog
@ -288,10 +305,9 @@ func NewSegment(ctx context.Context,
insertCount: atomic.NewInt64(0),
}
if segmentType != SegmentTypeSealed {
segment.loadStatus.Store(string(LoadStatusInMemory))
if err := segment.initializeSegment(); err != nil {
return nil, err
}
return segment, nil
}
@ -355,11 +371,57 @@ func NewSegmentV2(
insertCount: atomic.NewInt64(0),
}
if segmentType != SegmentTypeSealed {
segment.loadStatus.Store(string(LoadStatusInMemory))
if err := segment.initializeSegment(); err != nil {
return nil, err
}
return segment, nil
}
func (s *LocalSegment) initializeSegment() error {
loadInfo := s.loadInfo.Load()
indexedFieldInfos, fieldBinlogs := separateIndexAndBinlog(loadInfo)
schemaHelper, _ := typeutil.CreateSchemaHelper(s.collection.Schema())
for fieldID, info := range indexedFieldInfos {
field, err := schemaHelper.GetFieldFromID(fieldID)
if err != nil {
return err
}
indexInfo := info.IndexInfo
s.fieldIndexes.Insert(indexInfo.GetFieldID(), &IndexedFieldInfo{
FieldBinlog: &datapb.FieldBinlog{
FieldID: indexInfo.GetFieldID(),
},
IndexInfo: indexInfo,
IsLoaded: false,
})
if !typeutil.IsVectorType(field.GetDataType()) && !s.HasRawData(fieldID) {
s.fields.Insert(fieldID, &FieldInfo{
FieldBinlog: *info.FieldBinlog,
RowCount: loadInfo.GetNumOfRows(),
})
}
}
return segment, nil
for _, binlogs := range fieldBinlogs {
s.fields.Insert(binlogs.FieldID, &FieldInfo{
FieldBinlog: *binlogs,
RowCount: loadInfo.GetNumOfRows(),
})
}
// Update the insert count when initialize the segment and update the metrics.
s.insertCount.Store(loadInfo.GetNumOfRows())
metrics.QueryNodeNumEntities.WithLabelValues(
s.DatabaseName(),
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(s.Collection()),
fmt.Sprint(s.Partition()),
s.Type().String(),
strconv.FormatInt(int64(len(s.Indexes())), 10),
).Add(float64(loadInfo.GetNumOfRows()))
return nil
}
// PinIfNotReleased acquires the `ptrLock` and returns true if the pointer is valid
@ -426,10 +488,6 @@ func (s *LocalSegment) LastDeltaTimestamp() uint64 {
return s.lastDeltaTimestamp.Load()
}
func (s *LocalSegment) addIndex(fieldID int64, info *IndexedFieldInfo) {
s.fieldIndexes.Insert(fieldID, info)
}
func (s *LocalSegment) GetIndex(fieldID int64) *IndexedFieldInfo {
info, _ := s.fieldIndexes.Get(fieldID)
return info
@ -464,7 +522,7 @@ func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
func (s *LocalSegment) ResetIndexesLazyLoad(lazyState bool) {
for _, indexInfo := range s.Indexes() {
indexInfo.LazyLoad = lazyState
indexInfo.IsLoaded = lazyState
}
}
@ -720,6 +778,15 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
}
s.insertCount.Add(int64(numOfRow))
metrics.QueryNodeNumEntities.WithLabelValues(
s.DatabaseName(),
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(s.Collection()),
fmt.Sprint(s.Partition()),
s.Type().String(),
strconv.FormatInt(int64(len(s.Indexes())), 10),
).Add(float64(numOfRow))
s.rowNum.Store(-1)
s.memSize.Store(-1)
return nil
@ -801,7 +868,11 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary
}
// -------------------------------------------------------------------------------------- interfaces for sealed segment
func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error {
func (s *LocalSegment) LoadMultiFieldData(ctx context.Context) error {
loadInfo := s.loadInfo.Load()
rowCount := loadInfo.GetNumOfRows()
fields := loadInfo.GetBinlogPaths()
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
@ -859,7 +930,6 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, f
return err
}
s.insertCount.Store(rowCount)
log.Info("load mutil field done",
zap.Int64("row count", rowCount),
zap.Int64("segmentID", s.ID()))
@ -867,43 +937,7 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, f
return nil
}
type loadOptions struct {
LoadStatus LoadStatus
}
func newLoadOptions() *loadOptions {
return &loadOptions{
LoadStatus: LoadStatusInMemory,
}
}
type loadOption func(*loadOptions)
func WithLoadStatus(loadStatus LoadStatus) loadOption {
return func(options *loadOptions) {
options.LoadStatus = loadStatus
}
}
func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCount int64, field *datapb.FieldBinlog, opts ...loadOption) error {
options := newLoadOptions()
for _, opt := range opts {
opt(options)
}
if field != nil {
s.fields.Insert(fieldID, &FieldInfo{
FieldBinlog: *field,
RowCount: rowCount,
})
}
if options.LoadStatus == LoadStatusMeta {
return nil
}
s.loadStatus.Store(string(options.LoadStatus))
func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCount int64, field *datapb.FieldBinlog) error {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
@ -941,7 +975,9 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
}
}
mmapEnabled := options.LoadStatus == LoadStatusMapped
collection := s.collection
mmapEnabled := common.IsFieldMmapEnabled(collection.Schema(), fieldID) ||
(!common.FieldHasMmapKey(collection.Schema(), fieldID) && params.Params.QueryNodeCfg.MmapEnabled.GetAsBool())
loadFieldDataInfo.appendMMapDirPath(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue())
loadFieldDataInfo.enableMmap(fieldID, mmapEnabled)
@ -970,7 +1006,6 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
return err
}
s.insertCount.Store(rowCount)
log.Info("load field done")
return nil
@ -1194,12 +1229,7 @@ func (s *LocalSegment) LoadDeltaData(ctx context.Context, deltaData *storage.Del
return nil
}
func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIndexInfo, fieldType schemapb.DataType, opts ...loadOption) error {
options := newLoadOptions()
for _, opt := range opts {
opt(options)
}
func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIndexInfo, fieldType schemapb.DataType) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
@ -1208,19 +1238,9 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn
zap.Int64("indexID", indexInfo.GetIndexID()),
)
if options.LoadStatus == LoadStatusMeta {
s.addIndex(indexInfo.GetFieldID(), &IndexedFieldInfo{
FieldBinlog: &datapb.FieldBinlog{
FieldID: indexInfo.GetFieldID(),
},
IndexInfo: indexInfo,
LazyLoad: true,
})
return nil
}
old := s.GetIndex(indexInfo.GetFieldID())
// the index loaded
if old != nil && old.IndexInfo.GetIndexID() == indexInfo.GetIndexID() && !old.LazyLoad {
if old != nil && old.IndexInfo.GetIndexID() == indexInfo.GetIndexID() && old.IsLoaded {
log.Warn("index already loaded")
return nil
}
@ -1228,20 +1248,23 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, fmt.Sprintf("LoadIndex-%d-%d", s.ID(), indexInfo.GetFieldID()))
defer sp.End()
tr := timerecord.NewTimeRecorder("loadIndex")
// 1.
loadIndexInfo, err := newLoadIndexInfo(ctx)
if err != nil {
return err
}
defer deleteLoadIndexInfo(loadIndexInfo)
if paramtable.Get().CommonCfg.EnableStorageV2.GetAsBool() {
uri, err := typeutil_internal.GetStorageURI(paramtable.Get().CommonCfg.StorageScheme.GetValue(), paramtable.Get().CommonCfg.StoragePathPrefix.GetValue(), s.ID())
if err != nil {
return err
}
loadIndexInfo.appendStorageInfo(uri, indexInfo.IndexStoreVersion)
}
newLoadIndexInfoSpan := tr.RecordSpan()
// 2.
err = loadIndexInfo.appendLoadIndexInfo(ctx, indexInfo, s.Collection(), s.Partition(), s.ID(), fieldType)
if err != nil {
if loadIndexInfo.cleanLocalData(ctx) != nil {
@ -1255,16 +1278,27 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn
errMsg := fmt.Sprintln("updateSegmentIndex failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID())
return errors.New(errMsg)
}
appendLoadIndexInfoSpan := tr.RecordSpan()
// 3.
err = s.UpdateIndexInfo(ctx, indexInfo, loadIndexInfo)
if err != nil {
return err
}
updateIndexInfoSpan := tr.RecordSpan()
if !typeutil.IsVectorType(fieldType) || s.HasRawData(indexInfo.GetFieldID()) {
return nil
}
// 4.
s.WarmupChunkCache(ctx, indexInfo.GetFieldID())
warmupChunkCacheSpan := tr.RecordSpan()
log.Info("Finish loading index",
zap.Duration("newLoadIndexInfoSpan", newLoadIndexInfoSpan),
zap.Duration("appendLoadIndexInfoSpan", appendLoadIndexInfoSpan),
zap.Duration("updateIndexInfoSpan", updateIndexInfoSpan),
zap.Duration("updateIndexInfoSpan", warmupChunkCacheSpan),
)
return nil
}
@ -1294,12 +1328,12 @@ func (s *LocalSegment) UpdateIndexInfo(ctx context.Context, indexInfo *querypb.F
return err
}
s.addIndex(indexInfo.GetFieldID(), &IndexedFieldInfo{
s.fieldIndexes.Insert(indexInfo.GetFieldID(), &IndexedFieldInfo{
FieldBinlog: &datapb.FieldBinlog{
FieldID: indexInfo.GetFieldID(),
},
IndexInfo: indexInfo,
LazyLoad: false,
IsLoaded: true,
})
log.Info("updateSegmentIndex done")
return nil
@ -1399,7 +1433,7 @@ func WithReleaseScope(scope ReleaseScope) releaseOption {
}
}
func (s *LocalSegment) Release(opts ...releaseOption) {
func (s *LocalSegment) Release(ctx context.Context, opts ...releaseOption) {
options := newReleaseOptions()
for _, opt := range opts {
opt(options)
@ -1411,7 +1445,7 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
// release will never fail
defer stateLockGuard.Done(nil)
log := log.With(zap.Int64("collectionID", s.Collection()),
log := log.Ctx(ctx).With(zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.String("segmentType", s.segmentType.String()),
@ -1421,10 +1455,8 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
// wait all read ops finished
ptr := s.ptr
if options.Scope == ReleaseScopeData {
s.loadStatus.Store(string(LoadStatusMeta))
C.ClearSegmentData(ptr)
s.ResetIndexesLazyLoad(true)
log.Debug("release segment data done and the field indexes info has been set lazy load=true")
s.ReleaseSegmentData()
log.Info("release segment data done and the field indexes info has been set lazy load=true")
return
}
@ -1439,6 +1471,14 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
log.Info("delete segment from memory")
}
// ReleaseSegmentData releases the segment data.
func (s *LocalSegment) ReleaseSegmentData() {
C.ClearSegmentData(s.ptr)
for _, indexInfo := range s.Indexes() {
indexInfo.IsLoaded = false
}
}
// StartLoadData starts the loading process of the segment.
func (s *LocalSegment) StartLoadData() (state.LoadStateLockGuard, error) {
return s.ptrLock.StartLoadData()
@ -1455,3 +1495,34 @@ func (s *LocalSegment) startRelease(scope ReleaseScope) state.LoadStateLockGuard
panic(fmt.Sprintf("unexpected release scope %d", scope))
}
}
func (s *LocalSegment) RemoveFieldFile(fieldId int64) {
C.RemoveFieldFile(s.ptr, C.int64_t(fieldId))
}
func (s *LocalSegment) RemoveUnusedFieldFiles() error {
schema := s.collection.Schema()
indexInfos, _ := separateIndexAndBinlog(s.LoadInfo())
for _, indexInfo := range indexInfos {
need, err := s.indexNeedLoadRawData(schema, indexInfo)
if err != nil {
return err
}
if !need {
s.RemoveFieldFile(indexInfo.IndexInfo.FieldID)
}
}
return nil
}
func (s *LocalSegment) indexNeedLoadRawData(schema *schemapb.CollectionSchema, indexInfo *IndexedFieldInfo) (bool, error) {
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
if err != nil {
return false, err
}
fieldSchema, err := schemaHelper.GetFieldFromID(indexInfo.IndexInfo.FieldID)
if err != nil {
return false, err
}
return !typeutil.IsVectorType(fieldSchema.DataType) && s.HasRawData(indexInfo.IndexInfo.FieldID), nil
}

View File

@ -27,14 +27,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type LoadStatus string
const (
LoadStatusMeta LoadStatus = "meta"
LoadStatusMapped LoadStatus = "mapped"
LoadStatusInMemory LoadStatus = "in_memory"
)
// ResourceUsage is used to estimate the resource usage of a sealed segment.
type ResourceUsage struct {
MemorySize uint64
@ -60,7 +52,6 @@ type Segment interface {
StartPosition() *msgpb.MsgPosition
Type() SegmentType
Level() datapb.SegmentLevel
LoadStatus() LoadStatus
LoadInfo() *querypb.SegmentLoadInfo
// PinIfNotReleased the segment to prevent it from being released
PinIfNotReleased() error
@ -87,7 +78,7 @@ type Segment interface {
Delete(ctx context.Context, primaryKeys []storage.PrimaryKey, timestamps []typeutil.Timestamp) error
LoadDeltaData(ctx context.Context, deltaData *storage.DeleteData) error
LastDeltaTimestamp() uint64
Release(opts ...releaseOption)
Release(ctx context.Context, opts ...releaseOption)
// Bloom filter related
UpdateBloomFilter(pks []storage.PrimaryKey)
@ -99,4 +90,8 @@ type Segment interface {
RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error)
IsLazyLoad() bool
ResetIndexesLazyLoad(lazyState bool)
// lazy load related
NeedUpdatedVersion() int64
RemoveUnusedFieldFiles() error
}

View File

@ -63,8 +63,6 @@ func NewL0Segment(collection *Collection,
}
// level 0 segments are always in memory
segment.loadStatus.Store(string(LoadStatusInMemory))
return segment, nil
}
@ -164,10 +162,14 @@ func (s *L0Segment) DeleteRecords() ([]storage.PrimaryKey, []uint64) {
return s.pks, s.tss
}
func (s *L0Segment) Release(opts ...releaseOption) {
func (s *L0Segment) Release(ctx context.Context, opts ...releaseOption) {
s.dataGuard.Lock()
defer s.dataGuard.Unlock()
s.pks = nil
s.tss = nil
}
func (s *L0Segment) RemoveUnusedFieldFiles() error {
panic("not implemented")
}

View File

@ -82,7 +82,11 @@ type Loader interface {
LoadSegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
loadStatus LoadStatus,
) error
LoadLazySegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
) error
}
@ -171,7 +175,7 @@ func (loader *segmentLoaderV2) Load(ctx context.Context,
loaded := typeutil.NewConcurrentMap[int64, Segment]()
defer func() {
newSegments.Range(func(_ int64, s Segment) bool {
s.Release()
s.Release(context.Background())
return true
})
debug.FreeOSMemory()
@ -214,7 +218,7 @@ func (loader *segmentLoaderV2) Load(ctx context.Context,
if loadInfo.GetLevel() == datapb.SegmentLevel_L0 {
err = loader.LoadDelta(ctx, collectionID, segment.(*LocalSegment))
} else {
err = loader.LoadSegment(ctx, segment.(*LocalSegment), loadInfo, LoadStatusInMemory)
err = loader.LoadSegment(ctx, segment.(*LocalSegment), loadInfo)
}
if err != nil {
log.Warn("load segment failed when load data into memory",
@ -224,7 +228,7 @@ func (loader *segmentLoaderV2) Load(ctx context.Context,
)
return err
}
loader.manager.Segment.Put(segmentType, segment)
loader.manager.Segment.Put(ctx, segmentType, segment)
newSegments.GetAndRemove(segmentID)
loaded.Insert(segmentID, segment)
log.Info("load segment done", zap.Int64("segmentID", segmentID))
@ -374,7 +378,6 @@ func (loader *segmentLoaderV2) loadBloomFilter(ctx context.Context, segmentID in
func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
loadstatus LoadStatus,
) (err error) {
// TODO: we should create a transaction-like api to load segment for segment interface,
// but not do many things in segment loader.
@ -461,7 +464,7 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
return err
}
} else {
if err := segment.LoadMultiFieldData(ctx, loadInfo.GetNumOfRows(), loadInfo.BinlogPaths); err != nil {
if err := segment.LoadMultiFieldData(ctx); err != nil {
return err
}
}
@ -480,6 +483,13 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
return loader.LoadDelta(ctx, segment.Collection(), segment)
}
func (loader *segmentLoaderV2) LoadLazySegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
) (err error) {
return merr.ErrOperationNotSupported
}
func (loader *segmentLoaderV2) loadSealedSegmentFields(ctx context.Context, segment *LocalSegment, fields *typeutil.ConcurrentMap[int64, *schemapb.FieldSchema], rowCount int64) error {
runningGroup, _ := errgroup.WithContext(ctx)
fields.Range(func(fieldID int64, field *schemapb.FieldSchema) bool {
@ -597,14 +607,20 @@ func (loader *segmentLoader) Load(ctx context.Context,
// continue to wait other task done
log.Info("start loading...", zap.Int("segmentNum", len(segments)), zap.Int("afterFilter", len(infos)))
// Check memory & storage limit
resource, concurrencyLevel, err := loader.requestResource(ctx, infos...)
if err != nil {
log.Warn("request resource failed", zap.Error(err))
return nil, err
var err error
var resource LoadResource
var concurrencyLevel int
coll := loader.manager.Collection.Get(collectionID)
if !isLazyLoad(coll, segmentType) {
// Check memory & storage limit
// no need to check resource for lazy load here
resource, concurrencyLevel, err = loader.requestResource(ctx, infos...)
if err != nil {
log.Warn("request resource failed", zap.Error(err))
return nil, err
}
defer loader.freeRequest(resource)
}
defer loader.freeRequest(resource)
newSegments := typeutil.NewConcurrentMap[int64, Segment]()
loaded := typeutil.NewConcurrentMap[int64, Segment]()
defer func() {
@ -613,13 +629,12 @@ func (loader *segmentLoader) Load(ctx context.Context,
zap.Int64("segmentID", segmentID),
zap.Error(err),
)
s.Release()
s.Release(context.Background())
return true
})
debug.FreeOSMemory()
}()
loadStatus := LoadStatusInMemory
collection := loader.manager.Collection.Get(collectionID)
if collection == nil {
err := merr.WrapErrCollectionNotFound(collectionID)
@ -627,11 +642,6 @@ func (loader *segmentLoader) Load(ctx context.Context,
return nil, err
}
if common.IsCollectionLazyLoadEnabled(collection.Schema().Properties...) ||
(!common.HasLazyload(collection.Schema().Properties) && params.Params.QueryNodeCfg.LazyLoadEnabled.GetAsBool()) {
loadStatus = LoadStatusMeta
}
for _, info := range infos {
loadInfo := info
@ -674,17 +684,21 @@ func (loader *segmentLoader) Load(ctx context.Context,
tr := timerecord.NewTimeRecorder("loadDurationPerSegment")
logger.Info("load segment...")
// L0 segment has no index or data to be load
// L0 segment has no index or data to be load.
if loadInfo.GetLevel() != datapb.SegmentLevel_L0 {
if err = loader.LoadSegment(ctx, segment.(*LocalSegment), loadInfo, loadStatus); err != nil {
return errors.Wrap(err, "At LoadSegment")
s := segment.(*LocalSegment)
// lazy load segment do not load segment at first time.
if !s.IsLazyLoad() {
if err = loader.LoadSegment(ctx, s, loadInfo); err != nil {
return errors.Wrap(err, "At LoadSegment")
}
}
}
if err = loader.LoadDeltaLogs(ctx, segment, loadInfo.GetDeltalogs()); err != nil {
return errors.Wrap(err, "At LoadDeltaLogs")
}
loader.manager.Segment.Put(segmentType, segment)
loader.manager.Segment.Put(ctx, segmentType, segment)
newSegments.GetAndRemove(segmentID)
loaded.Insert(segmentID, segment)
loader.notifyLoadFinish(loadInfo)
@ -975,23 +989,45 @@ func separateIndexAndBinlog(loadInfo *querypb.SegmentLoadInfo) (map[int64]*Index
return indexedFieldInfos, fieldBinlogs
}
func (loader *segmentLoader) loadSealedSegment(ctx context.Context, loadInfo *querypb.SegmentLoadInfo, segment *LocalSegment, collection *Collection, loadStatus LoadStatus) error {
func (loader *segmentLoader) loadSealedSegment(ctx context.Context, loadInfo *querypb.SegmentLoadInfo, segment *LocalSegment) error {
// TODO: we should create a transaction-like api to load segment for segment interface,
// but not do many things in segment loader.
stateLockGuard, err := segment.StartLoadData()
// segment can not do load now.
if err != nil {
return err
}
if stateLockGuard == nil {
return nil
}
defer func() {
if err != nil {
// Release partial loaded segment data if load failed.
segment.ReleaseSegmentData()
}
stateLockGuard.Done(err)
}()
collection := segment.GetCollection()
indexedFieldInfos, fieldBinlogs := separateIndexAndBinlog(loadInfo)
schemaHelper, _ := typeutil.CreateSchemaHelper(collection.Schema())
if err := segment.AddFieldDataInfo(ctx, loadInfo.GetNumOfRows(), loadInfo.GetBinlogPaths()); err != nil {
return err
}
log := log.Ctx(ctx).With(zap.Int64("segmentID", segment.ID()))
tr := timerecord.NewTimeRecorder("segmentLoader.LoadIndex")
log.Info("load fields...",
tr := timerecord.NewTimeRecorder("segmentLoader.loadSealedSegment")
log.Info("Start loading fields...",
zap.Int64s("indexedFields", lo.Keys(indexedFieldInfos)),
)
if err := loader.loadFieldsIndex(ctx, schemaHelper, segment, loadInfo.GetNumOfRows(), indexedFieldInfos, WithLoadStatus(loadStatus)); err != nil {
if err := loader.loadFieldsIndex(ctx, schemaHelper, segment, loadInfo.GetNumOfRows(), indexedFieldInfos); err != nil {
return err
}
metrics.QueryNodeLoadIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(tr.ElapseSpan().Milliseconds()))
loadFieldsIndexSpan := tr.RecordSpan()
metrics.QueryNodeLoadIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(loadFieldsIndexSpan))
// 2. complement raw data for the scalar fields without raw data
for fieldID, info := range indexedFieldInfos {
field, err := schemaHelper.GetFieldFromID(fieldID)
if err != nil {
@ -1002,49 +1038,39 @@ func (loader *segmentLoader) loadSealedSegment(ctx context.Context, loadInfo *qu
zap.Int64("fieldID", fieldID),
zap.String("index", info.IndexInfo.GetIndexName()),
)
status := loadStatus
if status != LoadStatusMeta {
status = LoadStatusMapped
}
// for scalar index's raw data, only load to mmap not memory
if err = segment.LoadFieldData(ctx, fieldID, loadInfo.GetNumOfRows(), info.FieldBinlog, WithLoadStatus(status)); err != nil {
if err = segment.LoadFieldData(ctx, fieldID, loadInfo.GetNumOfRows(), info.FieldBinlog); err != nil {
log.Warn("load raw data failed", zap.Int64("fieldID", fieldID), zap.Error(err))
return err
}
}
}
if err := loadSealedSegmentFields(ctx, collection, segment, fieldBinlogs, loadInfo.GetNumOfRows(), WithLoadStatus(loadStatus)); err != nil {
complementScalarDataSpan := tr.RecordSpan()
if err := loadSealedSegmentFields(ctx, collection, segment, fieldBinlogs, loadInfo.GetNumOfRows()); err != nil {
return err
}
loadRawDataSpan := tr.RecordSpan()
// 4. rectify entries number for binlog in very rare cases
// https://github.com/milvus-io/milvus/23654
// legacy entry num = 0
if err := loader.patchEntryNumber(ctx, segment, loadInfo); err != nil {
return err
}
patchEntryNumberSpan := tr.RecordSpan()
log.Info("Finish loading segment",
zap.Duration("loadFieldsIndexSpan", loadFieldsIndexSpan),
zap.Duration("complementScalarDataSpan", complementScalarDataSpan),
zap.Duration("loadRawDataSpan", loadRawDataSpan),
zap.Duration("patchEntryNumberSpan", patchEntryNumberSpan),
)
return nil
}
func (loader *segmentLoader) LoadSegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
loadStatus LoadStatus,
) (err error) {
// TODO: we should create a transaction-like api to load segment for segment interface,
// but not do many things in segment loader.
stateLockGuard, err := segment.StartLoadData()
// segment can not do load now.
if err != nil {
return err
}
defer func() {
// segment is already loaded.
// TODO: if stateLockGuard is nil, we should not call LoadSegment anymore.
// but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug.
if stateLockGuard != nil {
stateLockGuard.Done(err)
}
}()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", segment.Collection()),
zap.Int64("partitionID", segment.Partition()),
@ -1068,15 +1094,11 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
defer debug.FreeOSMemory()
if segment.Type() == SegmentTypeSealed {
if loadStatus == LoadStatusMeta {
segment.baseSegment.isLazyLoad = true
segment.baseSegment.loadInfo = loadInfo
}
if err := loader.loadSealedSegment(ctx, loadInfo, segment, collection, loadStatus); err != nil {
if err := loader.loadSealedSegment(ctx, loadInfo, segment); err != nil {
return err
}
} else {
if err := segment.LoadMultiFieldData(ctx, loadInfo.GetNumOfRows(), loadInfo.BinlogPaths); err != nil {
if err := segment.LoadMultiFieldData(ctx); err != nil {
return err
}
}
@ -1090,10 +1112,24 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
return err
}
}
return nil
}
func (loader *segmentLoader) LoadLazySegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
) (err error) {
infos := []*querypb.SegmentLoadInfo{loadInfo}
resource, _, err := loader.requestResource(ctx, infos...)
log := log.Ctx(ctx)
if err != nil {
log.Warn("request resource failed", zap.Error(err))
return merr.ErrServiceResourceInsufficient
}
defer loader.freeRequest(resource)
return loader.LoadSegment(ctx, segment, loadInfo)
}
func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBinlog, pkFieldID int64) ([]string, storage.StatsLogType) {
result := make([]string, 0)
for _, fieldBinlog := range fieldBinlogs {
@ -1114,27 +1150,16 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi
return result, storage.DefaultStatsType
}
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64, opts ...loadOption) error {
options := newLoadOptions()
for _, opt := range opts {
opt(options)
}
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error {
runningGroup, _ := errgroup.WithContext(ctx)
for _, field := range fields {
opts := opts
fieldBinLog := field
fieldID := field.FieldID
mmapEnabled := common.IsFieldMmapEnabled(collection.Schema(), fieldID) ||
(!common.FieldHasMmapKey(collection.Schema(), fieldID) && params.Params.QueryNodeCfg.MmapEnabled.GetAsBool())
if mmapEnabled && options.LoadStatus == LoadStatusInMemory {
opts = append(opts, WithLoadStatus(LoadStatusMapped))
}
runningGroup.Go(func() error {
return segment.LoadFieldData(ctx,
fieldID,
rowCount,
fieldBinLog,
opts...,
)
})
}
@ -1157,7 +1182,6 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
segment *LocalSegment,
numRows int64,
indexedFieldInfos map[int64]*IndexedFieldInfo,
opts ...loadOption,
) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", segment.Collection()),
@ -1168,7 +1192,9 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
for fieldID, fieldInfo := range indexedFieldInfos {
indexInfo := fieldInfo.IndexInfo
err := loader.loadFieldIndex(ctx, segment, indexInfo, opts...)
tr := timerecord.NewTimeRecorder("loadFieldIndex")
err := loader.loadFieldIndex(ctx, segment, indexInfo)
loadFieldIndexSpan := tr.RecordSpan()
if err != nil {
return err
}
@ -1177,6 +1203,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
zap.Int64("fieldID", fieldID),
zap.Any("binlog", fieldInfo.FieldBinlog.Binlogs),
zap.Int32("current_index_version", fieldInfo.IndexInfo.GetCurrentIndexVersion()),
zap.Duration("load_duration", loadFieldIndexSpan),
)
// set average row data size of variable field
@ -1195,7 +1222,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
return nil
}
func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalSegment, indexInfo *querypb.FieldIndexInfo, opts ...loadOption) error {
func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalSegment, indexInfo *querypb.FieldIndexInfo) error {
filteredPaths := make([]string, 0, len(indexInfo.IndexFilePaths))
for _, indexPath := range indexInfo.IndexFilePaths {
@ -1215,7 +1242,7 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
return merr.WrapErrCollectionNotLoaded(segment.Collection(), "failed to load field index")
}
return segment.LoadIndex(ctx, indexInfo, fieldType, opts...)
return segment.LoadIndex(ctx, indexInfo, fieldType)
}
func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet,
@ -1526,20 +1553,20 @@ func getResourceUsageEstimateOfSegment(schema *schemapb.CollectionSchema, loadIn
loadInfo.GetSegmentID(),
fieldIndexInfo.GetBuildID())
}
segmentMemorySize += neededMemSize
if mmapEnabled {
segmentDiskSize += neededMemSize + neededDiskSize
} else {
segmentMemorySize += neededMemSize
segmentDiskSize += neededDiskSize
}
} else {
mmapEnabled = common.IsFieldMmapEnabled(schema, fieldID) ||
(!common.FieldHasMmapKey(schema, fieldID) && params.Params.QueryNodeCfg.MmapEnabled.GetAsBool())
binlogSize := uint64(getBinlogDataSize(fieldBinlog))
segmentMemorySize += binlogSize
if mmapEnabled {
segmentDiskSize += binlogSize
} else {
segmentMemorySize += binlogSize
if multiplyFactor.enableTempSegmentIndex {
segmentMemorySize += uint64(float64(binlogSize) * multiplyFactor.tempSegmentIndexFactor)
}

View File

@ -27,6 +27,7 @@ import (
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -97,7 +98,7 @@ func (suite *SegmentLoaderSuite) SetupTest() {
func (suite *SegmentLoaderSuite) TearDownTest() {
ctx := context.Background()
for i := 0; i < suite.segmentNum; i++ {
suite.manager.Segment.Remove(suite.segmentID+int64(i), querypb.DataScope_All)
suite.manager.Segment.Remove(context.Background(), suite.segmentID+int64(i), querypb.DataScope_All)
}
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
}
@ -450,7 +451,7 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
seg := segment.(*LocalSegment)
// nothing would happen as the delta logs have been all applied,
// so the released segment won't cause error
seg.Release()
seg.Release(ctx)
loadInfos[i].Deltalogs[0].Binlogs[0].TimestampTo--
err := suite.loader.LoadDeltaLogs(ctx, seg, loadInfos[i].GetDeltalogs())
suite.NoError(err)
@ -459,7 +460,6 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
func (suite *SegmentLoaderSuite) TestLoadIndex() {
ctx := context.Background()
segment := &LocalSegment{}
loadInfo := &querypb.SegmentLoadInfo{
SegmentID: 1,
PartitionID: suite.partitionID,
@ -470,6 +470,11 @@ func (suite *SegmentLoaderSuite) TestLoadIndex() {
},
},
}
segment := &LocalSegment{
baseSegment: baseSegment{
loadInfo: atomic.NewPointer[querypb.SegmentLoadInfo](loadInfo),
},
}
err := suite.loader.LoadIndex(ctx, segment, loadInfo, 0)
suite.ErrorIs(err, merr.ErrIndexNotFound)
@ -866,7 +871,7 @@ func (suite *SegmentLoaderV2Suite) SetupTest() {
func (suite *SegmentLoaderV2Suite) TearDownTest() {
ctx := context.Background()
for i := 0; i < suite.segmentNum; i++ {
suite.manager.Segment.Remove(suite.segmentID+int64(i), querypb.DataScope_All)
suite.manager.Segment.Remove(context.Background(), suite.segmentID+int64(i), querypb.DataScope_All)
}
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
paramtable.Get().CommonCfg.EnableStorageV2.SwapTempValue("false")

View File

@ -71,6 +71,7 @@ func (suite *SegmentSuite) SetupTest() {
PartitionID: suite.partitionID,
InsertChannel: "dml",
Level: datapb.SegmentLevel_Legacy,
NumOfRows: int64(msgLength),
BinlogPaths: []*datapb.FieldBinlog{
{
FieldID: 101,
@ -123,14 +124,14 @@ func (suite *SegmentSuite) SetupTest() {
err = suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord)
suite.Require().NoError(err)
suite.manager.Segment.Put(SegmentTypeSealed, suite.sealed)
suite.manager.Segment.Put(SegmentTypeGrowing, suite.growing)
suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed)
suite.manager.Segment.Put(context.Background(), SegmentTypeGrowing, suite.growing)
}
func (suite *SegmentSuite) TearDownTest() {
ctx := context.Background()
suite.sealed.Release()
suite.growing.Release()
suite.sealed.Release(context.Background())
suite.growing.Release(context.Background())
DeleteCollection(suite.collection)
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
}
@ -139,7 +140,7 @@ func (suite *SegmentSuite) TestLoadInfo() {
// sealed segment has load info
suite.NotNil(suite.sealed.LoadInfo())
// growing segment has no load info
suite.Nil(suite.growing.LoadInfo())
suite.NotNil(suite.growing.LoadInfo())
}
func (suite *SegmentSuite) TestResourceUsageEstimate() {
@ -196,8 +197,11 @@ func (suite *SegmentSuite) TestCASVersion() {
suite.Equal(curVersion+1, segment.Version())
}
func (suite *SegmentSuite) TestSegmentRemoveUnusedFieldFiles() {
}
func (suite *SegmentSuite) TestSegmentReleased() {
suite.sealed.Release()
suite.sealed.Release(context.Background())
sealed := suite.sealed.(*LocalSegment)

View File

@ -182,3 +182,13 @@ func getSegmentMetricLabel(segment Segment) metricsutil.SegmentLabel {
ResourceGroup: segment.ResourceGroup(),
}
}
func FilterZeroValuesFromSlice(intVals []int64) []int64 {
var result []int64
for _, value := range intVals {
if value != 0 {
result = append(result, value)
}
}
return result
}

View File

@ -0,0 +1,20 @@
package segments
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilterZeroValuesFromSlice(t *testing.T) {
var ints []int64
ints = append(ints, 10)
ints = append(ints, 0)
ints = append(ints, 5)
ints = append(ints, 13)
ints = append(ints, 0)
filteredInts := FilterZeroValuesFromSlice(ints)
assert.Equal(t, 3, len(filteredInts))
assert.EqualValues(t, []int64{10, 5, 13}, filteredInts)
}

View File

@ -484,7 +484,7 @@ func (node *QueryNode) Stop() error {
node.dispClient.Close()
}
if node.manager != nil {
node.manager.Segment.Clear()
node.manager.Segment.Clear(context.Background())
}
node.CloseSegcore()

View File

@ -236,7 +236,7 @@ func (suite *QueryNodeSuite) TestStop() {
},
)
suite.NoError(err)
suite.node.manager.Segment.Put(segments.SegmentTypeSealed, segment)
suite.node.manager.Segment.Put(context.Background(), segments.SegmentTypeSealed, segment)
err = suite.node.Stop()
suite.NoError(err)
suite.True(suite.node.manager.Segment.Empty())

View File

@ -301,7 +301,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
defer func() {
if err != nil {
// remove legacy growing
node.manager.Segment.RemoveBy(segments.WithChannel(channel.GetChannelName()),
node.manager.Segment.RemoveBy(ctx, segments.WithChannel(channel.GetChannelName()),
segments.WithType(segments.SegmentTypeGrowing))
}
}()
@ -357,8 +357,8 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
delegator.Close()
node.pipelineManager.Remove(req.GetChannelName())
node.manager.Segment.RemoveBy(segments.WithChannel(req.GetChannelName()), segments.WithType(segments.SegmentTypeGrowing))
node.manager.Segment.RemoveBy(segments.WithChannel(req.GetChannelName()), segments.WithLevel(datapb.SegmentLevel_L0))
node.manager.Segment.RemoveBy(ctx, segments.WithChannel(req.GetChannelName()), segments.WithType(segments.SegmentTypeGrowing))
node.manager.Segment.RemoveBy(ctx, segments.WithChannel(req.GetChannelName()), segments.WithLevel(datapb.SegmentLevel_L0))
node.tSafeManager.Remove(ctx, req.GetChannelName())
node.manager.Collection.Unref(req.GetCollectionID(), 1)
@ -547,7 +547,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release
log.Info("start to release segments")
sealedCount := 0
for _, id := range req.GetSegmentIDs() {
_, count := node.manager.Segment.Remove(id, req.GetScope())
_, count := node.manager.Segment.Remove(ctx, id, req.GetScope())
sealedCount += count
}
node.manager.Collection.Unref(req.GetCollectionID(), uint32(sealedCount))
@ -660,7 +660,13 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
return resp, nil
}
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID)
var task tasks.Task
if paramtable.Get().QueryNodeCfg.UseStreamComputing.GetAsBool() {
task = tasks.NewStreamingSearchTask(searchCtx, collection, node.manager, req, node.serverID)
} else {
task = tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID)
}
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to search channel", zap.Error(err))
resp.Status = merr.Status(err)
@ -683,7 +689,7 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
resp = task.Result()
resp = task.SearchResult()
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
resp.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
return resp, nil
@ -767,7 +773,8 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
metrics.QueryNodeReduceLatency.
WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce).
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
@ -914,7 +921,8 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
}, nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards).
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()),
metrics.QueryLabel, metrics.ReduceShards, metrics.BatchReduce).
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.NQPerSecond, 1)

View File

@ -160,7 +160,6 @@ func (suite *ServiceSuite) TearDownTest() {
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.node.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
suite.node.Stop()
suite.etcdClient.Close()
}
@ -474,9 +473,9 @@ func (suite *ServiceSuite) TestUnsubDmChannels_Normal() {
l0Segment.EXPECT().Type().Return(commonpb.SegmentState_Sealed)
l0Segment.EXPECT().Indexes().Return(nil)
l0Segment.EXPECT().Shard().Return(suite.vchannel)
l0Segment.EXPECT().Release().Return()
l0Segment.EXPECT().Release(ctx).Return()
suite.node.manager.Segment.Put(segments.SegmentTypeSealed, l0Segment)
suite.node.manager.Segment.Put(ctx, segments.SegmentTypeSealed, l0Segment)
// data
req := &querypb.UnsubDmChannelRequest{
@ -1364,6 +1363,48 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() {
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
}
func (suite *ServiceSuite) TestStreamingSearch() {
ctx := context.Background()
// pre
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true")
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
req := &querypb.SearchRequest{
Req: creq,
FromShardLeader: true,
DmlChannels: []string{suite.vchannel},
TotalChannelNum: 2,
SegmentIDs: suite.validSegmentIDs,
Scope: querypb.DataScope_Historical,
}
suite.NoError(err)
rsp, err := suite.node.SearchSegments(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
}
func (suite *ServiceSuite) TestStreamingSearchGrowing() {
ctx := context.Background()
// pre
suite.TestWatchDmChannelsInt64()
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true")
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
req := &querypb.SearchRequest{
Req: creq,
FromShardLeader: true,
DmlChannels: []string{suite.vchannel},
TotalChannelNum: 2,
Scope: querypb.DataScope_Streaming,
}
suite.NoError(err)
rsp, err := suite.node.SearchSegments(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
}
// Test Query
func (suite *ServiceSuite) genCQueryRequest(nq int64, indexType string, schema *schemapb.CollectionSchema) (*internalpb.RetrieveRequest, error) {
expr, err := genSimpleRetrievePlanExpr(schema)

View File

@ -5,6 +5,7 @@ import (
"math/rand"
"time"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
@ -114,6 +115,10 @@ func (t *MockTask) MergeWith(t2 Task) bool {
return false
}
func (t *MockTask) SearchResult() *internalpb.SearchResults {
return nil
}
func (t *MockTask) NQ() int64 {
return t.nq
}

View File

@ -3,6 +3,7 @@ package tasks
import (
"context"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/streamrpc"
@ -86,3 +87,7 @@ func (t *QueryStreamTask) Wait() error {
func (t *QueryStreamTask) NQ() int64 {
return 1
}
func (t *QueryStreamTask) SearchResult() *internalpb.SearchResults {
return nil
}

View File

@ -91,6 +91,10 @@ func (t *QueryTask) PreExecute() error {
return nil
}
func (t *QueryTask) SearchResult() *internalpb.SearchResults {
return nil
}
// Execute the task, only call once.
func (t *QueryTask) Execute() error {
if t.scheduleSpan != nil {
@ -125,7 +129,8 @@ func (t *QueryTask) Execute() error {
metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel,
metrics.ReduceSegments).Observe(float64(time.Since(beforeReduce).Milliseconds()))
metrics.ReduceSegments,
metrics.BatchReduce).Observe(float64(time.Since(beforeReduce).Milliseconds()))
if err != nil {
return err
}

View File

@ -2,6 +2,8 @@ package tasks
// TODO: rename this file into search_task.go
import "C"
import (
"bytes"
"context"
@ -233,7 +235,8 @@ func (t *SearchTask) Execute() error {
metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(t.GetNodeID()),
metrics.SearchLabel,
metrics.ReduceSegments).
metrics.ReduceSegments,
metrics.BatchReduce).
Observe(float64(tr.RecordSpan().Milliseconds()))
for i := range t.originNqs {
blob, err := segments.GetSearchResultDataBlob(t.ctx, blobs, i)
@ -333,7 +336,7 @@ func (t *SearchTask) Wait() error {
return <-t.notifier
}
func (t *SearchTask) Result() *internalpb.SearchResults {
func (t *SearchTask) SearchResult() *internalpb.SearchResults {
if t.result != nil {
channelsMvcc := make(map[string]uint64)
for _, ch := range t.req.GetDmlChannels() {
@ -382,3 +385,214 @@ func (t *SearchTask) combinePlaceHolderGroups() error {
t.placeholderGroup, _ = proto.Marshal(ret)
return nil
}
type StreamingSearchTask struct {
SearchTask
others []*StreamingSearchTask
resultBlobs segments.SearchResultDataBlobs
streamReducer segments.StreamSearchReducer
}
func NewStreamingSearchTask(ctx context.Context,
collection *segments.Collection,
manager *segments.Manager,
req *querypb.SearchRequest,
serverID int64,
) *StreamingSearchTask {
ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "schedule")
return &StreamingSearchTask{
SearchTask: SearchTask{
ctx: ctx,
collection: collection,
segmentManager: manager,
req: req,
merged: false,
groupSize: 1,
topk: req.GetReq().GetTopk(),
nq: req.GetReq().GetNq(),
placeholderGroup: req.GetReq().GetPlaceholderGroup(),
originTopks: []int64{req.GetReq().GetTopk()},
originNqs: []int64{req.GetReq().GetNq()},
notifier: make(chan error, 1),
tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"),
scheduleSpan: span,
serverID: serverID,
},
}
}
func (t *StreamingSearchTask) MergeWith(other Task) bool {
return false
}
func (t *StreamingSearchTask) Execute() error {
log := log.Ctx(t.ctx).With(
zap.Int64("collectionID", t.collection.ID()),
zap.String("shard", t.req.GetDmlChannels()[0]),
)
// 0. prepare search req
if t.scheduleSpan != nil {
t.scheduleSpan.End()
}
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")
req := t.req
t.combinePlaceHolderGroups()
searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup)
if err != nil {
return err
}
defer searchReq.Delete()
// 1. search&&reduce or streaming-search&&streaming-reduce
metricType := searchReq.Plan().GetMetricType()
if req.GetScope() == querypb.DataScope_Historical {
streamReduceFunc := func(result *segments.SearchResult) error {
reduceErr := t.streamReduce(t.ctx, searchReq.Plan(), result, t.originNqs, t.originTopks)
return reduceErr
}
pinnedSegments, err := segments.SearchHistoricalStreamly(
t.ctx,
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetSegmentIDs(),
streamReduceFunc)
defer segments.DeleteStreamReduceHelper(t.streamReducer)
defer t.segmentManager.Segment.Unpin(pinnedSegments)
if err != nil {
log.Error("Failed to search sealed segments streamly", zap.Error(err))
return err
}
t.resultBlobs, err = segments.GetStreamReduceResult(t.ctx, t.streamReducer)
defer segments.DeleteSearchResultDataBlobs(t.resultBlobs)
if err != nil {
log.Error("Failed to get stream-reduced search result")
return err
}
} else if req.GetScope() == querypb.DataScope_Streaming {
results, pinnedSegments, err := segments.SearchStreaming(
t.ctx,
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetSegmentIDs(),
)
defer segments.DeleteSearchResults(results)
defer t.segmentManager.Segment.Unpin(pinnedSegments)
if err != nil {
return err
}
if t.maybeReturnForEmptyResults(results, metricType, tr) {
return nil
}
tr.RecordSpan()
t.resultBlobs, err = segments.ReduceSearchResultsAndFillData(
t.ctx,
searchReq.Plan(),
results,
int64(len(results)),
t.originNqs,
t.originTopks,
)
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
return err
}
defer segments.DeleteSearchResultDataBlobs(t.resultBlobs)
metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(t.GetNodeID()),
metrics.SearchLabel,
metrics.ReduceSegments,
metrics.BatchReduce).
Observe(float64(tr.RecordSpan().Milliseconds()))
}
// 2. reorganize blobs to original search request
for i := range t.originNqs {
blob, err := segments.GetSearchResultDataBlob(t.ctx, t.resultBlobs, i)
if err != nil {
return err
}
var task *StreamingSearchTask
if i == 0 {
task = t
} else {
task = t.others[i-1]
}
// Note: blob is unsafe because get from C
bs := make([]byte, len(blob))
copy(bs, blob)
task.result = &internalpb.SearchResults{
Base: &commonpb.MsgBase{
SourceID: t.GetNodeID(),
},
Status: merr.Success(),
MetricType: metricType,
NumQueries: t.originNqs[i],
TopK: t.originTopks[i],
SlicedBlob: bs,
SlicedOffset: 1,
SlicedNumCount: 1,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
}
return nil
}
func (t *StreamingSearchTask) maybeReturnForEmptyResults(results []*segments.SearchResult,
metricType string, tr *timerecord.TimeRecorder,
) bool {
if len(results) == 0 {
for i := range t.originNqs {
var task *StreamingSearchTask
if i == 0 {
task = t
} else {
task = t.others[i-1]
}
task.result = &internalpb.SearchResults{
Base: &commonpb.MsgBase{
SourceID: t.GetNodeID(),
},
Status: merr.Success(),
MetricType: metricType,
NumQueries: t.originNqs[i],
TopK: t.originTopks[i],
SlicedOffset: 1,
SlicedNumCount: 1,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
}
return true
}
return false
}
func (t *StreamingSearchTask) streamReduce(ctx context.Context,
plan *segments.SearchPlan,
newResult *segments.SearchResult,
sliceNQs []int64,
sliceTopKs []int64,
) error {
if t.streamReducer == nil {
var err error
t.streamReducer, err = segments.NewStreamReducer(ctx, plan, sliceNQs, sliceTopKs)
if err != nil {
log.Error("Fail to init stream reducer, return")
return err
}
}
return segments.StreamReduceSearchResult(ctx, newResult, t.streamReducer)
}

View File

@ -1,5 +1,7 @@
package tasks
import "github.com/milvus-io/milvus/internal/proto/internalpb"
const (
schedulePolicyNameFIFO = "fifo"
schedulePolicyNameUserTaskPolling = "user-task-polling"
@ -104,4 +106,6 @@ type Task interface {
// Return the NQ of task.
NQ() int64
SearchResult() *internalpb.SearchResults
}

View File

@ -66,6 +66,9 @@ const (
ReduceSegments = "segments"
ReduceShards = "shards"
BatchReduce = "batch_reduce"
StreamReduce = "stream_reduce"
Pending = "pending"
Executing = "executing"
Done = "done"
@ -96,6 +99,7 @@ const (
requestScope = "scope"
fullMethodLabelName = "full_method"
reduceLevelName = "reduce_level"
reduceType = "reduce_type"
lockName = "lock_name"
lockSource = "lock_source"
lockType = "lock_type"

View File

@ -229,6 +229,7 @@ var (
nodeIDLabelName,
queryTypeLabelName,
reduceLevelName,
reduceType,
})
QueryNodeLoadSegmentLatency = prometheus.NewHistogramVec(

View File

@ -18,13 +18,15 @@ package cache
import (
"container/list"
"fmt"
"context"
"sync"
"time"
"go.uber.org/atomic"
"golang.org/x/sync/singleflight"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/lock"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -35,14 +37,15 @@ var (
)
type cacheItem[K comparable, V any] struct {
key K
value V
pinCount atomic.Int32
key K
value V
pinCount atomic.Int32
needReload bool
}
type (
Loader[K comparable, V any] func(key K) (V, bool)
Finalizer[K comparable, V any] func(key K, value V) error
Loader[K comparable, V any] func(ctx context.Context, key K) (V, error)
Finalizer[K comparable, V any] func(ctx context.Context, key K, value V) error
)
// Scavenger records occupation of cache and decide whether to evict if necessary.
@ -57,18 +60,26 @@ type Scavenger[K comparable] interface {
Collect(key K) (bool, func(K) bool)
// Throw records entry removals.
Throw(key K)
// Spare returns a collector function based on given key.
// The collector is a function which can be invoked repetedly, each invocation will test if there is enough
// room for all the pending entries if the thrown entry is evicted. Typically, the collector will get multiple true
// before it gets a false.
Spare(key K) func(K) bool
Replace(key K) (bool, func(K) bool, func())
}
type LazyScavenger[K comparable] struct {
capacity int64
size int64
weight func(K) int64
weights map[K]int64
}
func NewLazyScavenger[K comparable](weight func(K) int64, capacity int64) *LazyScavenger[K] {
return &LazyScavenger[K]{
capacity: capacity,
weight: weight,
weights: make(map[K]int64),
}
}
@ -77,16 +88,47 @@ func (s *LazyScavenger[K]) Collect(key K) (bool, func(K) bool) {
if s.size+w > s.capacity {
needCollect := s.size + w - s.capacity
return false, func(key K) bool {
needCollect -= s.weight(key)
needCollect -= s.weights[key]
return needCollect <= 0
}
}
s.size += w
s.weights[key] = w
return true, nil
}
func (s *LazyScavenger[K]) Replace(key K) (bool, func(K) bool, func()) {
pw := s.weights[key]
w := s.weight(key)
if s.size-pw+w > s.capacity {
needCollect := s.size - pw + w - s.capacity
return false, func(key K) bool {
needCollect -= s.weights[key]
return needCollect <= 0
}, nil
}
s.size += w - pw
s.weights[key] = w
return true, nil, func() {
s.size -= w - pw
s.weights[key] = pw
}
}
func (s *LazyScavenger[K]) Throw(key K) {
s.size -= s.weight(key)
if w, ok := s.weights[key]; ok {
s.size -= w
delete(s.weights, key)
}
}
func (s *LazyScavenger[K]) Spare(key K) func(K) bool {
w := s.weight(key)
available := s.capacity - s.size + w
return func(k K) bool {
available -= s.weight(k)
return available >= 0
}
}
type Stats struct {
@ -104,15 +146,21 @@ type Cache[K comparable, V any] interface {
// completes.
// Throws `ErrNoSuchItem` if the key is not found or not able to be loaded from given loader.
// Throws `ErrNotEnoughSpace` if there is no room for the operation.
Do(key K, doer func(V) error) (missing bool, err error)
Do(ctx context.Context, key K, doer func(V) error) (missing bool, err error)
// Do the operation `doer` on the given key `key`. The key is kept in the cache until the operation
// completes. The function waits for `timeout` if there is not enough space for the given key.
// Throws `ErrNoSuchItem` if the key is not found or not able to be loaded from given loader.
// Throws `ErrTimeOut` if timed out.
DoWait(key K, timeout time.Duration, doer func(V) error) (missing bool, err error)
DoWait(ctx context.Context, key K, timeout time.Duration, doer func(context.Context, V) error) (missing bool, err error)
// Get stats
Stats() *Stats
MarkItemNeedReload(ctx context.Context, key K) bool
// Expire removes the item from the cache.
// Return true if the item is not in used and removed immediately or the item is not in cache now.
// Return false if the item is in used, it will be marked as need to be reloaded, a lazy expire is applied.
Expire(ctx context.Context, key K) (evicted bool)
}
type Waiter[K comparable] struct {
@ -131,21 +179,23 @@ func newWaiter[K comparable](key K) Waiter[K] {
type lruCache[K comparable, V any] struct {
rwlock sync.RWMutex
// the value is *cacheItem[V]
items map[K]*list.Element
accessList *list.List
loaderSingleFlight singleflight.Group
stats *Stats
waitQueue *list.List
items map[K]*list.Element
accessList *list.List
loaderKeyLocks *lock.KeyLock[K]
stats *Stats
waitQueue *list.List
loader Loader[K, V]
finalizer Finalizer[K, V]
scavenger Scavenger[K]
reloader Loader[K, V]
}
type CacheBuilder[K comparable, V any] struct {
loader Loader[K, V]
finalizer Finalizer[K, V]
scavenger Scavenger[K]
reloader Loader[K, V]
}
func NewCacheBuilder[K comparable, V any]() *CacheBuilder[K, V] {
@ -186,30 +236,37 @@ func (b *CacheBuilder[K, V]) WithCapacity(capacity int64) *CacheBuilder[K, V] {
return b
}
func (b *CacheBuilder[K, V]) WithReloader(reloader Loader[K, V]) *CacheBuilder[K, V] {
b.reloader = reloader
return b
}
func (b *CacheBuilder[K, V]) Build() Cache[K, V] {
return newLRUCache(b.loader, b.finalizer, b.scavenger)
return newLRUCache(b.loader, b.finalizer, b.scavenger, b.reloader)
}
func newLRUCache[K comparable, V any](
loader Loader[K, V],
finalizer Finalizer[K, V],
scavenger Scavenger[K],
reloader Loader[K, V],
) Cache[K, V] {
return &lruCache[K, V]{
items: make(map[K]*list.Element),
accessList: list.New(),
waitQueue: list.New(),
loaderSingleFlight: singleflight.Group{},
stats: new(Stats),
loader: loader,
finalizer: finalizer,
scavenger: scavenger,
items: make(map[K]*list.Element),
accessList: list.New(),
waitQueue: list.New(),
loaderKeyLocks: lock.NewKeyLock[K](),
stats: new(Stats),
loader: loader,
finalizer: finalizer,
scavenger: scavenger,
reloader: reloader,
}
}
// Do picks up an item from cache and executes doer. The entry of interest is garented in the cache when doer is executing.
func (c *lruCache[K, V]) Do(key K, doer func(V) error) (bool, error) {
item, missing, err := c.getAndPin(key)
func (c *lruCache[K, V]) Do(ctx context.Context, key K, doer func(V) error) (bool, error) {
item, missing, err := c.getAndPin(ctx, key)
if err != nil {
return missing, err
}
@ -217,8 +274,11 @@ func (c *lruCache[K, V]) Do(key K, doer func(V) error) (bool, error) {
return missing, doer(item.value)
}
func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error) (bool, error) {
timedWait := func(cond *sync.Cond, timeout time.Duration) bool {
func (c *lruCache[K, V]) DoWait(ctx context.Context, key K, timeout time.Duration, doer func(context.Context, V) error) (bool, error) {
timedWait := func(cond *sync.Cond, timeout time.Duration) error {
if timeout <= 0 {
return ErrTimeOut // timed out
}
c := make(chan struct{})
go func() {
cond.L.Lock()
@ -228,25 +288,34 @@ func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error
}()
select {
case <-c:
return false // completed normally
return nil // completed normally
case <-time.After(timeout):
return true // timed out
return ErrTimeOut // timed out
case <-ctx.Done():
return ctx.Err() // context timeout
}
}
var ele *list.Element
defer func() {
if ele != nil {
c.rwlock.Lock()
c.waitQueue.Remove(ele)
c.rwlock.Unlock()
}
}()
start := time.Now()
log := log.Ctx(ctx).With(zap.Any("key", key))
for {
item, missing, err := c.getAndPin(key)
item, missing, err := c.getAndPin(ctx, key)
if err == nil {
if ele != nil {
c.rwlock.Lock()
c.waitQueue.Remove(ele)
c.rwlock.Unlock()
}
defer c.Unpin(key)
return missing, doer(item.value)
} else if err != ErrNotEnoughSpace {
return missing, doer(ctx, item.value)
} else if err == ErrNotEnoughSpace {
log.Warn("Failed to get disk cache for segment, wait and try again")
} else if err == merr.ErrServiceResourceInsufficient {
log.Warn("Failed to load segment for insufficient resource, wait and try later")
} else if err != nil {
return true, err
}
if ele == nil {
@ -254,12 +323,17 @@ func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error
c.rwlock.Lock()
waiter := newWaiter(key)
ele = c.waitQueue.PushBack(&waiter)
log.Info("push waiter into waiter queue", zap.Any("key", key))
c.rwlock.Unlock()
}
// Wait for the key to be available
timeLeft := time.Until(start.Add(timeout))
if timeLeft <= 0 || timedWait(ele.Value.(*Waiter[K]).c, timeLeft) {
return true, ErrTimeOut
if err = timedWait(ele.Value.(*Waiter[K]).c, timeLeft); err != nil {
log.Warn("failed to get item for key",
zap.Any("key", key),
zap.Int("wait_len", c.waitQueue.Len()),
zap.Error(err))
return true, err
}
}
}
@ -277,77 +351,95 @@ func (c *lruCache[K, V]) Unpin(key K) {
}
item := e.Value.(*cacheItem[K, V])
item.pinCount.Dec()
if item.pinCount.Load() == 0 {
c.notifyWaiters()
}
}
func (c *lruCache[K, V]) notifyWaiters() {
if c.waitQueue.Len() > 0 {
log := log.With(zap.Any("UnPinedKey", key))
if item.pinCount.Load() == 0 && c.waitQueue.Len() > 0 {
log.Debug("Unpin item to zero ref, trigger activating waiters")
// Notify waiters
// collector := c.scavenger.Spare(key)
for e := c.waitQueue.Front(); e != nil; e = e.Next() {
w := e.Value.(*Waiter[K])
log.Info("try to activate waiter", zap.Any("activated_waiter_key", w.key))
w.c.Broadcast()
// we try best to activate as many waiters as possible every time
}
} else {
log.Debug("Miss to trigger activating waiters",
zap.Int32("PinCount", item.pinCount.Load()),
zap.Int("wait_len", c.waitQueue.Len()))
}
}
func (c *lruCache[K, V]) peekAndPin(key K) *cacheItem[K, V] {
func (c *lruCache[K, V]) peekAndPin(ctx context.Context, key K) *cacheItem[K, V] {
c.rwlock.Lock()
defer c.rwlock.Unlock()
e, ok := c.items[key]
log := log.Ctx(ctx)
if ok {
item := e.Value.(*cacheItem[K, V])
if item.needReload && item.pinCount.Load() == 0 {
ok, _, retback := c.scavenger.Replace(key)
if ok {
// there is room for reload and no one is using the item
if c.reloader != nil {
reloaded, err := c.reloader(ctx, key)
if err == nil {
item.value = reloaded
} else if retback != nil {
retback()
}
}
item.needReload = false
}
}
c.accessList.MoveToFront(e)
item.pinCount.Inc()
log.Debug("peeked item success",
zap.Int32("PinCount", item.pinCount.Load()),
zap.Any("key", key))
return item
}
log.Debug("failed to peek item", zap.Any("key", key))
return nil
}
// GetAndPin gets and pins the given key if it exists
func (c *lruCache[K, V]) getAndPin(key K) (*cacheItem[K, V], bool, error) {
if item := c.peekAndPin(key); item != nil {
func (c *lruCache[K, V]) getAndPin(ctx context.Context, key K) (*cacheItem[K, V], bool, error) {
if item := c.peekAndPin(ctx, key); item != nil {
c.stats.HitCount.Inc()
return item, false, nil
}
log := log.Ctx(ctx)
c.stats.MissCount.Inc()
if c.loader != nil {
// Try scavenge if there is room. If not, fail fast.
// Note that the test is not accurate since we are not locking `loader` here.
if _, ok := c.tryScavenge(key); !ok {
log.Warn("getAndPin ran into scavenge failure, return", zap.Any("key", key))
return nil, true, ErrNotEnoughSpace
}
strKey := fmt.Sprint(key)
item, err, _ := c.loaderSingleFlight.Do(strKey, func() (interface{}, error) {
if item := c.peekAndPin(key); item != nil {
return item, nil
}
timer := time.Now()
value, ok := c.loader(key)
c.stats.TotalLoadTimeMs.Add(uint64(time.Since(timer).Milliseconds()))
if !ok {
c.stats.LoadFailCount.Inc()
return nil, ErrNoSuchItem
}
c.stats.LoadSuccessCount.Inc()
item, err := c.setAndPin(key, value)
if err != nil {
return nil, err
}
return item, nil
})
if err == nil {
return item.(*cacheItem[K, V]), true, nil
c.loaderKeyLocks.Lock(key)
defer c.loaderKeyLocks.Unlock(key)
if item := c.peekAndPin(ctx, key); item != nil {
return item, false, nil
}
timer := time.Now()
value, err := c.loader(ctx, key)
c.stats.TotalLoadTimeMs.Add(uint64(time.Since(timer).Milliseconds()))
if err != nil {
c.stats.LoadFailCount.Inc()
log.Debug("loader failed for key", zap.Any("key", key))
return nil, true, err
}
return nil, true, err
}
c.stats.LoadSuccessCount.Inc()
item, err := c.setAndPin(ctx, key, value)
if err != nil {
log.Debug("setAndPin failed for key", zap.Any("key", key), zap.Error(err))
return nil, true, err
}
return item, true, nil
}
return nil, true, ErrNoSuchItem
}
@ -381,7 +473,7 @@ func (c *lruCache[K, V]) lockfreeTryScavenge(key K) ([]K, bool) {
}
// for cache miss
func (c *lruCache[K, V]) setAndPin(key K, value V) (*cacheItem[K, V], error) {
func (c *lruCache[K, V]) setAndPin(ctx context.Context, key K, value V) (*cacheItem[K, V], error) {
c.rwlock.Lock()
defer c.rwlock.Unlock()
@ -390,32 +482,69 @@ func (c *lruCache[K, V]) setAndPin(key K, value V) (*cacheItem[K, V], error) {
// tryScavenge is done again since the load call is lock free.
toEvict, ok := c.lockfreeTryScavenge(key)
log := log.Ctx(ctx)
if !ok {
if c.finalizer != nil {
c.finalizer(key, value)
log.Warn("setAndPin ran into scavenge failure, release data for", zap.Any("key", key))
c.finalizer(ctx, key, value)
}
return nil, ErrNotEnoughSpace
}
for _, ek := range toEvict {
e := c.items[ek]
delete(c.items, ek)
c.accessList.Remove(e)
c.scavenger.Throw(ek)
c.stats.EvictionCount.Inc()
if c.finalizer != nil {
item := e.Value.(*cacheItem[K, V])
timer := time.Now()
c.finalizer(ek, item.value)
c.stats.TotalFinalizeTimeMs.Add(uint64(time.Since(timer).Milliseconds()))
}
c.evict(ctx, ek)
log.Debug("cache evicting", zap.Any("key", ek), zap.Any("by", key))
}
c.scavenger.Collect(key)
e := c.accessList.PushFront(item)
c.items[item.key] = e
log.Debug("setAndPin set up item", zap.Any("item.key", item.key),
zap.Int32("pinCount", item.pinCount.Load()))
return item, nil
}
func (c *lruCache[K, V]) Expire(ctx context.Context, key K) (evicted bool) {
c.rwlock.Lock()
defer c.rwlock.Unlock()
e, ok := c.items[key]
if !ok {
return true
}
item := e.Value.(*cacheItem[K, V])
if item.pinCount.Load() == 0 {
c.evict(ctx, key)
return true
}
item.needReload = true
return false
}
func (c *lruCache[K, V]) evict(ctx context.Context, key K) {
c.stats.EvictionCount.Inc()
e := c.items[key]
delete(c.items, key)
c.accessList.Remove(e)
c.scavenger.Throw(key)
if c.finalizer != nil {
item := e.Value.(*cacheItem[K, V])
c.finalizer(ctx, key, item.value)
}
}
func (c *lruCache[K, V]) MarkItemNeedReload(ctx context.Context, key K) bool {
c.rwlock.Lock()
defer c.rwlock.Unlock()
if e, ok := c.items[key]; ok {
item := e.Value.(*cacheItem[K, V])
item.needReload = true
return true
}
return false
}

View File

@ -1,7 +1,7 @@
package cache
import (
"math/rand"
"context"
"sync"
"testing"
"time"
@ -12,8 +12,8 @@ import (
)
func TestLRUCache(t *testing.T) {
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
})
t.Run("test loader", func(t *testing.T) {
@ -21,7 +21,7 @@ func TestLRUCache(t *testing.T) {
cache := cacheBuilder.WithCapacity(int64(size)).Build()
for i := 0; i < size; i++ {
missing, err := cache.Do(i, func(v int) error {
missing, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -33,13 +33,13 @@ func TestLRUCache(t *testing.T) {
t.Run("test finalizer", func(t *testing.T) {
size := 10
finalizeSeq := make([]int, 0)
cache := cacheBuilder.WithCapacity(int64(size)).WithFinalizer(func(key, value int) error {
cache := cacheBuilder.WithCapacity(int64(size)).WithFinalizer(func(ctx context.Context, key, value int) error {
finalizeSeq = append(finalizeSeq, key)
return nil
}).Build()
for i := 0; i < size*2; i++ {
missing, err := cache.Do(i, func(v int) error {
missing, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -50,7 +50,7 @@ func TestLRUCache(t *testing.T) {
// Hit the cache again, there should be no swap-out
for i := size; i < size*2; i++ {
missing, err := cache.Do(i, func(v int) error {
missing, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -65,13 +65,13 @@ func TestLRUCache(t *testing.T) {
sumCapacity := 20 // inserting 1 to 19, capacity is set to sum of 20, expecting (19) at last.
cache := cacheBuilder.WithLazyScavenger(func(key int) int64 {
return int64(key)
}, int64(sumCapacity)).WithFinalizer(func(key, value int) error {
}, int64(sumCapacity)).WithFinalizer(func(ctx context.Context, key, value int) error {
finalizeSeq = append(finalizeSeq, key)
return nil
}).Build()
for i := 0; i < 20; i++ {
missing, err := cache.Do(i, func(v int) error {
missing, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -84,7 +84,7 @@ func TestLRUCache(t *testing.T) {
t.Run("test do negative", func(t *testing.T) {
cache := cacheBuilder.Build()
theErr := errors.New("error")
missing, err := cache.Do(-1, func(v int) error {
missing, err := cache.Do(context.Background(), -1, func(v int) error {
return theErr
})
assert.True(t, missing)
@ -96,13 +96,13 @@ func TestLRUCache(t *testing.T) {
sumCapacity := 20 // inserting 1 to 19, capacity is set to sum of 20, expecting (19) at last.
cache := cacheBuilder.WithLazyScavenger(func(key int) int64 {
return int64(key)
}, int64(sumCapacity)).WithFinalizer(func(key, value int) error {
}, int64(sumCapacity)).WithFinalizer(func(ctx context.Context, key, value int) error {
finalizeSeq = append(finalizeSeq, key)
return nil
}).Build()
for i := 0; i < 20; i++ {
missing, err := cache.Do(i, func(v int) error {
missing, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -110,7 +110,7 @@ func TestLRUCache(t *testing.T) {
assert.NoError(t, err)
}
assert.Equal(t, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, finalizeSeq)
missing, err := cache.Do(100, func(v int) error {
missing, err := cache.Do(context.Background(), 100, func(v int) error {
return nil
})
assert.True(t, missing)
@ -118,28 +118,52 @@ func TestLRUCache(t *testing.T) {
})
t.Run("test load negative", func(t *testing.T) {
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
if key < 0 {
return 0, false
return 0, ErrNoSuchItem
}
return key, true
return key, nil
}).Build()
missing, err := cache.Do(0, func(v int) error {
missing, err := cache.Do(context.Background(), 0, func(v int) error {
return nil
})
assert.True(t, missing)
assert.NoError(t, err)
missing, err = cache.Do(-1, func(v int) error {
missing, err = cache.Do(context.Background(), -1, func(v int) error {
return nil
})
assert.True(t, missing)
assert.Equal(t, ErrNoSuchItem, err)
})
t.Run("test reloader", func(t *testing.T) {
cache := cacheBuilder.WithReloader(func(ctx context.Context, key int) (int, error) {
return -key, nil
}).Build()
_, err := cache.Do(context.Background(), 1, func(i int) error { return nil })
assert.NoError(t, err)
exist := cache.MarkItemNeedReload(context.Background(), 1)
assert.True(t, exist)
cache.Do(context.Background(), 1, func(i int) error {
assert.Equal(t, -1, i)
return nil
})
})
t.Run("test mark", func(t *testing.T) {
cache := cacheBuilder.WithCapacity(1).Build()
exist := cache.MarkItemNeedReload(context.Background(), 1)
assert.False(t, exist)
_, err := cache.Do(context.Background(), 1, func(i int) error { return nil })
assert.NoError(t, err)
exist = cache.MarkItemNeedReload(context.Background(), 1)
assert.True(t, exist)
})
}
func TestStats(t *testing.T) {
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
})
t.Run("test loader", func(t *testing.T) {
@ -155,7 +179,7 @@ func TestStats(t *testing.T) {
assert.Equal(t, uint64(0), stats.LoadFailCount.Load())
for i := 0; i < size; i++ {
_, err := cache.Do(i, func(v int) error {
_, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -170,7 +194,7 @@ func TestStats(t *testing.T) {
assert.Equal(t, uint64(0), stats.LoadFailCount.Load())
for i := 0; i < size; i++ {
_, err := cache.Do(i, func(v int) error {
_, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -184,7 +208,7 @@ func TestStats(t *testing.T) {
assert.Equal(t, uint64(0), stats.LoadFailCount.Load())
for i := size; i < size*2; i++ {
_, err := cache.Do(i, func(v int) error {
_, err := cache.Do(context.Background(), i, func(v int) error {
assert.Equal(t, i, v)
return nil
})
@ -202,9 +226,9 @@ func TestStats(t *testing.T) {
func TestLRUCacheConcurrency(t *testing.T) {
t.Run("test race condition", func(t *testing.T) {
numEvict := new(atomic.Int32)
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
}).WithCapacity(10).WithFinalizer(func(key, value int) error {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(10).WithFinalizer(func(ctx context.Context, key, value int) error {
numEvict.Add(1)
return nil
}).Build()
@ -215,7 +239,7 @@ func TestLRUCacheConcurrency(t *testing.T) {
go func(i int) {
defer wg.Done()
for j := 0; j < 100; j++ {
_, err := cache.Do(j, func(v int) error {
_, err := cache.Do(context.Background(), j, func(v int) error {
return nil
})
assert.NoError(t, err)
@ -226,9 +250,9 @@ func TestLRUCacheConcurrency(t *testing.T) {
})
t.Run("test not enough space", func(t *testing.T) {
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
}).WithCapacity(1).WithFinalizer(func(key, value int) error {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(1).WithFinalizer(func(ctx context.Context, key, value int) error {
return nil
}).Build()
@ -236,13 +260,13 @@ func TestLRUCacheConcurrency(t *testing.T) {
var wg1 sync.WaitGroup // Make sure goroutine is started
wg.Add(1)
wg1.Add(1)
go cache.Do(1000, func(v int) error {
go cache.Do(context.Background(), 1000, func(v int) error {
wg1.Done()
wg.Wait()
return nil
})
wg1.Wait()
_, err := cache.Do(1001, func(v int) error {
_, err := cache.Do(context.Background(), 1001, func(v int) error {
return nil
})
wg.Done()
@ -250,9 +274,9 @@ func TestLRUCacheConcurrency(t *testing.T) {
})
t.Run("test time out", func(t *testing.T) {
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
}).WithCapacity(1).WithFinalizer(func(key, value int) error {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(1).WithFinalizer(func(ctx context.Context, key, value int) error {
return nil
}).Build()
@ -260,13 +284,13 @@ func TestLRUCacheConcurrency(t *testing.T) {
var wg1 sync.WaitGroup // Make sure goroutine is started
wg.Add(1)
wg1.Add(1)
go cache.Do(1000, func(v int) error {
go cache.Do(context.Background(), 1000, func(v int) error {
wg1.Done()
wg.Wait()
return nil
})
wg1.Wait()
missing, err := cache.DoWait(1001, time.Nanosecond, func(v int) error {
missing, err := cache.DoWait(context.Background(), 1001, time.Nanosecond, func(ctx context.Context, v int) error {
return nil
})
wg.Done()
@ -275,22 +299,22 @@ func TestLRUCacheConcurrency(t *testing.T) {
})
t.Run("test wait", func(t *testing.T) {
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
}).WithCapacity(1).WithFinalizer(func(key, value int) error {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(1).WithFinalizer(func(ctx context.Context, key, value int) error {
return nil
}).Build()
var wg1 sync.WaitGroup // Make sure goroutine is started
wg1.Add(1)
go cache.Do(1000, func(v int) error {
go cache.Do(context.Background(), 1000, func(v int) error {
wg1.Done()
time.Sleep(time.Second)
return nil
})
wg1.Wait()
missing, err := cache.DoWait(1001, time.Second*2, func(v int) error {
missing, err := cache.DoWait(context.Background(), 1001, time.Second*2, func(ctx context.Context, v int) error {
return nil
})
assert.True(t, missing)
@ -299,9 +323,9 @@ func TestLRUCacheConcurrency(t *testing.T) {
t.Run("test wait race condition", func(t *testing.T) {
numEvict := new(atomic.Int32)
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
return key, true
}).WithCapacity(5).WithFinalizer(func(key, value int) error {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(5).WithFinalizer(func(ctx context.Context, key, value int) error {
numEvict.Add(1)
return nil
}).Build()
@ -311,9 +335,8 @@ func TestLRUCacheConcurrency(t *testing.T) {
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 20; j++ {
_, err := cache.DoWait(j, time.Second, func(v int) error {
time.Sleep(time.Duration(rand.Intn(3)))
for j := 0; j < 100; j++ {
_, err := cache.DoWait(context.Background(), j, 2*time.Second, func(ctx context.Context, v int) error {
return nil
})
assert.NoError(t, err)
@ -322,4 +345,93 @@ func TestLRUCacheConcurrency(t *testing.T) {
}
wg.Wait()
})
t.Run("test concurrent reload and mark", func(t *testing.T) {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(5).WithFinalizer(func(ctx context.Context, key, value int) error {
return nil
}).WithReloader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).Build()
for i := 0; i < 100; i++ {
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error { return nil })
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
for j := 0; j < 100; j++ {
cache.MarkItemNeedReload(context.Background(), j)
}
}
}()
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
for j := 0; j < 100; j++ {
cache.DoWait(context.Background(), j, 2*time.Second, func(ctx context.Context, v int) error { return nil })
}
}
}()
wg.Wait()
})
t.Run("test expire", func(t *testing.T) {
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).WithCapacity(5).WithFinalizer(func(ctx context.Context, key, value int) error {
return nil
}).WithReloader(func(ctx context.Context, key int) (int, error) {
return key, nil
}).Build()
for i := 0; i < 100; i++ {
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error { return nil })
}
evicted := 0
for i := 0; i < 100; i++ {
if cache.Expire(context.Background(), i) {
evicted++
}
}
assert.Equal(t, 100, evicted)
// all item shouldn't be evicted if they are in used.
for i := 0; i < 5; i++ {
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error { return nil })
}
wg := sync.WaitGroup{}
wg.Add(5)
for i := 0; i < 5; i++ {
go func(i int) {
defer wg.Done()
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error {
time.Sleep(2 * time.Second)
return nil
})
}(i)
}
time.Sleep(1 * time.Second)
evicted = 0
for i := 0; i < 5; i++ {
if cache.Expire(context.Background(), i) {
evicted++
}
}
assert.Zero(t, evicted)
wg.Wait()
evicted = 0
for i := 0; i < 5; i++ {
if cache.Expire(context.Background(), i) {
evicted++
}
}
assert.Equal(t, 5, evicted)
})
}

View File

@ -43,6 +43,7 @@ var (
ErrServiceQuotaExceeded = newMilvusError("quota exceeded", 9, false)
ErrServiceUnimplemented = newMilvusError("service unimplemented", 10, false)
ErrServiceTimeTickLongDelay = newMilvusError("time tick long delay", 11, false)
ErrServiceResourceInsufficient = newMilvusError("service resource insufficient", 12, true)
// Collection related
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
@ -85,6 +86,7 @@ var (
ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false)
ErrSegmentLack = newMilvusError("segment lacks", 602, false)
ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false)
ErrSegmentLoadFailed = newMilvusError("segment load failed", 604, false)
// Index related
ErrIndexNotFound = newMilvusError("index not found", 700, false)
@ -167,6 +169,15 @@ var (
// Search/Query related
ErrInconsistentRequery = newMilvusError("inconsistent requery result", 2200, true)
// Compaction
ErrCompactionReadDeltaLogErr = newMilvusError("fail to read delta log", 2300, false)
ErrClusteringCompactionClusterNotSupport = newMilvusError("milvus cluster not support clustering compaction", 2301, false)
ErrClusteringCompactionCollectionNotSupport = newMilvusError("collection not support clustering compaction", 2302, false)
ErrClusteringCompactionCollectionIsCompacting = newMilvusError("collection is compacting", 2303, false)
// General
ErrOperationNotSupported = newMilvusError("unsupported operation", 3000, false)
)
type milvusError struct {

View File

@ -678,6 +678,14 @@ func WrapErrSegmentsNotFound(ids []int64, msg ...string) error {
return err
}
func WrapErrSegmentLoadFailed(id int64, msg ...string) error {
err := wrapFields(ErrSegmentLoadFailed, value("segment", id))
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "->"))
}
return err
}
func WrapErrSegmentNotLoaded(id int64, msg ...string) error {
err := wrapFields(ErrSegmentNotLoaded, value("segment", id))
if len(msg) > 0 {

View File

@ -1994,7 +1994,8 @@ type queryNodeConfig struct {
MmapDirPath ParamItem `refreshable:"false"`
MmapEnabled ParamItem `refreshable:"false"`
LazyLoadEnabled ParamItem `refreshable:"false"`
LazyLoadEnabled ParamItem `refreshable:"false"`
LazyLoadWaitTimeout ParamItem `refreshable:"false"`
// chunk cache
ReadAheadPolicy ParamItem `refreshable:"false"`
@ -2045,6 +2046,7 @@ type queryNodeConfig struct {
MemoryIndexLoadPredictMemoryUsageFactor ParamItem `refreshable:"true"`
EnableSegmentPrune ParamItem `refreshable:"false"`
DefaultSegmentFilterRatio ParamItem `refreshable:"false"`
UseStreamComputing ParamItem `refreshable:"false"`
}
func (p *queryNodeConfig) init(base *BaseTable) {
@ -2237,6 +2239,13 @@ func (p *queryNodeConfig) init(base *BaseTable) {
Export: true,
}
p.LazyLoadEnabled.Init(base.mgr)
p.LazyLoadWaitTimeout = ParamItem{
Key: "queryNode.lazyloadWaitTimeout",
Version: "2.4.0",
DefaultValue: "30000",
Doc: "max wait timeout duration in milliseconds before start to do lazyload search and retrieve",
}
p.LazyLoadWaitTimeout.Init(base.mgr)
p.ReadAheadPolicy = ParamItem{
Key: "queryNode.cache.readAheadPolicy",
@ -2569,6 +2578,13 @@ user-task-polling:
Doc: "filter ratio used for pruning segments when searching",
}
p.DefaultSegmentFilterRatio.Init(base.mgr)
p.UseStreamComputing = ParamItem{
Key: "queryNode.useStreamComputing",
Version: "2.4.0",
DefaultValue: "false",
Doc: "use stream search mode when searching or querying",
}
p.UseStreamComputing.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////