mirror of https://github.com/milvus-io/milvus.git
feat: LRU cache implementation (#32567)
issue: https://github.com/milvus-io/milvus/issues/32783 This pr is the implementation of lru cache on branch lru-dev. Signed-off-by: sunby <sunbingyi1992@gmail.com> Co-authored-by: chyezh <chyezh@outlook.com> Co-authored-by: MrPresent-Han <chun.han@zilliz.com> Co-authored-by: Ted Xu <ted.xu@zilliz.com> Co-authored-by: jaime <yun.zhang@zilliz.com> Co-authored-by: wayblink <anyang.wang@zilliz.com>pull/32709/head
parent
37a99ca23e
commit
fecd9c21ba
|
@ -367,6 +367,7 @@ queryNode:
|
|||
maxQueueLength: 16 # Maximum length of task queue in flowgraph
|
||||
maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph
|
||||
enableSegmentPrune: false # use partition prune function on shard delegator
|
||||
useStreamComputing: false
|
||||
ip: # if not specified, use the first unicastable address
|
||||
port: 21123
|
||||
grpc:
|
||||
|
|
|
@ -502,6 +502,11 @@ VectorDiskAnnIndex<T>::update_load_json(const Config& config) {
|
|||
}
|
||||
}
|
||||
|
||||
if (config.contains(kMmapFilepath)) {
|
||||
load_config.erase(kMmapFilepath);
|
||||
load_config[kEnableMmap] = true;
|
||||
}
|
||||
|
||||
return load_config;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
|
@ -33,7 +32,6 @@
|
|||
|
||||
#include "index/Index.h"
|
||||
#include "index/IndexInfo.h"
|
||||
#include "index/Meta.h"
|
||||
#include "index/Utils.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "config/ConfigKnowhere.h"
|
||||
|
@ -44,16 +42,15 @@
|
|||
#include "common/FieldData.h"
|
||||
#include "common/File.h"
|
||||
#include "common/Slice.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
#include "common/Utils.h"
|
||||
#include "log/Log.h"
|
||||
#include "mmap/Types.h"
|
||||
#include "storage/DataCodec.h"
|
||||
#include "storage/MemFileManagerImpl.h"
|
||||
#include "storage/ThreadPools.h"
|
||||
#include "storage/space.h"
|
||||
#include "storage/Util.h"
|
||||
#include "storage/prometheus_client.h"
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
|
@ -733,7 +730,8 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
|||
}
|
||||
|
||||
LOG_INFO("load with slice meta: {}", !slice_meta_filepath.empty());
|
||||
|
||||
std::chrono::duration<double> load_duration_sum;
|
||||
std::chrono::duration<double> write_disk_duration_sum;
|
||||
if (!slice_meta_filepath
|
||||
.empty()) { // load with the slice meta info, then we can load batch by batch
|
||||
std::string index_file_prefix = slice_meta_filepath.substr(
|
||||
|
@ -751,15 +749,20 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
|||
std::string prefix = item[NAME];
|
||||
int slice_num = item[SLICE_NUM];
|
||||
auto total_len = static_cast<size_t>(item[TOTAL_LEN]);
|
||||
|
||||
auto HandleBatch = [&](int index) {
|
||||
auto start_load2_mem = std::chrono::system_clock::now();
|
||||
auto batch_data = file_manager_->LoadIndexToMemory(batch);
|
||||
load_duration_sum +=
|
||||
(std::chrono::system_clock::now() - start_load2_mem);
|
||||
for (int j = index - batch.size() + 1; j <= index; j++) {
|
||||
std::string file_name = GenSlicedFileName(prefix, j);
|
||||
AssertInfo(batch_data.find(file_name) != batch_data.end(),
|
||||
"lost index slice data");
|
||||
auto data = batch_data[file_name];
|
||||
auto start_write_file = std::chrono::system_clock::now();
|
||||
auto written = file.Write(data->Data(), data->Size());
|
||||
write_disk_duration_sum +=
|
||||
(std::chrono::system_clock::now() - start_write_file);
|
||||
AssertInfo(
|
||||
written == data->Size(),
|
||||
fmt::format("failed to write index data to disk {}: {}",
|
||||
|
@ -784,24 +787,46 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
//1. load files into memory
|
||||
auto start_load_files2_mem = std::chrono::system_clock::now();
|
||||
auto result = file_manager_->LoadIndexToMemory(std::vector<std::string>(
|
||||
pending_index_files.begin(), pending_index_files.end()));
|
||||
load_duration_sum +=
|
||||
(std::chrono::system_clock::now() - start_load_files2_mem);
|
||||
//2. write data into files
|
||||
auto start_write_file = std::chrono::system_clock::now();
|
||||
for (auto& [_, index_data] : result) {
|
||||
file.Write(index_data->Data(), index_data->Size());
|
||||
}
|
||||
write_disk_duration_sum +=
|
||||
(std::chrono::system_clock::now() - start_write_file);
|
||||
}
|
||||
milvus::storage::internal_storage_download_duration.Observe(
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(load_duration_sum)
|
||||
.count());
|
||||
milvus::storage::internal_storage_write_disk_duration.Observe(
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
write_disk_duration_sum)
|
||||
.count());
|
||||
file.Close();
|
||||
|
||||
LOG_INFO("load index into Knowhere...");
|
||||
auto conf = config;
|
||||
conf.erase(kMmapFilepath);
|
||||
conf[kEnableMmap] = true;
|
||||
auto start_deserialize = std::chrono::system_clock::now();
|
||||
auto stat = index_.DeserializeFromFile(filepath.value(), conf);
|
||||
auto deserialize_duration =
|
||||
std::chrono::system_clock::now() - start_deserialize;
|
||||
if (stat != knowhere::Status::success) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"failed to Deserialize index: {}",
|
||||
KnowhereStatusString(stat));
|
||||
}
|
||||
milvus::storage::internal_storage_deserialize_duration.Observe(
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
deserialize_duration)
|
||||
.count());
|
||||
|
||||
auto dim = index_.Dim();
|
||||
this->SetDim(index_.Dim());
|
||||
|
@ -811,7 +836,18 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
|||
"failed to unlink mmap index file {}: {}",
|
||||
filepath.value(),
|
||||
strerror(errno));
|
||||
LOG_INFO("load vector index done");
|
||||
LOG_INFO(
|
||||
"load vector index done, mmap_file_path:{}, download_duration:{}, "
|
||||
"write_files_duration:{}, deserialize_duration:{}",
|
||||
filepath.value(),
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(load_duration_sum)
|
||||
.count(),
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
write_disk_duration_sum)
|
||||
.count(),
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
deserialize_duration)
|
||||
.count());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -25,6 +25,7 @@ set(SEGCORE_FILES
|
|||
SegmentSealedImpl.cpp
|
||||
FieldIndexing.cpp
|
||||
Reduce.cpp
|
||||
StreamReduce.cpp
|
||||
metrics_c.cpp
|
||||
plan_c.cpp
|
||||
reduce_c.cpp
|
||||
|
@ -36,7 +37,8 @@ set(SEGCORE_FILES
|
|||
segcore_init_c.cpp
|
||||
TimestampIndex.cpp
|
||||
Utils.cpp
|
||||
ConcurrentVector.cpp)
|
||||
ConcurrentVector.cpp
|
||||
ReduceUtils.cpp)
|
||||
add_library(milvus_segcore SHARED ${SEGCORE_FILES})
|
||||
|
||||
target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage)
|
||||
|
|
|
@ -41,10 +41,10 @@ class ThreadSafeVector {
|
|||
template <typename... Args>
|
||||
void
|
||||
emplace_to_at_least(int64_t size, Args... args) {
|
||||
std::lock_guard lck(mutex_);
|
||||
if (size <= size_) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard lck(mutex_);
|
||||
while (vec_.size() < size) {
|
||||
vec_.emplace_back(std::forward<Args...>(args...));
|
||||
++size_;
|
||||
|
@ -52,24 +52,25 @@ class ThreadSafeVector {
|
|||
}
|
||||
const Type&
|
||||
operator[](int64_t index) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
AssertInfo(index < size_,
|
||||
fmt::format(
|
||||
"index out of range, index={}, size_={}", index, size_));
|
||||
std::shared_lock lck(mutex_);
|
||||
return vec_[index];
|
||||
}
|
||||
|
||||
Type&
|
||||
operator[](int64_t index) {
|
||||
std::shared_lock lck(mutex_);
|
||||
AssertInfo(index < size_,
|
||||
fmt::format(
|
||||
"index out of range, index={}, size_={}", index, size_));
|
||||
std::shared_lock lck(mutex_);
|
||||
return vec_[index];
|
||||
}
|
||||
|
||||
int64_t
|
||||
size() const {
|
||||
std::lock_guard lck(mutex_);
|
||||
return size_;
|
||||
}
|
||||
|
||||
|
@ -81,7 +82,7 @@ class ThreadSafeVector {
|
|||
}
|
||||
|
||||
private:
|
||||
std::atomic<int64_t> size_ = 0;
|
||||
int64_t size_ = 0;
|
||||
std::deque<Type> vec_;
|
||||
mutable std::shared_mutex mutex_;
|
||||
};
|
||||
|
|
|
@ -598,6 +598,11 @@ struct InsertRecord {
|
|||
fields_data_.clear();
|
||||
}
|
||||
|
||||
bool
|
||||
empty() const {
|
||||
return pk2offset_->empty();
|
||||
}
|
||||
|
||||
public:
|
||||
ConcurrentVector<Timestamp> timestamps_;
|
||||
ConcurrentVector<idx_t> row_ids_;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "Utils.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "pkVisitor.h"
|
||||
#include "ReduceUtils.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -130,7 +131,6 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
|||
|
||||
void
|
||||
ReduceHelper::FillPrimaryKey() {
|
||||
std::vector<SearchResult*> valid_search_results;
|
||||
// get primary keys for duplicates removal
|
||||
uint32_t valid_index = 0;
|
||||
for (auto& search_result : search_results_) {
|
||||
|
@ -368,7 +368,7 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
search_result_data->set_all_search_count(all_search_count);
|
||||
|
||||
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
|
||||
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
|
||||
std::vector<MergeBase> result_pairs(result_count);
|
||||
|
||||
// reserve space for pks
|
||||
auto primary_field_id =
|
||||
|
@ -461,14 +461,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
group_by_values[loc] =
|
||||
search_result->group_by_values_.value()[ki];
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = std::make_pair(search_result, ki);
|
||||
result_pairs[loc] = {&search_result->output_fields_data_, ki};
|
||||
}
|
||||
}
|
||||
|
||||
// update result topKs
|
||||
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
|
||||
}
|
||||
AssembleGroupByValues(search_result_data, group_by_values);
|
||||
AssembleGroupByValues(search_result_data, group_by_values, plan_);
|
||||
|
||||
AssertInfo(search_result_data->scores_size() == result_count,
|
||||
"wrong scores size, size = " +
|
||||
|
@ -498,89 +498,4 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
return buffer;
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::AssembleGroupByValues(
|
||||
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
|
||||
const std::vector<GroupByValueType>& group_by_vals) {
|
||||
auto group_by_field_id = plan_->plan_node_->search_info_.group_by_field_id_;
|
||||
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
|
||||
auto group_by_values_field =
|
||||
std::make_unique<milvus::proto::schema::ScalarField>();
|
||||
auto group_by_field =
|
||||
plan_->schema_.operator[](group_by_field_id.value());
|
||||
DataType group_by_data_type = group_by_field.get_data_type();
|
||||
|
||||
int group_by_val_size = group_by_vals.size();
|
||||
switch (group_by_data_type) {
|
||||
case DataType::INT8: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int8_t val = std::get<int8_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int16_t val = std::get<int16_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int32_t val = std::get<int32_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
auto field_data = group_by_values_field->mutable_long_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int64_t val = std::get<int64_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::BOOL: {
|
||||
auto field_data = group_by_values_field->mutable_bool_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
bool val = std::get<bool>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto field_data = group_by_values_field->mutable_string_data();
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
std::string val =
|
||||
std::move(std::get<std::string>(group_by_vals[idx]));
|
||||
*(field_data->mutable_data()->Add()) = val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported datatype for group_by operations ",
|
||||
group_by_data_type));
|
||||
}
|
||||
}
|
||||
|
||||
search_result->mutable_group_by_field_value()->set_type(
|
||||
milvus::proto::schema::DataType(group_by_data_type));
|
||||
search_result->mutable_group_by_field_value()
|
||||
->mutable_scalars()
|
||||
->MergeFrom(*group_by_values_field.get());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "common/QueryResult.h"
|
||||
#include "query/PlanImpl.h"
|
||||
#include "ReduceStructure.h"
|
||||
#include "segment_c.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -55,18 +56,22 @@ class ReduceHelper {
|
|||
return search_result_data_blobs_.release();
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
Initialize();
|
||||
|
||||
protected:
|
||||
void
|
||||
FilterInvalidSearchResult(SearchResult* search_result);
|
||||
|
||||
void
|
||||
RefreshSearchResult();
|
||||
|
||||
void
|
||||
FillPrimaryKey();
|
||||
|
||||
void
|
||||
RefreshSearchResult();
|
||||
ReduceResultData();
|
||||
|
||||
private:
|
||||
void
|
||||
Initialize();
|
||||
|
||||
void
|
||||
FillEntryData();
|
||||
|
@ -76,44 +81,34 @@ class ReduceHelper {
|
|||
int64_t topk,
|
||||
int64_t& result_offset);
|
||||
|
||||
void
|
||||
ReduceResultData();
|
||||
|
||||
std::vector<char>
|
||||
GetSearchResultDataSlice(int slice_index_);
|
||||
|
||||
void
|
||||
AssembleGroupByValues(
|
||||
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
|
||||
const std::vector<GroupByValueType>& group_by_vals);
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::vector<SearchResult*>& search_results_;
|
||||
milvus::query::Plan* plan_;
|
||||
|
||||
std::vector<int64_t> slice_nqs_;
|
||||
std::vector<int64_t> slice_topKs_;
|
||||
int64_t total_nq_;
|
||||
int64_t num_segments_;
|
||||
int64_t num_slices_;
|
||||
|
||||
std::vector<int64_t> slice_nqs_prefix_sum_;
|
||||
|
||||
// dim0: num_segments_; dim1: total_nq_; dim2: offset
|
||||
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
|
||||
|
||||
// output
|
||||
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
|
||||
|
||||
// Used for merge results,
|
||||
// define these here to avoid allocating them for each query
|
||||
std::vector<SearchResultPair> pairs_;
|
||||
int64_t num_segments_;
|
||||
std::vector<int64_t> slice_topKs_;
|
||||
std::priority_queue<SearchResultPair*,
|
||||
std::vector<SearchResultPair*>,
|
||||
SearchResultPairComparator>
|
||||
heap_;
|
||||
// Used for merge results,
|
||||
// define these here to avoid allocating them for each query
|
||||
std::vector<SearchResultPair> pairs_;
|
||||
std::unordered_set<milvus::PkType> pk_set_;
|
||||
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
|
||||
// dim0: num_segments_; dim1: total_nq_; dim2: offset
|
||||
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> slice_nqs_;
|
||||
int64_t total_nq_;
|
||||
|
||||
// output
|
||||
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
|
||||
};
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
//
|
||||
// Created by zilliz on 2024/3/26.
|
||||
//
|
||||
|
||||
#include "ReduceUtils.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
void
|
||||
AssembleGroupByValues(
|
||||
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
|
||||
const std::vector<GroupByValueType>& group_by_vals,
|
||||
milvus::query::Plan* plan) {
|
||||
auto group_by_field_id = plan->plan_node_->search_info_.group_by_field_id_;
|
||||
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
|
||||
auto group_by_values_field =
|
||||
std::make_unique<milvus::proto::schema::ScalarField>();
|
||||
auto group_by_field =
|
||||
plan->schema_.operator[](group_by_field_id.value());
|
||||
DataType group_by_data_type = group_by_field.get_data_type();
|
||||
|
||||
int group_by_val_size = group_by_vals.size();
|
||||
switch (group_by_data_type) {
|
||||
case DataType::INT8: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int8_t val = std::get<int8_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int16_t val = std::get<int16_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int32_t val = std::get<int32_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
auto field_data = group_by_values_field->mutable_long_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int64_t val = std::get<int64_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::BOOL: {
|
||||
auto field_data = group_by_values_field->mutable_bool_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
bool val = std::get<bool>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto field_data = group_by_values_field->mutable_string_data();
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
std::string val =
|
||||
std::move(std::get<std::string>(group_by_vals[idx]));
|
||||
*(field_data->mutable_data()->Add()) = val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported datatype for group_by operations ",
|
||||
group_by_data_type));
|
||||
}
|
||||
}
|
||||
|
||||
search_result->mutable_group_by_field_value()->set_type(
|
||||
milvus::proto::schema::DataType(group_by_data_type));
|
||||
search_result->mutable_group_by_field_value()
|
||||
->mutable_scalars()
|
||||
->MergeFrom(*group_by_values_field.get());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pb/schema.pb.h"
|
||||
#include "common/Types.h"
|
||||
#include "query/PlanImpl.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
void
|
||||
AssembleGroupByValues(
|
||||
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
|
||||
const std::vector<GroupByValueType>& group_by_vals,
|
||||
milvus::query::Plan* plan);
|
||||
|
||||
}
|
|
@ -105,6 +105,10 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
|
|||
") than other column's row count (" +
|
||||
std::to_string(num_rows_.value()) + ")");
|
||||
}
|
||||
LOG_INFO(
|
||||
"Before setting field_bit for field index, fieldID:{}. segmentID:{}, ",
|
||||
info.field_id,
|
||||
id_);
|
||||
if (get_bit(field_data_ready_bitset_, field_id)) {
|
||||
fields_.erase(field_id);
|
||||
set_bit(field_data_ready_bitset_, field_id, false);
|
||||
|
@ -118,6 +122,9 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
|
|||
metric_type,
|
||||
std::move(const_cast<LoadIndexInfo&>(info).index));
|
||||
set_bit(index_ready_bitset_, field_id, true);
|
||||
LOG_INFO("Has load vec index done, fieldID:{}. segmentID:{}, ",
|
||||
info.field_id,
|
||||
id_);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -1021,7 +1028,8 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type,
|
|||
int64_t count,
|
||||
void* output) const {
|
||||
AssertInfo(is_system_field_ready(),
|
||||
"System field isn't ready when do bulk_insert");
|
||||
"System field isn't ready when do bulk_insert, segID:{}",
|
||||
id_);
|
||||
switch (system_type) {
|
||||
case SystemFieldType::Timestamp:
|
||||
AssertInfo(
|
||||
|
@ -1647,5 +1655,20 @@ SegmentSealedImpl::generate_interim_index(const FieldId field_id) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
void
|
||||
SegmentSealedImpl::RemoveFieldFile(const FieldId field_id) {
|
||||
auto cc = storage::ChunkCacheSingleton::GetInstance().GetChunkCache();
|
||||
if (cc == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (const auto& iter : field_data_info_.field_infos) {
|
||||
if (iter.second.field_id == field_id.get()) {
|
||||
for (const auto& binlog : iter.second.insert_files) {
|
||||
cc->Remove(binlog);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -88,6 +88,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
DataType
|
||||
GetFieldDataType(FieldId fieldId) const override;
|
||||
|
||||
void
|
||||
RemoveFieldFile(const FieldId field_id);
|
||||
|
||||
public:
|
||||
size_t
|
||||
GetMemoryUsageInBytes() const override {
|
||||
|
|
|
@ -0,0 +1,690 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "StreamReduce.h"
|
||||
#include "SegmentInterface.h"
|
||||
#include "segcore/Utils.h"
|
||||
#include "Reduce.h"
|
||||
#include "segcore/pkVisitor.h"
|
||||
#include "segcore/ReduceUtils.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
void
|
||||
StreamReducerHelper::FillEntryData() {
|
||||
for (auto search_result : search_results_to_merge_) {
|
||||
auto segment = static_cast<milvus::segcore::SegmentInterface*>(
|
||||
search_result->segment_);
|
||||
segment->FillTargetEntry(plan_, *search_result);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::AssembleMergedResult() {
|
||||
if (search_results_to_merge_.size() > 0) {
|
||||
std::unique_ptr<MergedSearchResult> new_merged_result =
|
||||
std::make_unique<MergedSearchResult>();
|
||||
std::vector<PkType> new_merged_pks;
|
||||
std::vector<float> new_merged_distances;
|
||||
std::vector<GroupByValueType> new_merged_groupBy_vals;
|
||||
std::vector<MergeBase> merge_output_data_bases;
|
||||
std::vector<int64_t> new_result_offsets;
|
||||
bool need_handle_groupBy =
|
||||
plan_->plan_node_->search_info_.group_by_field_id_.has_value();
|
||||
int valid_size = 0;
|
||||
std::vector<int> real_topKs(total_nq_);
|
||||
for (int i = 0; i < num_slice_; i++) {
|
||||
auto nq_begin = slice_nqs_prefix_sum_[i];
|
||||
auto nq_end = slice_nqs_prefix_sum_[i + 1];
|
||||
int64_t result_count = 0;
|
||||
for (auto search_result : search_results_to_merge_) {
|
||||
AssertInfo(
|
||||
search_result->topk_per_nq_prefix_sum_.size() ==
|
||||
search_result->total_nq_ + 1,
|
||||
"incorrect topk_per_nq_prefix_sum_ size in search result");
|
||||
result_count +=
|
||||
search_result->topk_per_nq_prefix_sum_[nq_end] -
|
||||
search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
}
|
||||
if (merged_search_result->has_result_) {
|
||||
result_count +=
|
||||
merged_search_result->topk_per_nq_prefix_sum_[nq_end] -
|
||||
merged_search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
}
|
||||
int nq_base_offset = valid_size;
|
||||
valid_size += result_count;
|
||||
new_merged_pks.resize(valid_size);
|
||||
new_merged_distances.resize(valid_size);
|
||||
merge_output_data_bases.resize(valid_size);
|
||||
new_result_offsets.resize(valid_size);
|
||||
if (need_handle_groupBy) {
|
||||
new_merged_groupBy_vals.resize(valid_size);
|
||||
}
|
||||
for (auto qi = nq_begin; qi < nq_end; qi++) {
|
||||
for (auto search_result : search_results_to_merge_) {
|
||||
AssertInfo(search_result != nullptr,
|
||||
"null search result when reorganize");
|
||||
if (search_result->result_offsets_.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
auto topK_start =
|
||||
search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto topK_end =
|
||||
search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
for (auto ki = topK_start; ki < topK_end; ki++) {
|
||||
auto loc = search_result->result_offsets_[ki];
|
||||
AssertInfo(loc < result_count && loc >= 0,
|
||||
"invalid loc when GetSearchResultDataSlice, "
|
||||
"loc = " +
|
||||
std::to_string(loc) +
|
||||
", result_count = " +
|
||||
std::to_string(result_count));
|
||||
|
||||
new_merged_pks[nq_base_offset + loc] =
|
||||
search_result->primary_keys_[ki];
|
||||
new_merged_distances[nq_base_offset + loc] =
|
||||
search_result->distances_[ki];
|
||||
if (need_handle_groupBy) {
|
||||
new_merged_groupBy_vals[nq_base_offset + loc] =
|
||||
search_result->group_by_values_.value()[ki];
|
||||
}
|
||||
merge_output_data_bases[nq_base_offset + loc] = {
|
||||
&search_result->output_fields_data_, ki};
|
||||
new_result_offsets[nq_base_offset + loc] = loc;
|
||||
real_topKs[qi]++;
|
||||
}
|
||||
}
|
||||
if (merged_search_result->has_result_) {
|
||||
auto topK_start =
|
||||
merged_search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto topK_end =
|
||||
merged_search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
for (auto ki = topK_start; ki < topK_end; ki++) {
|
||||
auto loc = merged_search_result->reduced_offsets_[ki];
|
||||
AssertInfo(loc < result_count && loc >= 0,
|
||||
"invalid loc when GetSearchResultDataSlice, "
|
||||
"loc = " +
|
||||
std::to_string(loc) +
|
||||
", result_count = " +
|
||||
std::to_string(result_count));
|
||||
|
||||
new_merged_pks[nq_base_offset + loc] =
|
||||
merged_search_result->primary_keys_[ki];
|
||||
new_merged_distances[nq_base_offset + loc] =
|
||||
merged_search_result->distances_[ki];
|
||||
if (need_handle_groupBy) {
|
||||
new_merged_groupBy_vals[nq_base_offset + loc] =
|
||||
merged_search_result->group_by_values_
|
||||
.value()[ki];
|
||||
}
|
||||
merge_output_data_bases[nq_base_offset + loc] = {
|
||||
&merged_search_result->output_fields_data_, ki};
|
||||
new_result_offsets[nq_base_offset + loc] = loc;
|
||||
real_topKs[qi]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
new_merged_result->primary_keys_ = std::move(new_merged_pks);
|
||||
new_merged_result->distances_ = std::move(new_merged_distances);
|
||||
if (need_handle_groupBy) {
|
||||
new_merged_result->group_by_values_ =
|
||||
std::move(new_merged_groupBy_vals);
|
||||
}
|
||||
new_merged_result->topk_per_nq_prefix_sum_.resize(total_nq_ + 1);
|
||||
std::partial_sum(
|
||||
real_topKs.begin(),
|
||||
real_topKs.end(),
|
||||
new_merged_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
new_merged_result->result_offsets_ = std::move(new_result_offsets);
|
||||
for (auto field_id : plan_->target_entries_) {
|
||||
auto& field_meta = plan_->schema_[field_id];
|
||||
auto field_data =
|
||||
MergeDataArray(merge_output_data_bases, field_meta);
|
||||
if (field_meta.get_data_type() == DataType::ARRAY) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
}
|
||||
new_merged_result->output_fields_data_[field_id] =
|
||||
std::move(field_data);
|
||||
}
|
||||
merged_search_result = std::move(new_merged_result);
|
||||
merged_search_result->has_result_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::MergeReduce() {
|
||||
FilterSearchResults();
|
||||
FillPrimaryKeys();
|
||||
InitializeReduceRecords();
|
||||
ReduceResultData();
|
||||
RefreshSearchResult();
|
||||
FillEntryData();
|
||||
AssembleMergedResult();
|
||||
CleanReduceStatus();
|
||||
}
|
||||
|
||||
void*
|
||||
StreamReducerHelper::SerializeMergedResult() {
|
||||
std::unique_ptr<SearchResultDataBlobs> search_result_blobs =
|
||||
std::make_unique<milvus::segcore::SearchResultDataBlobs>();
|
||||
AssertInfo(num_slice_ > 0,
|
||||
"Wrong state for num_slice in streamReducer, num_slice:{}",
|
||||
num_slice_);
|
||||
search_result_blobs->blobs.resize(num_slice_);
|
||||
for (int i = 0; i < num_slice_; i++) {
|
||||
auto proto = GetSearchResultDataSlice(i);
|
||||
search_result_blobs->blobs[i] = proto;
|
||||
}
|
||||
return search_result_blobs.release();
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::ReduceResultData() {
|
||||
if (search_results_to_merge_.size() > 0) {
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_to_merge_[i];
|
||||
auto result_count = search_result->get_total_result_count();
|
||||
AssertInfo(search_result != nullptr,
|
||||
"search result must not equal to nullptr");
|
||||
AssertInfo(search_result->distances_.size() == result_count,
|
||||
"incorrect search result distance size");
|
||||
AssertInfo(search_result->seg_offsets_.size() == result_count,
|
||||
"incorrect search result seg offset size");
|
||||
AssertInfo(search_result->primary_keys_.size() == result_count,
|
||||
"incorrect search result primary key size");
|
||||
}
|
||||
for (int64_t slice_index = 0; slice_index < slice_nqs_.size();
|
||||
slice_index++) {
|
||||
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
|
||||
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
|
||||
|
||||
int64_t offset = 0;
|
||||
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
|
||||
StreamReduceSearchResultForOneNQ(
|
||||
qi, slice_topKs_[slice_index], offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::FilterSearchResults() {
|
||||
uint32_t valid_index = 0;
|
||||
for (auto& search_result : search_results_to_merge_) {
|
||||
// skip when results num is 0
|
||||
AssertInfo(search_result != nullptr,
|
||||
"search_result to merge cannot be nullptr, there must be "
|
||||
"sth wrong in the code");
|
||||
if (search_result->unity_topK_ == 0) {
|
||||
continue;
|
||||
}
|
||||
FilterInvalidSearchResult(search_result);
|
||||
search_results_to_merge_[valid_index++] = search_result;
|
||||
}
|
||||
search_results_to_merge_.resize(valid_index);
|
||||
num_segments_ = search_results_to_merge_.size();
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::InitializeReduceRecords() {
|
||||
// init final_search_records and final_read_topKs
|
||||
if (merged_search_result->has_result_) {
|
||||
final_search_records_.resize(num_segments_ + 1);
|
||||
} else {
|
||||
final_search_records_.resize(num_segments_);
|
||||
}
|
||||
for (auto& search_record : final_search_records_) {
|
||||
search_record.resize(total_nq_);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::FillPrimaryKeys() {
|
||||
for (auto& search_result : search_results_to_merge_) {
|
||||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
if (search_result->get_total_result_count() > 0) {
|
||||
segment->FillPrimaryKeys(plan_, *search_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
||||
auto total_nq = search_result->total_nq_;
|
||||
auto topK = search_result->unity_topK_;
|
||||
AssertInfo(search_result->seg_offsets_.size() == total_nq * topK,
|
||||
"wrong seg offsets size, size = " +
|
||||
std::to_string(search_result->seg_offsets_.size()) +
|
||||
", expected size = " + std::to_string(total_nq * topK));
|
||||
AssertInfo(search_result->distances_.size() == total_nq * topK,
|
||||
"wrong distances size, size = " +
|
||||
std::to_string(search_result->distances_.size()) +
|
||||
", expected size = " + std::to_string(total_nq * topK));
|
||||
std::vector<int64_t> real_topKs(total_nq, 0);
|
||||
uint32_t valid_index = 0;
|
||||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
auto& offsets = search_result->seg_offsets_;
|
||||
auto& distances = search_result->distances_;
|
||||
if (search_result->group_by_values_.has_value()) {
|
||||
AssertInfo(search_result->distances_.size() ==
|
||||
search_result->group_by_values_.value().size(),
|
||||
"wrong group_by_values size, size:{}, expected size:{} ",
|
||||
search_result->group_by_values_.value().size(),
|
||||
search_result->distances_.size());
|
||||
}
|
||||
|
||||
for (auto i = 0; i < total_nq; ++i) {
|
||||
for (auto j = 0; j < topK; ++j) {
|
||||
auto index = i * topK + j;
|
||||
if (offsets[index] != INVALID_SEG_OFFSET) {
|
||||
AssertInfo(0 <= offsets[index] &&
|
||||
offsets[index] < segment->get_row_count(),
|
||||
fmt::format("invalid offset {}, segment {} with "
|
||||
"rows num {}, data or index corruption",
|
||||
offsets[index],
|
||||
segment->get_segment_id(),
|
||||
segment->get_row_count()));
|
||||
real_topKs[i]++;
|
||||
offsets[valid_index] = offsets[index];
|
||||
distances[valid_index] = distances[index];
|
||||
if (search_result->group_by_values_.has_value())
|
||||
search_result->group_by_values_.value()[valid_index] =
|
||||
search_result->group_by_values_.value()[index];
|
||||
valid_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
offsets.resize(valid_index);
|
||||
distances.resize(valid_index);
|
||||
if (search_result->group_by_values_.has_value())
|
||||
search_result->group_by_values_.value().resize(valid_index);
|
||||
|
||||
search_result->topk_per_nq_prefix_sum_.resize(total_nq + 1);
|
||||
std::partial_sum(real_topKs.begin(),
|
||||
real_topKs.end(),
|
||||
search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::StreamReduceSearchResultForOneNQ(int64_t qi,
|
||||
int64_t topK,
|
||||
int64_t& offset) {
|
||||
//1. clear heap for preceding left elements
|
||||
while (!heap_.empty()) {
|
||||
heap_.pop();
|
||||
}
|
||||
pk_set_.clear();
|
||||
group_by_val_set_.clear();
|
||||
|
||||
//2. push new search results into sort-heap
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_to_merge_[i];
|
||||
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
if (offset_beg == offset_end) {
|
||||
continue;
|
||||
}
|
||||
auto primary_key = search_result->primary_keys_[offset_beg];
|
||||
auto distance = search_result->distances_[offset_beg];
|
||||
if (search_result->group_by_values_.has_value()) {
|
||||
AssertInfo(
|
||||
search_result->group_by_values_.value().size() > offset_beg,
|
||||
"Wrong size for group_by_values size to "
|
||||
"ReduceSearchResultForOneNQ:{}, not enough for"
|
||||
"required offset_beg:{}",
|
||||
search_result->group_by_values_.value().size(),
|
||||
offset_beg);
|
||||
}
|
||||
|
||||
auto result_pair = std::make_shared<StreamSearchResultPair>(
|
||||
primary_key,
|
||||
distance,
|
||||
search_result,
|
||||
nullptr,
|
||||
i,
|
||||
offset_beg,
|
||||
offset_end,
|
||||
search_result->group_by_values_.has_value() &&
|
||||
search_result->group_by_values_.value().size() > offset_beg
|
||||
? std::make_optional(
|
||||
search_result->group_by_values_.value().at(offset_beg))
|
||||
: std::nullopt);
|
||||
heap_.push(result_pair);
|
||||
}
|
||||
if (heap_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
//3. if the merged_search_result has previous data
|
||||
//push merged search result into the heap
|
||||
if (merged_search_result->has_result_) {
|
||||
auto merged_off_begin =
|
||||
merged_search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto merged_off_end =
|
||||
merged_search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
if (merged_off_end > merged_off_begin) {
|
||||
auto merged_pk =
|
||||
merged_search_result->primary_keys_[merged_off_begin];
|
||||
auto merged_distance =
|
||||
merged_search_result->distances_[merged_off_begin];
|
||||
auto merged_result_pair = std::make_shared<StreamSearchResultPair>(
|
||||
merged_pk,
|
||||
merged_distance,
|
||||
nullptr,
|
||||
merged_search_result.get(),
|
||||
num_segments_, //use last index as the merged segment idex
|
||||
merged_off_begin,
|
||||
merged_off_end,
|
||||
merged_search_result->group_by_values_.has_value() &&
|
||||
merged_search_result->group_by_values_.value().size() >
|
||||
merged_off_begin
|
||||
? std::make_optional(
|
||||
merged_search_result->group_by_values_.value().at(
|
||||
merged_off_begin))
|
||||
: std::nullopt);
|
||||
heap_.push(merged_result_pair);
|
||||
}
|
||||
}
|
||||
|
||||
//3. pop heap to sort
|
||||
int count = 0;
|
||||
while (count < topK && !heap_.empty()) {
|
||||
auto pilot = heap_.top();
|
||||
heap_.pop();
|
||||
auto seg_index = pilot->segment_index_;
|
||||
auto pk = pilot->primary_key_;
|
||||
if (pk == INVALID_PK) {
|
||||
break; // valid search result for this nq has been run out, break to next
|
||||
}
|
||||
if (pk_set_.count(pk) == 0) {
|
||||
bool skip_for_group_by = false;
|
||||
if (pilot->group_by_value_.has_value()) {
|
||||
if (group_by_val_set_.count(pilot->group_by_value_.value()) >
|
||||
0) {
|
||||
skip_for_group_by = true;
|
||||
}
|
||||
}
|
||||
if (!skip_for_group_by) {
|
||||
final_search_records_[seg_index][qi].push_back(pilot->offset_);
|
||||
if (pilot->search_result_ != nullptr) {
|
||||
pilot->search_result_->result_offsets_.push_back(offset++);
|
||||
} else {
|
||||
merged_search_result->reduced_offsets_.push_back(offset++);
|
||||
}
|
||||
pk_set_.insert(pk);
|
||||
if (pilot->group_by_value_.has_value()) {
|
||||
group_by_val_set_.insert(pilot->group_by_value_.value());
|
||||
}
|
||||
count++;
|
||||
}
|
||||
}
|
||||
pilot->advance();
|
||||
if (pilot->primary_key_ != INVALID_PK) {
|
||||
heap_.push(pilot);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::RefreshSearchResult() {
|
||||
//1. refresh new input results
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
std::vector<int64_t> real_topKs(total_nq_, 0);
|
||||
auto search_result = search_results_to_merge_[i];
|
||||
if (search_result->result_offsets_.size() > 0) {
|
||||
uint32_t final_size = 0;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
final_size += final_search_records_[i][j].size();
|
||||
}
|
||||
std::vector<milvus::PkType> reduced_pks(final_size);
|
||||
std::vector<float> reduced_distances(final_size);
|
||||
std::vector<int64_t> reduced_seg_offsets(final_size);
|
||||
std::vector<GroupByValueType> reduced_group_by_values(final_size);
|
||||
|
||||
uint32_t final_index = 0;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
for (auto offset : final_search_records_[i][j]) {
|
||||
reduced_pks[final_index] =
|
||||
search_result->primary_keys_[offset];
|
||||
reduced_distances[final_index] =
|
||||
search_result->distances_[offset];
|
||||
reduced_seg_offsets[final_index] =
|
||||
search_result->seg_offsets_[offset];
|
||||
if (search_result->group_by_values_.has_value())
|
||||
reduced_group_by_values[final_index] =
|
||||
search_result->group_by_values_.value()[offset];
|
||||
final_index++;
|
||||
real_topKs[j]++;
|
||||
}
|
||||
}
|
||||
search_result->primary_keys_.swap(reduced_pks);
|
||||
search_result->distances_.swap(reduced_distances);
|
||||
search_result->seg_offsets_.swap(reduced_seg_offsets);
|
||||
if (search_result->group_by_values_.has_value()) {
|
||||
search_result->group_by_values_.value().swap(
|
||||
reduced_group_by_values);
|
||||
}
|
||||
}
|
||||
std::partial_sum(real_topKs.begin(),
|
||||
real_topKs.end(),
|
||||
search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
|
||||
//2. refresh merged search result possibly
|
||||
if (merged_search_result->has_result_) {
|
||||
std::vector<int64_t> real_topKs(total_nq_, 0);
|
||||
if (merged_search_result->reduced_offsets_.size() > 0) {
|
||||
uint32_t final_size = merged_search_result->reduced_offsets_.size();
|
||||
std::vector<milvus::PkType> reduced_pks(final_size);
|
||||
std::vector<float> reduced_distances(final_size);
|
||||
std::vector<int64_t> reduced_seg_offsets(final_size);
|
||||
std::vector<GroupByValueType> reduced_group_by_values(final_size);
|
||||
|
||||
uint32_t final_index = 0;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
for (auto offset : final_search_records_[num_segments_][j]) {
|
||||
reduced_pks[final_index] =
|
||||
merged_search_result->primary_keys_[offset];
|
||||
reduced_distances[final_index] =
|
||||
merged_search_result->distances_[offset];
|
||||
if (merged_search_result->group_by_values_.has_value())
|
||||
reduced_group_by_values[final_index] =
|
||||
merged_search_result->group_by_values_
|
||||
.value()[offset];
|
||||
final_index++;
|
||||
real_topKs[j]++;
|
||||
}
|
||||
}
|
||||
merged_search_result->primary_keys_.swap(reduced_pks);
|
||||
merged_search_result->distances_.swap(reduced_distances);
|
||||
if (merged_search_result->group_by_values_.has_value()) {
|
||||
merged_search_result->group_by_values_.value().swap(
|
||||
reduced_group_by_values);
|
||||
}
|
||||
}
|
||||
std::partial_sum(
|
||||
real_topKs.begin(),
|
||||
real_topKs.end(),
|
||||
merged_search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char>
|
||||
StreamReducerHelper::GetSearchResultDataSlice(int slice_index) {
|
||||
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
|
||||
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
|
||||
|
||||
auto search_result_data =
|
||||
std::make_unique<milvus::proto::schema::SearchResultData>();
|
||||
// set unify_topK and total_nq
|
||||
search_result_data->set_top_k(slice_topKs_[slice_index]);
|
||||
search_result_data->set_num_queries(nq_end - nq_begin);
|
||||
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
|
||||
|
||||
int64_t result_count = 0;
|
||||
if (merged_search_result->has_result_) {
|
||||
AssertInfo(
|
||||
nq_begin < merged_search_result->topk_per_nq_prefix_sum_.size(),
|
||||
"nq_begin is incorrect for reduce, nq_begin:{}, topk_size:{}",
|
||||
nq_begin,
|
||||
merged_search_result->topk_per_nq_prefix_sum_.size());
|
||||
AssertInfo(
|
||||
nq_end < merged_search_result->topk_per_nq_prefix_sum_.size(),
|
||||
"nq_end is incorrect for reduce, nq_end:{}, topk_size:{}",
|
||||
nq_end,
|
||||
merged_search_result->topk_per_nq_prefix_sum_.size());
|
||||
|
||||
result_count = merged_search_result->topk_per_nq_prefix_sum_[nq_end] -
|
||||
merged_search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
}
|
||||
|
||||
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
|
||||
std::vector<MergeBase> result_pairs(result_count);
|
||||
|
||||
// reserve space for pks
|
||||
auto primary_field_id =
|
||||
plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
|
||||
AssertInfo(primary_field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
|
||||
auto pk_type = plan_->schema_[primary_field_id].get_data_type();
|
||||
switch (pk_type) {
|
||||
case milvus::DataType::INT64: {
|
||||
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
|
||||
ids->mutable_data()->Resize(result_count, 0);
|
||||
search_result_data->mutable_ids()->set_allocated_int_id(
|
||||
ids.release());
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::VARCHAR: {
|
||||
auto ids = std::make_unique<milvus::proto::schema::StringArray>();
|
||||
std::vector<std::string> string_pks(result_count);
|
||||
// TODO: prevent mem copy
|
||||
*ids->mutable_data() = {string_pks.begin(), string_pks.end()};
|
||||
search_result_data->mutable_ids()->set_allocated_str_id(
|
||||
ids.release());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported primary key type {}", pk_type));
|
||||
}
|
||||
}
|
||||
|
||||
// reserve space for distances
|
||||
search_result_data->mutable_scores()->Resize(result_count, 0);
|
||||
|
||||
//reserve space for group_by_values
|
||||
std::vector<GroupByValueType> group_by_values;
|
||||
if (plan_->plan_node_->search_info_.group_by_field_id_.has_value()) {
|
||||
group_by_values.resize(result_count);
|
||||
}
|
||||
|
||||
// fill pks and distances
|
||||
for (auto qi = nq_begin; qi < nq_end; qi++) {
|
||||
int64_t topk_count = 0;
|
||||
AssertInfo(merged_search_result != nullptr,
|
||||
"null merged search result when reorganize");
|
||||
if (!merged_search_result->has_result_ ||
|
||||
merged_search_result->result_offsets_.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto topk_start = merged_search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto topk_end = merged_search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
topk_count += topk_end - topk_start;
|
||||
|
||||
for (auto ki = topk_start; ki < topk_end; ki++) {
|
||||
auto loc = merged_search_result->result_offsets_[ki];
|
||||
AssertInfo(loc < result_count && loc >= 0,
|
||||
"invalid loc when GetSearchResultDataSlice, loc = " +
|
||||
std::to_string(loc) +
|
||||
", result_count = " + std::to_string(result_count));
|
||||
// set result pks
|
||||
switch (pk_type) {
|
||||
case milvus::DataType::INT64: {
|
||||
search_result_data->mutable_ids()
|
||||
->mutable_int_id()
|
||||
->mutable_data()
|
||||
->Set(loc,
|
||||
std::visit(
|
||||
Int64PKVisitor{},
|
||||
merged_search_result->primary_keys_[ki]));
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::VARCHAR: {
|
||||
*search_result_data->mutable_ids()
|
||||
->mutable_str_id()
|
||||
->mutable_data()
|
||||
->Mutable(loc) =
|
||||
std::visit(StrPKVisitor{},
|
||||
merged_search_result->primary_keys_[ki]);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported primary key type {}",
|
||||
pk_type));
|
||||
}
|
||||
}
|
||||
|
||||
search_result_data->mutable_scores()->Set(
|
||||
loc, merged_search_result->distances_[ki]);
|
||||
// set group by values
|
||||
if (merged_search_result->group_by_values_.has_value() &&
|
||||
ki < merged_search_result->group_by_values_.value().size())
|
||||
group_by_values[loc] =
|
||||
merged_search_result->group_by_values_.value()[ki];
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = {&merged_search_result->output_fields_data_,
|
||||
ki};
|
||||
}
|
||||
|
||||
// update result topKs
|
||||
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
|
||||
}
|
||||
AssembleGroupByValues(search_result_data, group_by_values, plan_);
|
||||
|
||||
AssertInfo(search_result_data->scores_size() == result_count,
|
||||
"wrong scores size, size = " +
|
||||
std::to_string(search_result_data->scores_size()) +
|
||||
", expected size = " + std::to_string(result_count));
|
||||
|
||||
// set output fields
|
||||
for (auto field_id : plan_->target_entries_) {
|
||||
auto& field_meta = plan_->schema_[field_id];
|
||||
auto field_data =
|
||||
milvus::segcore::MergeDataArray(result_pairs, field_meta);
|
||||
if (field_meta.get_data_type() == DataType::ARRAY) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
}
|
||||
search_result_data->mutable_fields_data()->AddAllocated(
|
||||
field_data.release());
|
||||
}
|
||||
|
||||
// SearchResultData to blob
|
||||
auto size = search_result_data->ByteSizeLong();
|
||||
auto buffer = std::vector<char>(size);
|
||||
search_result_data->SerializePartialToArray(buffer.data(), size);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::CleanReduceStatus() {
|
||||
this->final_search_records_.clear();
|
||||
this->merged_search_result->reduced_offsets_.clear();
|
||||
}
|
||||
} // namespace milvus::segcore
|
|
@ -0,0 +1,222 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "segcore/segment_c.h"
|
||||
#include "query/PlanImpl.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "segcore/ReduceStructure.h"
|
||||
#include "common/EasyAssert.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
class MergedSearchResult {
|
||||
public:
|
||||
bool has_result_;
|
||||
std::vector<PkType> primary_keys_;
|
||||
std::vector<float> distances_;
|
||||
std::optional<std::vector<GroupByValueType>> group_by_values_;
|
||||
|
||||
// set output fields data when filling target entity
|
||||
std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data_;
|
||||
|
||||
// used for reduce, filter invalid pk, get real topks count
|
||||
std::vector<size_t> topk_per_nq_prefix_sum_;
|
||||
// fill data during reducing search result
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<int64_t> reduced_offsets_;
|
||||
};
|
||||
|
||||
struct StreamSearchResultPair {
|
||||
milvus::PkType primary_key_;
|
||||
float distance_;
|
||||
milvus::SearchResult* search_result_;
|
||||
MergedSearchResult* merged_result_;
|
||||
int64_t segment_index_;
|
||||
int64_t offset_;
|
||||
int64_t offset_rb_;
|
||||
std::optional<milvus::GroupByValueType> group_by_value_;
|
||||
|
||||
StreamSearchResultPair(milvus::PkType primary_key,
|
||||
float distance,
|
||||
SearchResult* result,
|
||||
int64_t index,
|
||||
int64_t lb,
|
||||
int64_t rb)
|
||||
: StreamSearchResultPair(primary_key,
|
||||
distance,
|
||||
result,
|
||||
nullptr,
|
||||
index,
|
||||
lb,
|
||||
rb,
|
||||
std::nullopt) {
|
||||
}
|
||||
|
||||
StreamSearchResultPair(
|
||||
milvus::PkType primary_key,
|
||||
float distance,
|
||||
SearchResult* result,
|
||||
MergedSearchResult* merged_result,
|
||||
int64_t index,
|
||||
int64_t lb,
|
||||
int64_t rb,
|
||||
std::optional<milvus::GroupByValueType> group_by_value)
|
||||
: primary_key_(std::move(primary_key)),
|
||||
distance_(distance),
|
||||
search_result_(result),
|
||||
merged_result_(merged_result),
|
||||
segment_index_(index),
|
||||
offset_(lb),
|
||||
offset_rb_(rb),
|
||||
group_by_value_(group_by_value) {
|
||||
AssertInfo(
|
||||
search_result_ != nullptr || merged_result_ != nullptr,
|
||||
"For a valid StreamSearchResult pair, "
|
||||
"at least one of merged_result_ or search_result_ is not nullptr");
|
||||
}
|
||||
|
||||
bool
|
||||
operator>(const StreamSearchResultPair& other) const {
|
||||
if (std::fabs(distance_ - other.distance_) < 0.0000000119) {
|
||||
return primary_key_ < other.primary_key_;
|
||||
}
|
||||
return distance_ > other.distance_;
|
||||
}
|
||||
|
||||
void
|
||||
advance() {
|
||||
offset_++;
|
||||
if (offset_ < offset_rb_) {
|
||||
if (search_result_ != nullptr) {
|
||||
primary_key_ = search_result_->primary_keys_.at(offset_);
|
||||
distance_ = search_result_->distances_.at(offset_);
|
||||
if (search_result_->group_by_values_.has_value() &&
|
||||
offset_ < search_result_->group_by_values_.value().size()) {
|
||||
group_by_value_ =
|
||||
search_result_->group_by_values_.value().at(offset_);
|
||||
}
|
||||
} else {
|
||||
primary_key_ = merged_result_->primary_keys_.at(offset_);
|
||||
distance_ = merged_result_->distances_.at(offset_);
|
||||
if (merged_result_->group_by_values_.has_value() &&
|
||||
offset_ < merged_result_->group_by_values_.value().size()) {
|
||||
group_by_value_ =
|
||||
merged_result_->group_by_values_.value().at(offset_);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
primary_key_ = INVALID_PK;
|
||||
distance_ = std::numeric_limits<float>::min();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct StreamSearchResultPairComparator {
|
||||
bool
|
||||
operator()(const std::shared_ptr<StreamSearchResultPair> lhs,
|
||||
const std::shared_ptr<StreamSearchResultPair> rhs) const {
|
||||
return (*rhs.get()) > (*lhs.get());
|
||||
}
|
||||
};
|
||||
|
||||
class StreamReducerHelper {
|
||||
public:
|
||||
explicit StreamReducerHelper(milvus::query::Plan* plan,
|
||||
int64_t* slice_nqs,
|
||||
int64_t* slice_topKs,
|
||||
int64_t slice_num)
|
||||
: plan_(plan),
|
||||
slice_nqs_(slice_nqs, slice_nqs + slice_num),
|
||||
slice_topKs_(slice_topKs, slice_topKs + slice_num) {
|
||||
AssertInfo(slice_nqs_.size() > 0, "empty_nqs");
|
||||
AssertInfo(slice_nqs_.size() == slice_topKs_.size(),
|
||||
"unaligned slice_nqs and slice_topKs");
|
||||
merged_search_result = std::make_unique<MergedSearchResult>();
|
||||
merged_search_result->has_result_ = false;
|
||||
num_slice_ = slice_nqs_.size();
|
||||
slice_nqs_prefix_sum_.resize(num_slice_ + 1);
|
||||
std::partial_sum(slice_nqs_.begin(),
|
||||
slice_nqs_.end(),
|
||||
slice_nqs_prefix_sum_.begin() + 1);
|
||||
total_nq_ = slice_nqs_prefix_sum_[num_slice_];
|
||||
}
|
||||
|
||||
void
|
||||
SetSearchResultsToMerge(std::vector<SearchResult*>& search_results) {
|
||||
search_results_to_merge_ = search_results;
|
||||
num_segments_ = search_results_to_merge_.size();
|
||||
AssertInfo(num_segments_ > 0, "empty search result");
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
MergeReduce();
|
||||
void*
|
||||
SerializeMergedResult();
|
||||
|
||||
protected:
|
||||
void
|
||||
FilterSearchResults();
|
||||
|
||||
void
|
||||
InitializeReduceRecords();
|
||||
|
||||
void
|
||||
FillPrimaryKeys();
|
||||
|
||||
void
|
||||
FilterInvalidSearchResult(SearchResult* search_result);
|
||||
|
||||
void
|
||||
ReduceResultData();
|
||||
|
||||
private:
|
||||
void
|
||||
RefreshSearchResult();
|
||||
|
||||
void
|
||||
StreamReduceSearchResultForOneNQ(int64_t qi, int64_t topK, int64_t& offset);
|
||||
|
||||
void
|
||||
FillEntryData();
|
||||
|
||||
void
|
||||
AssembleMergedResult();
|
||||
|
||||
std::vector<char>
|
||||
GetSearchResultDataSlice(int slice_index);
|
||||
|
||||
void
|
||||
CleanReduceStatus();
|
||||
|
||||
std::unique_ptr<MergedSearchResult> merged_search_result;
|
||||
milvus::query::Plan* plan_;
|
||||
std::vector<int64_t> slice_nqs_;
|
||||
std::vector<int64_t> slice_topKs_;
|
||||
std::vector<SearchResult*> search_results_to_merge_;
|
||||
int64_t num_segments_{0};
|
||||
int64_t num_slice_{0};
|
||||
std::vector<int64_t> slice_nqs_prefix_sum_;
|
||||
std::priority_queue<std::shared_ptr<StreamSearchResultPair>,
|
||||
std::vector<std::shared_ptr<StreamSearchResultPair>>,
|
||||
StreamSearchResultPairComparator>
|
||||
heap_;
|
||||
std::unordered_set<milvus::PkType> pk_set_;
|
||||
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
|
||||
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
|
||||
int64_t total_nq_{0};
|
||||
};
|
||||
} // namespace milvus::segcore
|
|
@ -530,19 +530,17 @@ CreateDataArrayFrom(const void* data_raw,
|
|||
|
||||
// TODO remove merge dataArray, instead fill target entity when get data slice
|
||||
std::unique_ptr<DataArray>
|
||||
MergeDataArray(
|
||||
std::vector<std::pair<milvus::SearchResult*, int64_t>>& result_offsets,
|
||||
const FieldMeta& field_meta) {
|
||||
MergeDataArray(std::vector<MergeBase>& merge_bases,
|
||||
const FieldMeta& field_meta) {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
auto data_array = std::make_unique<DataArray>();
|
||||
data_array->set_field_id(field_meta.get_id().get());
|
||||
data_array->set_type(static_cast<milvus::proto::schema::DataType>(
|
||||
field_meta.get_data_type()));
|
||||
|
||||
for (auto& result_pair : result_offsets) {
|
||||
auto src_field_data =
|
||||
result_pair.first->output_fields_data_[field_meta.get_id()].get();
|
||||
auto src_offset = result_pair.second;
|
||||
for (auto& merge_base : merge_bases) {
|
||||
auto src_field_data = merge_base.get_field_data(field_meta.get_id());
|
||||
auto src_offset = merge_base.getOffset();
|
||||
AssertInfo(data_type == DataType(src_field_data->type()),
|
||||
"merge field data type not consistent");
|
||||
if (field_meta.is_vector()) {
|
||||
|
@ -650,7 +648,6 @@ MergeDataArray(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return data_array;
|
||||
}
|
||||
|
||||
|
|
|
@ -77,10 +77,35 @@ CreateDataArrayFrom(const void* data_raw,
|
|||
const FieldMeta& field_meta);
|
||||
|
||||
// TODO remove merge dataArray, instead fill target entity when get data slice
|
||||
struct MergeBase {
|
||||
private:
|
||||
std::map<FieldId, std::unique_ptr<milvus::DataArray>>* output_fields_data_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
MergeBase() {
|
||||
}
|
||||
|
||||
MergeBase(std::map<FieldId, std::unique_ptr<milvus::DataArray>>*
|
||||
output_fields_data,
|
||||
size_t offset)
|
||||
: output_fields_data_(output_fields_data), offset_(offset) {
|
||||
}
|
||||
|
||||
size_t
|
||||
getOffset() const {
|
||||
return offset_;
|
||||
}
|
||||
|
||||
milvus::DataArray*
|
||||
get_field_data(FieldId fieldId) const {
|
||||
return (*output_fields_data_)[fieldId].get();
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
MergeDataArray(
|
||||
std::vector<std::pair<milvus::SearchResult*, int64_t>>& result_offsets,
|
||||
const FieldMeta& field_meta);
|
||||
MergeDataArray(std::vector<MergeBase>& merge_bases,
|
||||
const FieldMeta& field_meta);
|
||||
|
||||
template <bool is_sealed>
|
||||
std::shared_ptr<DeletedRecord::TmpBitmap>
|
||||
|
|
|
@ -24,7 +24,7 @@ struct Int64PKVisitor {
|
|||
};
|
||||
|
||||
template <>
|
||||
int64_t
|
||||
inline int64_t
|
||||
Int64PKVisitor::operator()<int64_t>(int64_t t) const {
|
||||
return t;
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ struct StrPKVisitor {
|
|||
};
|
||||
|
||||
template <>
|
||||
std::string
|
||||
inline std::string
|
||||
StrPKVisitor::operator()<std::string>(std::string t) const {
|
||||
return t;
|
||||
}
|
||||
|
|
|
@ -15,10 +15,64 @@
|
|||
#include "common/EasyAssert.h"
|
||||
#include "query/Plan.h"
|
||||
#include "segcore/reduce_c.h"
|
||||
#include "segcore/StreamReduce.h"
|
||||
#include "segcore/Utils.h"
|
||||
|
||||
using SearchResult = milvus::SearchResult;
|
||||
|
||||
CStatus
|
||||
NewStreamReducer(CSearchPlan c_plan,
|
||||
int64_t* slice_nqs,
|
||||
int64_t* slice_topKs,
|
||||
int64_t num_slices,
|
||||
CSearchStreamReducer* stream_reducer) {
|
||||
try {
|
||||
//convert search results and search plan
|
||||
auto plan = static_cast<milvus::query::Plan*>(c_plan);
|
||||
auto stream_reduce_helper =
|
||||
std::make_unique<milvus::segcore::StreamReducerHelper>(
|
||||
plan, slice_nqs, slice_topKs, num_slices);
|
||||
*stream_reducer = stream_reduce_helper.release();
|
||||
return milvus::SuccessCStatus();
|
||||
} catch (std::exception& e) {
|
||||
return milvus::FailureCStatus(&e);
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
StreamReduce(CSearchStreamReducer c_stream_reducer,
|
||||
CSearchResult* c_search_results,
|
||||
int64_t num_segments) {
|
||||
try {
|
||||
auto stream_reducer =
|
||||
static_cast<milvus::segcore::StreamReducerHelper*>(
|
||||
c_stream_reducer);
|
||||
std::vector<SearchResult*> search_results(num_segments);
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
|
||||
}
|
||||
stream_reducer->SetSearchResultsToMerge(search_results);
|
||||
stream_reducer->MergeReduce();
|
||||
return milvus::SuccessCStatus();
|
||||
} catch (std::exception& e) {
|
||||
return milvus::FailureCStatus(&e);
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
GetStreamReduceResult(CSearchStreamReducer c_stream_reducer,
|
||||
CSearchResultDataBlobs* c_search_result_data_blobs) {
|
||||
try {
|
||||
auto stream_reducer =
|
||||
static_cast<milvus::segcore::StreamReducerHelper*>(
|
||||
c_stream_reducer);
|
||||
*c_search_result_data_blobs = stream_reducer->SerializeMergedResult();
|
||||
return milvus::SuccessCStatus();
|
||||
} catch (std::exception& e) {
|
||||
return milvus::FailureCStatus(&e);
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
||||
CSearchPlan c_plan,
|
||||
|
@ -81,3 +135,13 @@ DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs) {
|
|||
cSearchResultDataBlobs);
|
||||
delete search_result_data_blobs;
|
||||
}
|
||||
|
||||
void
|
||||
DeleteStreamSearchReducer(CSearchStreamReducer c_stream_reducer) {
|
||||
if (c_stream_reducer == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto stream_reducer =
|
||||
static_cast<milvus::segcore::StreamReducerHelper*>(c_stream_reducer);
|
||||
delete stream_reducer;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,23 @@ extern "C" {
|
|||
#include "segcore/segment_c.h"
|
||||
|
||||
typedef void* CSearchResultDataBlobs;
|
||||
typedef void* CSearchStreamReducer;
|
||||
|
||||
CStatus
|
||||
NewStreamReducer(CSearchPlan c_plan,
|
||||
int64_t* slice_nqs,
|
||||
int64_t* slice_topKs,
|
||||
int64_t num_slices,
|
||||
CSearchStreamReducer* stream_reducer);
|
||||
|
||||
CStatus
|
||||
StreamReduce(CSearchStreamReducer c_stream_reducer,
|
||||
CSearchResult* c_search_results,
|
||||
int64_t num_segments);
|
||||
|
||||
CStatus
|
||||
GetStreamReduceResult(CSearchStreamReducer c_stream_reducer,
|
||||
CSearchResultDataBlobs* c_search_result_data_blobs);
|
||||
|
||||
CStatus
|
||||
ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
||||
|
@ -36,6 +53,9 @@ GetSearchResultDataBlob(CProto* searchResultDataBlob,
|
|||
void
|
||||
DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs);
|
||||
|
||||
void
|
||||
DeleteStreamSearchReducer(CSearchStreamReducer c_stream_reducer);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -476,3 +476,10 @@ WarmupChunkCache(CSegmentInterface c_segment, int64_t field_id) {
|
|||
return milvus::FailureCStatus(milvus::UnexpectedError, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
RemoveFieldFile(CSegmentInterface c_segment, int64_t field_id) {
|
||||
auto segment =
|
||||
reinterpret_cast<milvus::segcore::SegmentSealedImpl*>(c_segment);
|
||||
segment->RemoveFieldFile(milvus::FieldId(field_id));
|
||||
}
|
||||
|
|
|
@ -156,6 +156,9 @@ Delete(CSegmentInterface c_segment,
|
|||
const uint64_t ids_size,
|
||||
const uint64_t* timestamps);
|
||||
|
||||
void
|
||||
RemoveFieldFile(CSegmentInterface c_segment, int64_t field_id);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -103,17 +103,28 @@ DeserializeFileData(const std::shared_ptr<uint8_t[]> input_data,
|
|||
int64_t length) {
|
||||
auto binlog_reader = std::make_shared<BinlogReader>(input_data, length);
|
||||
auto medium_type = ReadMediumType(binlog_reader);
|
||||
auto start_deserialize = std::chrono::system_clock::now();
|
||||
std::unique_ptr<DataCodec> res;
|
||||
switch (medium_type) {
|
||||
case StorageType::Remote: {
|
||||
return DeserializeRemoteFileData(binlog_reader);
|
||||
res = DeserializeRemoteFileData(binlog_reader);
|
||||
break;
|
||||
}
|
||||
case StorageType::LocalDisk: {
|
||||
return DeserializeLocalFileData(binlog_reader);
|
||||
res = DeserializeLocalFileData(binlog_reader);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataFormatBroken,
|
||||
fmt::format("unsupported medium type {}", medium_type));
|
||||
}
|
||||
auto deserialize_duration =
|
||||
std::chrono::system_clock::now() - start_deserialize;
|
||||
LOG_INFO("DeserializeFileData_deserialize_duration_ms:{}",
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
deserialize_duration)
|
||||
.count());
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace milvus::storage
|
||||
|
|
|
@ -148,6 +148,24 @@ DEFINE_PROMETHEUS_COUNTER(internal_storage_op_count_remove_fail,
|
|||
internal_storage_op_count,
|
||||
removeFailMap)
|
||||
|
||||
//load metrics
|
||||
std::map<std::string, std::string> downloadDurationLabels{{"type", "download"}};
|
||||
std::map<std::string, std::string> writeDiskDurationLabels{
|
||||
{"type", "write_disk"}};
|
||||
std::map<std::string, std::string> deserializeDurationLabels{
|
||||
{"type", "deserialize"}};
|
||||
DEFINE_PROMETHEUS_HISTOGRAM_FAMILY(internal_storage_load_duration,
|
||||
"[cpp]durations of load segment")
|
||||
DEFINE_PROMETHEUS_HISTOGRAM(internal_storage_download_duration,
|
||||
internal_storage_load_duration,
|
||||
downloadDurationLabels)
|
||||
DEFINE_PROMETHEUS_HISTOGRAM(internal_storage_write_disk_duration,
|
||||
internal_storage_load_duration,
|
||||
writeDiskDurationLabels)
|
||||
DEFINE_PROMETHEUS_HISTOGRAM(internal_storage_deserialize_duration,
|
||||
internal_storage_load_duration,
|
||||
deserializeDurationLabels)
|
||||
|
||||
// mmap metrics
|
||||
std::map<std::string, std::string> mmapAllocatedSpaceAnonLabel = {
|
||||
{"type", "anon"}};
|
||||
|
|
|
@ -115,6 +115,11 @@ DECLARE_PROMETHEUS_COUNTER(internal_storage_op_count_list_fail);
|
|||
DECLARE_PROMETHEUS_COUNTER(internal_storage_op_count_remove_suc);
|
||||
DECLARE_PROMETHEUS_COUNTER(internal_storage_op_count_remove_fail);
|
||||
|
||||
DECLARE_PROMETHEUS_HISTOGRAM_FAMILY(internal_storage_load_duration);
|
||||
DECLARE_PROMETHEUS_HISTOGRAM(internal_storage_download_duration);
|
||||
DECLARE_PROMETHEUS_HISTOGRAM(internal_storage_write_disk_duration);
|
||||
DECLARE_PROMETHEUS_HISTOGRAM(internal_storage_deserialize_duration);
|
||||
|
||||
// mmap metrics
|
||||
DECLARE_PROMETHEUS_HISTOGRAM_FAMILY(internal_mmap_allocated_space_bytes);
|
||||
DECLARE_PROMETHEUS_HISTOGRAM(internal_mmap_allocated_space_bytes_anon);
|
||||
|
@ -122,4 +127,5 @@ DECLARE_PROMETHEUS_HISTOGRAM(internal_mmap_allocated_space_bytes_file);
|
|||
DECLARE_PROMETHEUS_GAUGE_FAMILY(internal_mmap_in_used_space_bytes);
|
||||
DECLARE_PROMETHEUS_GAUGE(internal_mmap_in_used_space_bytes_anon);
|
||||
DECLARE_PROMETHEUS_GAUGE(internal_mmap_in_used_space_bytes_file);
|
||||
|
||||
} // namespace milvus::storage
|
||||
|
|
|
@ -26,6 +26,7 @@ set(MILVUS_TEST_FILES
|
|||
test_concurrent_vector.cpp
|
||||
test_c_api.cpp
|
||||
test_expr_materialized_view.cpp
|
||||
test_c_stream_reduce.cpp
|
||||
test_expr.cpp
|
||||
test_float16.cpp
|
||||
test_growing.cpp
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "plan/PlanNode.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/load_index_c.h"
|
||||
#include "test_utils/c_api_test_utils.h"
|
||||
|
||||
namespace chrono = std::chrono;
|
||||
|
||||
|
@ -59,16 +60,6 @@ namespace {
|
|||
const int64_t ROW_COUNT = 10 * 1000;
|
||||
const int64_t BIAS = 4200;
|
||||
|
||||
CStatus
|
||||
CSearch(CSegmentInterface c_segment,
|
||||
CSearchPlan c_plan,
|
||||
CPlaceholderGroup c_placeholder_group,
|
||||
uint64_t timestamp,
|
||||
CSearchResult* result) {
|
||||
return Search(
|
||||
{}, c_segment, c_plan, c_placeholder_group, timestamp, result);
|
||||
}
|
||||
|
||||
CStatus
|
||||
CRetrieve(CSegmentInterface c_segment,
|
||||
CRetrievePlan c_plan,
|
||||
|
@ -83,32 +74,6 @@ CRetrieve(CSegmentInterface c_segment,
|
|||
false);
|
||||
}
|
||||
|
||||
const char*
|
||||
get_default_schema_config() {
|
||||
static std::string conf = R"(name: "default-collection"
|
||||
fields: <
|
||||
fieldID: 100
|
||||
name: "fakevec"
|
||||
data_type: FloatVector
|
||||
type_params: <
|
||||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
fieldID: 101
|
||||
name: "age"
|
||||
data_type: Int64
|
||||
is_primary_key: true
|
||||
>)";
|
||||
static std::string fake_conf = "";
|
||||
return conf.c_str();
|
||||
}
|
||||
|
||||
const char*
|
||||
get_float16_schema_config() {
|
||||
static std::string conf = R"(name: "float16-collection"
|
||||
|
@ -212,52 +177,6 @@ generate_data(int N) {
|
|||
}
|
||||
return std::make_tuple(raw_data, timestamps, uids);
|
||||
}
|
||||
std::string
|
||||
generate_max_float_query_data(int all_nq, int max_float_nq) {
|
||||
assert(max_float_nq <= all_nq);
|
||||
namespace ser = milvus::proto::common;
|
||||
int dim = DIM;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
for (int i = 0; i < all_nq; ++i) {
|
||||
std::vector<float> vec;
|
||||
if (i < max_float_nq) {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(std::numeric_limits<float>::max());
|
||||
}
|
||||
} else {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(1);
|
||||
}
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
|
||||
std::string
|
||||
generate_query_data(int nq) {
|
||||
namespace ser = milvus::proto::common;
|
||||
std::default_random_engine e(67);
|
||||
int dim = DIM;
|
||||
std::normal_distribution<double> dis(0.0, 1.0);
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(dis(e));
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
|
||||
std::string
|
||||
generate_query_data_float16(int nq) {
|
||||
|
@ -300,7 +219,7 @@ generate_query_data_bfloat16(int nq) {
|
|||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
// 创建枚举,包含schema::DataType::BinaryVector,schema::DataType::FloatVector
|
||||
// create Enum for schema::DataType::BinaryVector,schema::DataType::FloatVector
|
||||
enum VectorType {
|
||||
BinaryVector = 0,
|
||||
FloatVector = 1,
|
||||
|
@ -1558,27 +1477,6 @@ TEST(CApiTest, GetRealCount) {
|
|||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
void
|
||||
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
|
||||
auto nq = ((SearchResult*)results[0])->total_nq_;
|
||||
|
||||
std::unordered_set<PkType> pk_set;
|
||||
for (int qi = 0; qi < nq; qi++) {
|
||||
pk_set.clear();
|
||||
for (size_t i = 0; i < results.size(); i++) {
|
||||
auto search_result = (SearchResult*)results[i];
|
||||
ASSERT_EQ(nq, search_result->total_nq_);
|
||||
auto topk_beg = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
for (size_t ki = topk_beg; ki < topk_end; ki++) {
|
||||
ASSERT_NE(search_result->seg_offsets_[ki], INVALID_SEG_OFFSET);
|
||||
auto ret = pk_set.insert(search_result->primary_keys_[ki]);
|
||||
ASSERT_TRUE(ret.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApiTest, ReduceNullResult) {
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
CSegmentInterface segment;
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "test_utils/c_api_test_utils.h"
|
||||
|
||||
TEST(CApiTest, StreamReduce) {
|
||||
int N = 300;
|
||||
int topK = 100;
|
||||
int num_queries = 2;
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
|
||||
//1. set up segments
|
||||
CSegmentInterface segment_1;
|
||||
auto status = NewSegment(collection, Growing, -1, &segment_1);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
CSegmentInterface segment_2;
|
||||
status = NewSegment(collection, Growing, -1, &segment_2);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
//2. insert data into segments
|
||||
auto schema = ((milvus::segcore::Collection*)collection)->get_schema();
|
||||
auto dataset_1 = DataGen(schema, N, 55, 0, 1, 10, true);
|
||||
int64_t offset_1;
|
||||
PreInsert(segment_1, N, &offset_1);
|
||||
auto insert_data_1 = serialize(dataset_1.raw_);
|
||||
auto ins_res_1 = Insert(segment_1,
|
||||
offset_1,
|
||||
N,
|
||||
dataset_1.row_ids_.data(),
|
||||
dataset_1.timestamps_.data(),
|
||||
insert_data_1.data(),
|
||||
insert_data_1.size());
|
||||
ASSERT_EQ(ins_res_1.error_code, Success);
|
||||
|
||||
auto dataset_2 = DataGen(schema, N, 66, 0, 1, 10, true);
|
||||
int64_t offset_2;
|
||||
PreInsert(segment_2, N, &offset_2);
|
||||
auto insert_data_2 = serialize(dataset_2.raw_);
|
||||
auto ins_res_2 = Insert(segment_2,
|
||||
offset_2,
|
||||
N,
|
||||
dataset_2.row_ids_.data(),
|
||||
dataset_2.timestamps_.data(),
|
||||
insert_data_2.data(),
|
||||
insert_data_2.size());
|
||||
ASSERT_EQ(ins_res_2.error_code, Success);
|
||||
|
||||
//3. search two segments
|
||||
auto fmt = boost::format(R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: %1%
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0">
|
||||
output_field_ids: 100)") %
|
||||
topK;
|
||||
auto serialized_expr_plan = fmt.str();
|
||||
auto blob = generate_query_data(num_queries);
|
||||
void* plan = nullptr;
|
||||
auto binary_plan =
|
||||
translate_text_plan_to_binary_plan(serialized_expr_plan.data());
|
||||
status = CreateSearchPlanByExpr(
|
||||
collection, binary_plan.data(), binary_plan.size(), &plan);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(
|
||||
plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
|
||||
dataset_1.timestamps_.clear();
|
||||
dataset_1.timestamps_.push_back(1);
|
||||
dataset_2.timestamps_.clear();
|
||||
dataset_2.timestamps_.push_back(1);
|
||||
CSearchResult res1;
|
||||
CSearchResult res2;
|
||||
auto stats1 = CSearch(
|
||||
segment_1, plan, placeholderGroup, dataset_1.timestamps_[N - 1], &res1);
|
||||
ASSERT_EQ(stats1.error_code, Success);
|
||||
auto stats2 = CSearch(
|
||||
segment_2, plan, placeholderGroup, dataset_2.timestamps_[N - 1], &res2);
|
||||
ASSERT_EQ(stats2.error_code, Success);
|
||||
|
||||
//4. stream reduce two search results
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
|
||||
if (num_queries == 1) {
|
||||
slice_nqs = std::vector<int64_t>{num_queries};
|
||||
}
|
||||
auto slice_topKs = std::vector<int64_t>{topK, topK};
|
||||
if (topK == 1) {
|
||||
slice_topKs = std::vector<int64_t>{topK, topK};
|
||||
}
|
||||
|
||||
//5. set up stream reducer
|
||||
CSearchStreamReducer c_search_stream_reducer;
|
||||
NewStreamReducer(plan,
|
||||
slice_nqs.data(),
|
||||
slice_topKs.data(),
|
||||
slice_nqs.size(),
|
||||
&c_search_stream_reducer);
|
||||
StreamReduce(c_search_stream_reducer, &res1, 1);
|
||||
StreamReduce(c_search_stream_reducer, &res2, 1);
|
||||
CSearchResultDataBlobs c_search_result_data_blobs;
|
||||
GetStreamReduceResult(c_search_stream_reducer, &c_search_result_data_blobs);
|
||||
SearchResultDataBlobs* search_result_data_blob =
|
||||
(SearchResultDataBlobs*)(c_search_result_data_blobs);
|
||||
|
||||
//6. check
|
||||
for (size_t i = 0; i < slice_nqs.size(); i++) {
|
||||
milvus::proto::schema::SearchResultData search_result_data;
|
||||
auto suc = search_result_data.ParseFromArray(
|
||||
search_result_data_blob->blobs[i].data(),
|
||||
search_result_data_blob->blobs[i].size());
|
||||
ASSERT_TRUE(suc);
|
||||
ASSERT_EQ(search_result_data.num_queries(), slice_nqs[i]);
|
||||
ASSERT_EQ(search_result_data.top_k(), slice_topKs[i]);
|
||||
ASSERT_EQ(search_result_data.ids().int_id().data_size(),
|
||||
search_result_data.topks().at(0) * slice_nqs[i]);
|
||||
ASSERT_EQ(search_result_data.scores().size(),
|
||||
search_result_data.topks().at(0) * slice_nqs[i]);
|
||||
|
||||
ASSERT_EQ(search_result_data.topks().size(), slice_nqs[i]);
|
||||
for (auto real_topk : search_result_data.topks()) {
|
||||
ASSERT_LE(real_topk, slice_topKs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DeleteSearchResultDataBlobs(c_search_result_data_blobs);
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(res1);
|
||||
DeleteSearchResult(res2);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment_1);
|
||||
DeleteSegment(segment_2);
|
||||
DeleteStreamSearchReducer(c_search_stream_reducer);
|
||||
DeleteStreamSearchReducer(nullptr);
|
||||
}
|
||||
|
||||
TEST(CApiTest, StreamReduceGroupBY) {
|
||||
int N = 300;
|
||||
int topK = 100;
|
||||
int num_queries = 2;
|
||||
int dim = 16;
|
||||
namespace schema = milvus::proto::schema;
|
||||
|
||||
void* c_collection;
|
||||
//1. set up schema and collection
|
||||
{
|
||||
schema::CollectionSchema collection_schema;
|
||||
auto pk_field_schema = collection_schema.add_fields();
|
||||
pk_field_schema->set_name("pk_field");
|
||||
pk_field_schema->set_fieldid(100);
|
||||
pk_field_schema->set_data_type(schema::DataType::Int64);
|
||||
pk_field_schema->set_is_primary_key(true);
|
||||
|
||||
auto i8_field_schema = collection_schema.add_fields();
|
||||
i8_field_schema->set_name("int8_field");
|
||||
i8_field_schema->set_fieldid(101);
|
||||
i8_field_schema->set_data_type(schema::DataType::Int8);
|
||||
i8_field_schema->set_is_primary_key(false);
|
||||
|
||||
auto i16_field_schema = collection_schema.add_fields();
|
||||
i16_field_schema->set_name("int16_field");
|
||||
i16_field_schema->set_fieldid(102);
|
||||
i16_field_schema->set_data_type(schema::DataType::Int16);
|
||||
i16_field_schema->set_is_primary_key(false);
|
||||
|
||||
auto i32_field_schema = collection_schema.add_fields();
|
||||
i32_field_schema->set_name("int32_field");
|
||||
i32_field_schema->set_fieldid(103);
|
||||
i32_field_schema->set_data_type(schema::DataType::Int32);
|
||||
i32_field_schema->set_is_primary_key(false);
|
||||
|
||||
auto str_field_schema = collection_schema.add_fields();
|
||||
str_field_schema->set_name("str_field");
|
||||
str_field_schema->set_fieldid(104);
|
||||
str_field_schema->set_data_type(schema::DataType::VarChar);
|
||||
auto str_type_params = str_field_schema->add_type_params();
|
||||
str_type_params->set_key(MAX_LENGTH);
|
||||
str_type_params->set_value(std::to_string(64));
|
||||
str_field_schema->set_is_primary_key(false);
|
||||
|
||||
auto vec_field_schema = collection_schema.add_fields();
|
||||
vec_field_schema->set_name("fake_vec");
|
||||
vec_field_schema->set_fieldid(105);
|
||||
vec_field_schema->set_data_type(schema::DataType::FloatVector);
|
||||
auto metric_type_param = vec_field_schema->add_index_params();
|
||||
metric_type_param->set_key("metric_type");
|
||||
metric_type_param->set_value(knowhere::metric::L2);
|
||||
auto dim_param = vec_field_schema->add_type_params();
|
||||
dim_param->set_key("dim");
|
||||
dim_param->set_value(std::to_string(dim));
|
||||
c_collection = NewCollection(&collection_schema, knowhere::metric::L2);
|
||||
}
|
||||
|
||||
CSegmentInterface segment;
|
||||
auto status = NewSegment(c_collection, Growing, -1, &segment);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
//2. generate data and insert
|
||||
auto c_schema = ((milvus::segcore::Collection*)c_collection)->get_schema();
|
||||
auto dataset = DataGen(c_schema, N);
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
auto insert_data = serialize(dataset.raw_);
|
||||
auto ins_res = Insert(segment,
|
||||
offset,
|
||||
N,
|
||||
dataset.row_ids_.data(),
|
||||
dataset.timestamps_.data(),
|
||||
insert_data.data(),
|
||||
insert_data.size());
|
||||
ASSERT_EQ(ins_res.error_code, Success);
|
||||
|
||||
//3. search
|
||||
auto fmt = boost::format(R"(vector_anns: <
|
||||
field_id: 105
|
||||
query_info: <
|
||||
topk: %1%
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
group_by_field_id: 101
|
||||
>
|
||||
placeholder_tag: "$0">
|
||||
output_field_ids: 100)") %
|
||||
topK;
|
||||
auto serialized_expr_plan = fmt.str();
|
||||
auto blob = generate_query_data(num_queries);
|
||||
void* plan = nullptr;
|
||||
auto binary_plan =
|
||||
translate_text_plan_to_binary_plan(serialized_expr_plan.data());
|
||||
status = CreateSearchPlanByExpr(
|
||||
c_collection, binary_plan.data(), binary_plan.size(), &plan);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(
|
||||
plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
dataset.timestamps_.clear();
|
||||
dataset.timestamps_.push_back(1);
|
||||
|
||||
CSearchResult res1;
|
||||
CSearchResult res2;
|
||||
auto res = CSearch(
|
||||
segment, plan, placeholderGroup, dataset.timestamps_[N - 1], &res1);
|
||||
ASSERT_EQ(res.error_code, Success);
|
||||
res = CSearch(
|
||||
segment, plan, placeholderGroup, dataset.timestamps_[N - 1], &res2);
|
||||
ASSERT_EQ(res.error_code, Success);
|
||||
|
||||
//4. set up stream reducer
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
|
||||
if (num_queries == 1) {
|
||||
slice_nqs = std::vector<int64_t>{num_queries};
|
||||
}
|
||||
auto slice_topKs = std::vector<int64_t>{topK, topK};
|
||||
if (topK == 1) {
|
||||
slice_topKs = std::vector<int64_t>{topK, topK};
|
||||
}
|
||||
CSearchStreamReducer c_search_stream_reducer;
|
||||
NewStreamReducer(plan,
|
||||
slice_nqs.data(),
|
||||
slice_topKs.data(),
|
||||
slice_nqs.size(),
|
||||
&c_search_stream_reducer);
|
||||
|
||||
//5. stream reduce
|
||||
StreamReduce(c_search_stream_reducer, &res1, 1);
|
||||
StreamReduce(c_search_stream_reducer, &res2, 1);
|
||||
CSearchResultDataBlobs c_search_result_data_blobs;
|
||||
GetStreamReduceResult(c_search_stream_reducer, &c_search_result_data_blobs);
|
||||
SearchResultDataBlobs* search_result_data_blob =
|
||||
(SearchResultDataBlobs*)(c_search_result_data_blobs);
|
||||
|
||||
//6. check result
|
||||
for (size_t i = 0; i < slice_nqs.size(); i++) {
|
||||
milvus::proto::schema::SearchResultData search_result_data;
|
||||
auto suc = search_result_data.ParseFromArray(
|
||||
search_result_data_blob->blobs[i].data(),
|
||||
search_result_data_blob->blobs[i].size());
|
||||
ASSERT_TRUE(suc);
|
||||
ASSERT_EQ(search_result_data.num_queries(), slice_nqs[i]);
|
||||
ASSERT_EQ(search_result_data.top_k(), slice_topKs[i]);
|
||||
ASSERT_EQ(search_result_data.ids().int_id().data_size(),
|
||||
search_result_data.topks().at(0) * slice_nqs[i]);
|
||||
ASSERT_EQ(search_result_data.scores().size(),
|
||||
search_result_data.topks().at(0) * slice_nqs[i]);
|
||||
ASSERT_TRUE(search_result_data.has_group_by_field_value());
|
||||
|
||||
// check real topks
|
||||
ASSERT_EQ(search_result_data.topks().size(), slice_nqs[i]);
|
||||
for (auto real_topk : search_result_data.topks()) {
|
||||
ASSERT_LE(real_topk, slice_topKs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DeleteSearchResultDataBlobs(c_search_result_data_blobs);
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(res1);
|
||||
DeleteSearchResult(res2);
|
||||
DeleteCollection(c_collection);
|
||||
DeleteSegment(segment);
|
||||
DeleteStreamSearchReducer(c_search_stream_reducer);
|
||||
DeleteStreamSearchReducer(nullptr);
|
||||
}
|
|
@ -234,7 +234,8 @@ struct GeneratedData {
|
|||
uint64_t seed,
|
||||
uint64_t ts_offset,
|
||||
int repeat_count,
|
||||
int array_len);
|
||||
int array_len,
|
||||
bool random_pk);
|
||||
friend GeneratedData
|
||||
DataGenForJsonArray(SchemaPtr schema,
|
||||
int64_t N,
|
||||
|
@ -292,9 +293,10 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
uint64_t seed = 42,
|
||||
uint64_t ts_offset = 0,
|
||||
int repeat_count = 1,
|
||||
int array_len = 10) {
|
||||
int array_len = 10,
|
||||
bool random_pk = false) {
|
||||
using std::vector;
|
||||
std::default_random_engine er(seed);
|
||||
std::default_random_engine random(seed);
|
||||
std::normal_distribution<> distr(0, 1);
|
||||
int offset = 0;
|
||||
|
||||
|
@ -343,7 +345,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
Assert(dim % 8 == 0);
|
||||
vector<uint8_t> data(dim / 8 * N);
|
||||
for (auto& x : data) {
|
||||
x = er();
|
||||
x = random();
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -352,7 +354,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
auto dim = field_meta.get_dim();
|
||||
vector<float16> final(dim * N);
|
||||
for (auto& x : final) {
|
||||
x = float16(distr(er) + offset);
|
||||
x = float16(distr(random) + offset);
|
||||
}
|
||||
insert_cols(final, N, field_meta);
|
||||
break;
|
||||
|
@ -371,7 +373,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
auto dim = field_meta.get_dim();
|
||||
vector<bfloat16> final(dim * N);
|
||||
for (auto& x : final) {
|
||||
x = bfloat16(distr(er) + offset);
|
||||
x = bfloat16(distr(random) + offset);
|
||||
}
|
||||
insert_cols(final, N, field_meta);
|
||||
break;
|
||||
|
@ -387,7 +389,12 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::INT64: {
|
||||
vector<int64_t> data(N);
|
||||
for (int i = 0; i < N; i++) {
|
||||
data[i] = i / repeat_count;
|
||||
if (random_pk && schema->get_primary_field_id()->get() ==
|
||||
field_id.get()) {
|
||||
data[i] = random();
|
||||
} else {
|
||||
data[i] = i / repeat_count;
|
||||
}
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -395,7 +402,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::INT32: {
|
||||
vector<int> data(N);
|
||||
for (auto& x : data) {
|
||||
x = er() % (2 * N);
|
||||
x = random() % (2 * N);
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -403,7 +410,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::INT16: {
|
||||
vector<int16_t> data(N);
|
||||
for (auto& x : data) {
|
||||
x = er() % (2 * N);
|
||||
x = random() % (2 * N);
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -411,7 +418,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::INT8: {
|
||||
vector<int8_t> data(N);
|
||||
for (auto& x : data) {
|
||||
x = er() % (2 * N);
|
||||
x = random() % (2 * N);
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -419,7 +426,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::FLOAT: {
|
||||
vector<float> data(N);
|
||||
for (auto& x : data) {
|
||||
x = distr(er);
|
||||
x = distr(random);
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -427,7 +434,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::DOUBLE: {
|
||||
vector<double> data(N);
|
||||
for (auto& x : data) {
|
||||
x = distr(er);
|
||||
x = distr(random);
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
break;
|
||||
|
@ -435,7 +442,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::VARCHAR: {
|
||||
vector<std::string> data(N);
|
||||
for (int i = 0; i < N / repeat_count; i++) {
|
||||
auto str = std::to_string(er());
|
||||
auto str = std::to_string(random());
|
||||
for (int j = 0; j < repeat_count; j++) {
|
||||
data[i * repeat_count + j] = str;
|
||||
}
|
||||
|
@ -446,11 +453,12 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
case DataType::JSON: {
|
||||
vector<std::string> data(N);
|
||||
for (int i = 0; i < N / repeat_count; i++) {
|
||||
auto str =
|
||||
R"({"int":)" + std::to_string(er()) + R"(,"double":)" +
|
||||
std::to_string(static_cast<double>(er())) +
|
||||
R"(,"string":")" + std::to_string(er()) +
|
||||
R"(","bool": true)" + R"(, "array": [1,2,3])" + "}";
|
||||
auto str = R"({"int":)" + std::to_string(random()) +
|
||||
R"(,"double":)" +
|
||||
std::to_string(static_cast<double>(random())) +
|
||||
R"(,"string":")" + std::to_string(random()) +
|
||||
R"(","bool": true)" + R"(, "array": [1,2,3])" +
|
||||
"}";
|
||||
data[i] = str;
|
||||
}
|
||||
insert_cols(data, N, field_meta);
|
||||
|
@ -465,7 +473,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
|
||||
for (int j = 0; j < array_len; j++) {
|
||||
field_data.mutable_bool_data()->add_data(
|
||||
static_cast<bool>(er()));
|
||||
static_cast<bool>(random()));
|
||||
}
|
||||
data[i] = field_data;
|
||||
}
|
||||
|
@ -479,7 +487,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
|
||||
for (int j = 0; j < array_len; j++) {
|
||||
field_data.mutable_int_data()->add_data(
|
||||
static_cast<int>(er()));
|
||||
static_cast<int>(random()));
|
||||
}
|
||||
data[i] = field_data;
|
||||
}
|
||||
|
@ -490,7 +498,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
milvus::proto::schema::ScalarField field_data;
|
||||
for (int j = 0; j < array_len; j++) {
|
||||
field_data.mutable_long_data()->add_data(
|
||||
static_cast<int64_t>(er()));
|
||||
static_cast<int64_t>(random()));
|
||||
}
|
||||
data[i] = field_data;
|
||||
}
|
||||
|
@ -503,7 +511,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
|
||||
for (int j = 0; j < array_len; j++) {
|
||||
field_data.mutable_string_data()->add_data(
|
||||
std::to_string(er()));
|
||||
std::to_string(random()));
|
||||
}
|
||||
data[i] = field_data;
|
||||
}
|
||||
|
@ -515,7 +523,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
|
||||
for (int j = 0; j < array_len; j++) {
|
||||
field_data.mutable_float_data()->add_data(
|
||||
static_cast<float>(er()));
|
||||
static_cast<float>(random()));
|
||||
}
|
||||
data[i] = field_data;
|
||||
}
|
||||
|
@ -527,7 +535,7 @@ inline GeneratedData DataGen(SchemaPtr schema,
|
|||
|
||||
for (int j = 0; j < array_len; j++) {
|
||||
field_data.mutable_double_data()->add_data(
|
||||
static_cast<double>(er()));
|
||||
static_cast<double>(random()));
|
||||
}
|
||||
data[i] = field_data;
|
||||
}
|
||||
|
@ -1226,4 +1234,21 @@ NewCollection(const char* schema_proto_blob,
|
|||
return (void*)collection.release();
|
||||
}
|
||||
|
||||
inline CCollection
|
||||
NewCollection(const milvus::proto::schema::CollectionSchema* schema,
|
||||
MetricType metric_type = knowhere::metric::L2) {
|
||||
auto collection = std::make_unique<milvus::segcore::Collection>(schema);
|
||||
milvus::proto::segcore::CollectionIndexMeta col_index_meta;
|
||||
for (auto field : collection->get_schema()->get_fields()) {
|
||||
auto field_index_meta = col_index_meta.add_index_metas();
|
||||
auto index_param = field_index_meta->add_index_params();
|
||||
index_param->set_key("metric_type");
|
||||
index_param->set_value(metric_type);
|
||||
field_index_meta->set_fieldid(field.first.get());
|
||||
}
|
||||
collection->set_index_meta(
|
||||
std::make_shared<CollectionIndexMeta>(col_index_meta));
|
||||
return (void*)collection.release();
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -63,6 +63,7 @@ generate_max_float_query_data(int all_nq, int max_float_nq) {
|
|||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
|
||||
std::string
|
||||
generate_query_data(int nq) {
|
||||
namespace ser = milvus::proto::common;
|
||||
|
@ -102,7 +103,8 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
|
|||
auto ret = pk_set.insert(search_result->primary_keys_[ki]);
|
||||
ASSERT_TRUE(ret.second);
|
||||
|
||||
if (search_result->group_by_values_.value().size() > ki) {
|
||||
if (search_result->group_by_values_.has_value() &&
|
||||
search_result->group_by_values_.value().size() > ki) {
|
||||
auto group_by_val =
|
||||
search_result->group_by_values_.value()[ki];
|
||||
ASSERT_TRUE(group_by_val_set.count(group_by_val) == 0);
|
||||
|
@ -112,4 +114,41 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char*
|
||||
get_default_schema_config() {
|
||||
static std::string conf = R"(name: "default-collection"
|
||||
fields: <
|
||||
fieldID: 100
|
||||
name: "fakevec"
|
||||
data_type: FloatVector
|
||||
type_params: <
|
||||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
fieldID: 101
|
||||
name: "age"
|
||||
data_type: Int64
|
||||
is_primary_key: true
|
||||
>)";
|
||||
static std::string fake_conf = "";
|
||||
return conf.c_str();
|
||||
}
|
||||
|
||||
CStatus
|
||||
CSearch(CSegmentInterface c_segment,
|
||||
CSearchPlan c_plan,
|
||||
CPlaceholderGroup c_placeholder_group,
|
||||
uint64_t timestamp,
|
||||
CSearchResult* result) {
|
||||
return Search(
|
||||
{}, c_segment, c_plan, c_placeholder_group, timestamp, result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -371,6 +371,9 @@ func (gc *garbageCollector) checkDroppedSegmentGC(segment *SegmentInfo,
|
|||
) bool {
|
||||
log := log.With(zap.Int64("segmentID", segment.ID))
|
||||
|
||||
if !gc.isExpire(segment.GetDroppedAt()) {
|
||||
return false
|
||||
}
|
||||
isCompacted := childSegment != nil || segment.GetCompacted()
|
||||
if isCompacted {
|
||||
// For compact A, B -> C, don't GC A or B if C is not indexed,
|
||||
|
@ -382,10 +385,6 @@ func (gc *garbageCollector) checkDroppedSegmentGC(segment *SegmentInfo,
|
|||
zap.Int64("child segment ID", childSegment.GetID()))
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if !gc.isExpire(segment.GetDroppedAt()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
segInsertChannel := segment.GetInsertChannel()
|
||||
|
|
|
@ -1308,6 +1308,7 @@ func TestGarbageCollector_clearETCD(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
// cannot dropped for not expired.
|
||||
segID + 6: {
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: segID + 6,
|
||||
|
@ -1343,6 +1344,24 @@ func TestGarbageCollector_clearETCD(t *testing.T) {
|
|||
Compacted: true,
|
||||
},
|
||||
},
|
||||
// can be dropped for expired and compacted
|
||||
segID + 8: {
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: segID + 8,
|
||||
CollectionID: collID,
|
||||
PartitionID: partID,
|
||||
InsertChannel: "dmlChannel",
|
||||
NumOfRows: 2000,
|
||||
State: commonpb.SegmentState_Dropped,
|
||||
MaxRowNum: 65535,
|
||||
DroppedAt: uint64(time.Now().Add(-7 * 24 * time.Hour).UnixNano()),
|
||||
CompactionFrom: nil,
|
||||
DmlPosition: &msgpb.MsgPosition{
|
||||
Timestamp: 900,
|
||||
},
|
||||
Compacted: true,
|
||||
},
|
||||
},
|
||||
} {
|
||||
m.segments.SetSegment(segID, segment)
|
||||
}
|
||||
|
@ -1390,9 +1409,11 @@ func TestGarbageCollector_clearETCD(t *testing.T) {
|
|||
segF := gc.meta.GetSegment(segID + 5)
|
||||
assert.NotNil(t, segF)
|
||||
segG := gc.meta.GetSegment(segID + 6)
|
||||
assert.Nil(t, segG)
|
||||
assert.NotNil(t, segG)
|
||||
segH := gc.meta.GetSegment(segID + 7)
|
||||
assert.NotNil(t, segH)
|
||||
segG = gc.meta.GetSegment(segID + 8)
|
||||
assert.Nil(t, segG)
|
||||
err := gc.meta.indexMeta.AddSegmentIndex(&model.SegmentIndex{
|
||||
SegmentID: segID + 4,
|
||||
CollectionID: collID,
|
||||
|
|
|
@ -2750,6 +2750,9 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
|||
}
|
||||
|
||||
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.InsertMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
|
||||
if merr.Ok(it.result.GetStatus()) {
|
||||
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeUpsert, dbName, username).Add(float64(v))
|
||||
}
|
||||
metrics.ProxyFunctionCall.WithLabelValues(nodeID, method,
|
||||
metrics.SuccessLabel, dbName, collectionName).Inc()
|
||||
successCnt := it.result.UpsertCnt - int64(len(it.result.ErrIndex))
|
||||
|
|
|
@ -131,6 +131,7 @@ type shardDelegator struct {
|
|||
// cause growing segment meta has been stored in segmentManager/distribution/pkOracle/excludeSegments
|
||||
// in order to make add/remove growing be atomic, need lock before modify these meta info
|
||||
growingSegmentLock sync.RWMutex
|
||||
partitionStatsMut sync.RWMutex
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
|
@ -217,8 +218,12 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
|||
return nil, err
|
||||
}
|
||||
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
|
||||
PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed,
|
||||
PruneInfo{filterRatio: paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
|
||||
func() {
|
||||
sd.partitionStatsMut.RLock()
|
||||
defer sd.partitionStatsMut.RUnlock()
|
||||
PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed,
|
||||
PruneInfo{filterRatio: paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
|
||||
}()
|
||||
}
|
||||
|
||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||
|
@ -477,7 +482,11 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
|
|||
}
|
||||
|
||||
if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
|
||||
PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
|
||||
func() {
|
||||
sd.partitionStatsMut.RLock()
|
||||
defer sd.partitionStatsMut.RUnlock()
|
||||
PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()})
|
||||
}()
|
||||
}
|
||||
|
||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||
|
@ -804,7 +813,14 @@ func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs
|
|||
log.Info("failed to find valid partition stats file for partition", zap.Int64("partitionID", partID))
|
||||
continue
|
||||
}
|
||||
partStats, exists := sd.partitionStats[partID]
|
||||
|
||||
var partStats *storage.PartitionStatsSnapshot
|
||||
var exists bool
|
||||
func() {
|
||||
sd.partitionStatsMut.RLock()
|
||||
defer sd.partitionStatsMut.RUnlock()
|
||||
partStats, exists = sd.partitionStats[partID]
|
||||
}()
|
||||
if !exists || (exists && partStats.GetVersion() < maxVersion) {
|
||||
statsBytes, err := sd.chunkManager.Read(ctx, maxVersionFilePath)
|
||||
if err != nil {
|
||||
|
@ -816,8 +832,13 @@ func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs
|
|||
log.Error("failed to parse partition stats from bytes", zap.Int("bytes_length", len(statsBytes)))
|
||||
continue
|
||||
}
|
||||
sd.partitionStats[partID] = partStats
|
||||
partStats.SetVersion(maxVersion)
|
||||
|
||||
func() {
|
||||
sd.partitionStatsMut.Lock()
|
||||
defer sd.partitionStatsMut.Unlock()
|
||||
sd.partitionStats[partID] = partStats
|
||||
}()
|
||||
log.Info("Updated partitionStats for partition", zap.Int64("partitionID", partID))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,14 +138,14 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
|||
if ok := sd.VerifyExcludedSegments(segmentID, typeutil.MaxTimestamp); !ok {
|
||||
log.Warn("try to insert data into released segment, skip it", zap.Int64("segmentID", segmentID))
|
||||
sd.growingSegmentLock.Unlock()
|
||||
growing.Release()
|
||||
growing.Release(context.Background())
|
||||
continue
|
||||
}
|
||||
|
||||
if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) {
|
||||
// register created growing segment after insert, avoid to add empty growing to delegator
|
||||
sd.pkOracle.Register(growing, paramtable.GetNodeID())
|
||||
sd.segmentManager.Put(segments.SegmentTypeGrowing, growing)
|
||||
sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing)
|
||||
sd.addGrowing(SegmentEntry{
|
||||
NodeID: paramtable.GetNodeID(),
|
||||
SegmentID: segmentID,
|
||||
|
@ -378,7 +378,7 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
|
|||
|
||||
// clear loaded growing segments
|
||||
for _, segment := range loaded {
|
||||
segment.Release()
|
||||
segment.Release(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -630,13 +630,13 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
|||
growing0.EXPECT().ID().Return(1)
|
||||
growing0.EXPECT().Partition().Return(10)
|
||||
growing0.EXPECT().Type().Return(segments.SegmentTypeGrowing)
|
||||
growing0.EXPECT().Release()
|
||||
growing0.EXPECT().Release(context.Background())
|
||||
|
||||
growing1 := segments.NewMockSegment(s.T())
|
||||
growing1.EXPECT().ID().Return(2)
|
||||
growing1.EXPECT().Partition().Return(10)
|
||||
growing1.EXPECT().Type().Return(segments.SegmentTypeGrowing)
|
||||
growing1.EXPECT().Release()
|
||||
growing1.EXPECT().Release(context.Background())
|
||||
|
||||
mockErr := merr.WrapErrServiceInternal("mock")
|
||||
|
||||
|
@ -1068,7 +1068,7 @@ func (s *DelegatorDataSuite) TestSyncTargetVersion() {
|
|||
ms.EXPECT().Indexes().Return(nil)
|
||||
ms.EXPECT().Shard().Return(s.vchannelName)
|
||||
ms.EXPECT().Level().Return(datapb.SegmentLevel_L1)
|
||||
s.manager.Segment.Put(segments.SegmentTypeGrowing, ms)
|
||||
s.manager.Segment.Put(context.Background(), segments.SegmentTypeGrowing, ms)
|
||||
}
|
||||
|
||||
s.delegator.SyncTargetVersion(int64(5), []int64{1}, []int64{2}, []int64{3, 4}, &msgpb.MsgPosition{})
|
||||
|
|
|
@ -28,6 +28,39 @@ func (_m *MockShardDelegator) EXPECT() *MockShardDelegator_Expecter {
|
|||
return &MockShardDelegator_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AddExcludedSegments provides a mock function with given fields: excludeInfo
|
||||
func (_m *MockShardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {
|
||||
_m.Called(excludeInfo)
|
||||
}
|
||||
|
||||
// MockShardDelegator_AddExcludedSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddExcludedSegments'
|
||||
type MockShardDelegator_AddExcludedSegments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AddExcludedSegments is a helper method to define mock.On call
|
||||
// - excludeInfo map[int64]uint64
|
||||
func (_e *MockShardDelegator_Expecter) AddExcludedSegments(excludeInfo interface{}) *MockShardDelegator_AddExcludedSegments_Call {
|
||||
return &MockShardDelegator_AddExcludedSegments_Call{Call: _e.mock.On("AddExcludedSegments", excludeInfo)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_AddExcludedSegments_Call) Run(run func(excludeInfo map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(map[int64]uint64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_AddExcludedSegments_Call) Return() *MockShardDelegator_AddExcludedSegments_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_AddExcludedSegments_Call) RunAndReturn(run func(map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) Close() {
|
||||
_m.Called()
|
||||
|
@ -253,39 +286,6 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
|
|||
return _c
|
||||
}
|
||||
|
||||
// AddExcludedSegments provides a mock function with given fields: excludeInfo
|
||||
func (_m *MockShardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {
|
||||
_m.Called(excludeInfo)
|
||||
}
|
||||
|
||||
// MockShardDelegator_AddExcludedSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddExcludedSegments'
|
||||
type MockShardDelegator_AddExcludedSegments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AddExcludedSegments is a helper method to define mock.On call
|
||||
// - excludeInfo map[int64]uint64
|
||||
func (_e *MockShardDelegator_Expecter) AddExcludedSegments(excludeInfo interface{}) *MockShardDelegator_AddExcludedSegments_Call {
|
||||
return &MockShardDelegator_AddExcludedSegments_Call{Call: _e.mock.On("AddExcludedSegments", excludeInfo)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_AddExcludedSegments_Call) Run(run func(excludeInfo map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(map[int64]uint64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_AddExcludedSegments_Call) Return() *MockShardDelegator_AddExcludedSegments_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_AddExcludedSegments_Call) RunAndReturn(run func(map[int64]uint64)) *MockShardDelegator_AddExcludedSegments_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadGrowing provides a mock function with given fields: ctx, infos, version
|
||||
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
||||
ret := _m.Called(ctx, infos, version)
|
||||
|
|
|
@ -163,6 +163,12 @@ func (node *QueryNode) loadIndex(ctx context.Context, req *querypb.LoadSegmentsR
|
|||
continue
|
||||
}
|
||||
|
||||
if localSegment.IsLazyLoad() {
|
||||
localSegment.SetLoadInfo(info)
|
||||
localSegment.SetNeedUpdatedVersion(req.GetVersion())
|
||||
node.manager.DiskCache.MarkItemNeedReload(ctx, localSegment.ID())
|
||||
return nil
|
||||
}
|
||||
err := node.loader.LoadIndex(ctx, localSegment, info, req.Version)
|
||||
if err != nil {
|
||||
log.Warn("failed to load index", zap.Error(err))
|
||||
|
@ -425,11 +431,11 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge
|
|||
results, readSegments, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
|
||||
}
|
||||
|
||||
defer node.manager.Segment.Unpin(readSegments)
|
||||
if err != nil {
|
||||
log.Warn("get segments statistics failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
defer node.manager.Segment.Unpin(readSegments)
|
||||
return segmentStatsResponse(results), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -47,11 +47,11 @@ func HandleCStatus(ctx context.Context, status *C.CStatus, extraInfo string, fie
|
|||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
|
||||
log.Ctx(ctx).With(fields...).
|
||||
log := log.Ctx(ctx).With(fields...).
|
||||
WithOptions(zap.AddCallerSkip(1)) // Add caller stack to show HandleCStatus caller
|
||||
|
||||
err := merr.SegcoreError(int32(errorCode), errorMsg)
|
||||
log.Warn("CStatus returns err", zap.Error(err))
|
||||
log.Warn("CStatus returns err", zap.Error(err), zap.String("extra", extraInfo))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -148,13 +148,6 @@ func IncreaseVersion(version int64) SegmentAction {
|
|||
}
|
||||
}
|
||||
|
||||
type actionType int32
|
||||
|
||||
const (
|
||||
removeAction actionType = iota
|
||||
addAction
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
Collection CollectionManager
|
||||
Segment SegmentManager
|
||||
|
@ -173,18 +166,19 @@ func NewManager() *Manager {
|
|||
}
|
||||
|
||||
manager.DiskCache = cache.NewCacheBuilder[int64, Segment]().WithLazyScavenger(func(key int64) int64 {
|
||||
return int64(segMgr.sealedSegments[key].ResourceUsageEstimate().DiskSize)
|
||||
}, diskCap).WithLoader(func(key int64) (Segment, bool) {
|
||||
log.Debug("cache missed segment", zap.Int64("segmentID", key))
|
||||
segMgr.mu.RLock()
|
||||
defer segMgr.mu.RUnlock()
|
||||
|
||||
segment, ok := segMgr.sealedSegments[key]
|
||||
if !ok {
|
||||
// the segment has been released, just ignore it
|
||||
return nil, false
|
||||
segment := segMgr.GetWithType(key, SegmentTypeSealed)
|
||||
if segment == nil {
|
||||
return 0
|
||||
}
|
||||
return int64(segment.ResourceUsageEstimate().DiskSize)
|
||||
}, diskCap).WithLoader(func(ctx context.Context, key int64) (Segment, error) {
|
||||
log.Debug("cache missed segment", zap.Int64("segmentID", key))
|
||||
segment := segMgr.GetWithType(key, SegmentTypeSealed)
|
||||
if segment == nil {
|
||||
// the segment has been released, just ignore it
|
||||
log.Warn("segment is not found when loading", zap.Int64("segmentID", key))
|
||||
return nil, merr.ErrSegmentNotFound
|
||||
}
|
||||
|
||||
info := segment.LoadInfo()
|
||||
_, err, _ := sf.Do(fmt.Sprint(segment.ID()), func() (nop interface{}, err error) {
|
||||
cacheLoadRecord := metricsutil.NewCacheLoadRecord(getSegmentMetricLabel(segment))
|
||||
|
@ -197,23 +191,50 @@ func NewManager() *Manager {
|
|||
if collection == nil {
|
||||
return nil, merr.WrapErrCollectionNotLoaded(segment.Collection(), "failed to load segment fields")
|
||||
}
|
||||
err = manager.Loader.LoadSegment(context.Background(), segment.(*LocalSegment), info, LoadStatusMapped)
|
||||
|
||||
err = manager.Loader.LoadLazySegment(ctx, segment.(*LocalSegment), info)
|
||||
return nil, err
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("cache sealed segment failed", zap.Error(err))
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
return segment, true
|
||||
}).WithFinalizer(func(key int64, segment Segment) error {
|
||||
log.Debug("evict segment from cache", zap.Int64("segmentID", key))
|
||||
return segment, nil
|
||||
}).WithFinalizer(func(ctx context.Context, key int64, segment Segment) error {
|
||||
log.Ctx(ctx).Debug("evict segment from cache", zap.Int64("segmentID", key))
|
||||
cacheEvictRecord := metricsutil.NewCacheEvictRecord(getSegmentMetricLabel(segment))
|
||||
cacheEvictRecord.WithBytes(segment.ResourceUsageEstimate().DiskSize)
|
||||
defer cacheEvictRecord.Finish(nil)
|
||||
|
||||
segment.Release(WithReleaseScope(ReleaseScopeData))
|
||||
segment.Release(ctx, WithReleaseScope(ReleaseScopeData))
|
||||
return nil
|
||||
}).WithReloader(func(ctx context.Context, key int64) (Segment, error) {
|
||||
segment := segMgr.GetWithType(key, SegmentTypeSealed)
|
||||
if segment == nil {
|
||||
// the segment has been released, just ignore it
|
||||
log.Debug("segment is not found when reloading", zap.Int64("segmentID", key))
|
||||
return nil, merr.ErrSegmentNotFound
|
||||
}
|
||||
|
||||
localSegment := segment.(*LocalSegment)
|
||||
err := manager.Loader.LoadIndex(ctx, localSegment, segment.LoadInfo(), segment.NeedUpdatedVersion())
|
||||
if err != nil {
|
||||
log.Warn("reload segment failed", zap.Int64("segmentID", key), zap.Error(err))
|
||||
return nil, merr.ErrSegmentLoadFailed
|
||||
}
|
||||
if err := localSegment.RemoveUnusedFieldFiles(); err != nil {
|
||||
log.Warn("remove unused field files failed", zap.Int64("segmentID", key), zap.Error(err))
|
||||
return nil, merr.ErrSegmentReduplicate
|
||||
}
|
||||
|
||||
return segment, nil
|
||||
}).Build()
|
||||
|
||||
segMgr.registerReleaseCallback(func(s Segment) {
|
||||
if s.Type() == SegmentTypeSealed {
|
||||
manager.DiskCache.Expire(context.Background(), s.ID())
|
||||
}
|
||||
})
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
|
@ -225,7 +246,7 @@ type SegmentManager interface {
|
|||
// Put puts the given segments in,
|
||||
// and increases the ref count of the corresponding collection,
|
||||
// dup segments will not increase the ref count
|
||||
Put(segmentType SegmentType, segments ...Segment)
|
||||
Put(ctx context.Context, segmentType SegmentType, segments ...Segment)
|
||||
UpdateBy(action SegmentAction, filters ...SegmentFilter) int
|
||||
Get(segmentID typeutil.UniqueID) Segment
|
||||
GetWithType(segmentID typeutil.UniqueID, typ SegmentType) Segment
|
||||
|
@ -242,9 +263,9 @@ type SegmentManager interface {
|
|||
// Remove removes the given segment,
|
||||
// and decreases the ref count of the corresponding collection,
|
||||
// will not decrease the ref count if the given segment not exists
|
||||
Remove(segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int)
|
||||
RemoveBy(filters ...SegmentFilter) (int, int)
|
||||
Clear()
|
||||
Remove(ctx context.Context, segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int)
|
||||
RemoveBy(ctx context.Context, filters ...SegmentFilter) (int, int)
|
||||
Clear(ctx context.Context)
|
||||
}
|
||||
|
||||
var _ SegmentManager = (*segmentManager)(nil)
|
||||
|
@ -255,6 +276,9 @@ type segmentManager struct {
|
|||
|
||||
growingSegments map[typeutil.UniqueID]Segment
|
||||
sealedSegments map[typeutil.UniqueID]Segment
|
||||
|
||||
// releaseCallback is the callback function when a segment is released.
|
||||
releaseCallback func(s Segment)
|
||||
}
|
||||
|
||||
func NewSegmentManager() *segmentManager {
|
||||
|
@ -265,7 +289,7 @@ func NewSegmentManager() *segmentManager {
|
|||
return mgr
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
|
||||
func (mgr *segmentManager) Put(ctx context.Context, segmentType SegmentType, segments ...Segment) {
|
||||
var replacedSegment []Segment
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
@ -278,7 +302,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
|
|||
default:
|
||||
panic("unexpected segment type")
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
for _, segment := range segments {
|
||||
oldSegment, ok := targetMap[segment.ID()]
|
||||
|
||||
|
@ -290,7 +314,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
|
|||
zap.Int64("newVersion", segment.Version()),
|
||||
)
|
||||
// delete redundant segment
|
||||
segment.Release()
|
||||
segment.Release(ctx)
|
||||
continue
|
||||
}
|
||||
replacedSegment = append(replacedSegment, oldSegment)
|
||||
|
@ -313,7 +337,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
|
|||
if len(replacedSegment) > 0 {
|
||||
go func() {
|
||||
for _, segment := range replacedSegment {
|
||||
remove(segment)
|
||||
mgr.remove(ctx, segment)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -563,7 +587,7 @@ func (mgr *segmentManager) Empty() bool {
|
|||
|
||||
// returns true if the segment exists,
|
||||
// false otherwise
|
||||
func (mgr *segmentManager) Remove(segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int) {
|
||||
func (mgr *segmentManager) Remove(ctx context.Context, segmentID typeutil.UniqueID, scope querypb.DataScope) (int, int) {
|
||||
mgr.mu.Lock()
|
||||
|
||||
var removeGrowing, removeSealed int
|
||||
|
@ -596,11 +620,11 @@ func (mgr *segmentManager) Remove(segmentID typeutil.UniqueID, scope querypb.Dat
|
|||
mgr.mu.Unlock()
|
||||
|
||||
if growing != nil {
|
||||
remove(growing)
|
||||
mgr.remove(ctx, growing)
|
||||
}
|
||||
|
||||
if sealed != nil {
|
||||
remove(sealed)
|
||||
mgr.remove(ctx, sealed)
|
||||
}
|
||||
|
||||
return removeGrowing, removeSealed
|
||||
|
@ -628,7 +652,7 @@ func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID type
|
|||
return nil
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) (int, int) {
|
||||
func (mgr *segmentManager) RemoveBy(ctx context.Context, filters ...SegmentFilter) (int, int) {
|
||||
mgr.mu.Lock()
|
||||
|
||||
var removeSegments []Segment
|
||||
|
@ -651,28 +675,34 @@ func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) (int, int) {
|
|||
mgr.mu.Unlock()
|
||||
|
||||
for _, s := range removeSegments {
|
||||
remove(s)
|
||||
mgr.remove(ctx, s)
|
||||
}
|
||||
|
||||
return removeGrowing, removeSealed
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) Clear() {
|
||||
func (mgr *segmentManager) Clear(ctx context.Context) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
for id, segment := range mgr.growingSegments {
|
||||
delete(mgr.growingSegments, id)
|
||||
remove(segment)
|
||||
mgr.remove(ctx, segment)
|
||||
}
|
||||
|
||||
for id, segment := range mgr.sealedSegments {
|
||||
delete(mgr.sealedSegments, id)
|
||||
remove(segment)
|
||||
mgr.remove(ctx, segment)
|
||||
}
|
||||
mgr.updateMetric()
|
||||
}
|
||||
|
||||
// registerReleaseCallback registers the callback function when a segment is released.
|
||||
// TODO: bad implementation for keep consistency with DiskCache, need to be refactor.
|
||||
func (mgr *segmentManager) registerReleaseCallback(callback func(s Segment)) {
|
||||
mgr.releaseCallback = callback
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) updateMetric() {
|
||||
// update collection and partiation metric
|
||||
collections, partiations := make(typeutil.Set[int64]), make(typeutil.Set[int64])
|
||||
|
@ -688,8 +718,11 @@ func (mgr *segmentManager) updateMetric() {
|
|||
metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(partiations.Len()))
|
||||
}
|
||||
|
||||
func remove(segment Segment) bool {
|
||||
segment.Release()
|
||||
func (mgr *segmentManager) remove(ctx context.Context, segment Segment) bool {
|
||||
segment.Release(ctx)
|
||||
if mgr.releaseCallback != nil {
|
||||
mgr.releaseCallback(segment)
|
||||
}
|
||||
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
|
|
|
@ -62,7 +62,7 @@ func (s *ManagerSuite) SetupTest() {
|
|||
s.Require().NoError(err)
|
||||
s.segments = append(s.segments, segment)
|
||||
|
||||
s.mgr.Put(s.types[i], segment)
|
||||
s.mgr.Put(context.Background(), s.types[i], segment)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,7 @@ func (s *ManagerSuite) TestGetBy() {
|
|||
segments := s.mgr.GetBy(WithType(typ))
|
||||
s.Contains(lo.Map(segments, func(segment Segment, _ int) int64 { return segment.ID() }), s.segmentIDs[i])
|
||||
}
|
||||
s.mgr.Clear()
|
||||
s.mgr.Clear(context.Background())
|
||||
|
||||
for _, typ := range s.types {
|
||||
segments := s.mgr.GetBy(WithType(typ))
|
||||
|
@ -101,7 +101,7 @@ func (s *ManagerSuite) TestRemoveGrowing() {
|
|||
for i, id := range s.segmentIDs {
|
||||
isGrowing := s.types[i] == SegmentTypeGrowing
|
||||
|
||||
s.mgr.Remove(id, querypb.DataScope_Streaming)
|
||||
s.mgr.Remove(context.Background(), id, querypb.DataScope_Streaming)
|
||||
s.Equal(s.mgr.Get(id) == nil, isGrowing)
|
||||
}
|
||||
}
|
||||
|
@ -110,21 +110,21 @@ func (s *ManagerSuite) TestRemoveSealed() {
|
|||
for i, id := range s.segmentIDs {
|
||||
isSealed := s.types[i] == SegmentTypeSealed
|
||||
|
||||
s.mgr.Remove(id, querypb.DataScope_Historical)
|
||||
s.mgr.Remove(context.Background(), id, querypb.DataScope_Historical)
|
||||
s.Equal(s.mgr.Get(id) == nil, isSealed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestRemoveAll() {
|
||||
for _, id := range s.segmentIDs {
|
||||
s.mgr.Remove(id, querypb.DataScope_All)
|
||||
s.mgr.Remove(context.Background(), id, querypb.DataScope_All)
|
||||
s.Nil(s.mgr.Get(id))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestRemoveBy() {
|
||||
for _, id := range s.segmentIDs {
|
||||
s.mgr.RemoveBy(WithID(id))
|
||||
s.mgr.RemoveBy(context.Background(), WithID(id))
|
||||
s.Nil(s.mgr.Get(id))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -261,13 +261,57 @@ func (_c *MockLoader_LoadIndex_Call) RunAndReturn(run func(context.Context, *Loc
|
|||
return _c
|
||||
}
|
||||
|
||||
// LoadSegment provides a mock function with given fields: ctx, segment, loadInfo, loadStatus
|
||||
func (_m *MockLoader) LoadSegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo, loadStatus LoadStatus) error {
|
||||
ret := _m.Called(ctx, segment, loadInfo, loadStatus)
|
||||
// LoadLazySegment provides a mock function with given fields: ctx, segment, loadInfo
|
||||
func (_m *MockLoader) LoadLazySegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo) error {
|
||||
ret := _m.Called(ctx, segment, loadInfo)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo, LoadStatus) error); ok {
|
||||
r0 = rf(ctx, segment, loadInfo, loadStatus)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error); ok {
|
||||
r0 = rf(ctx, segment, loadInfo)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockLoader_LoadLazySegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadLazySegment'
|
||||
type MockLoader_LoadLazySegment_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LoadLazySegment is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - segment *LocalSegment
|
||||
// - loadInfo *querypb.SegmentLoadInfo
|
||||
func (_e *MockLoader_Expecter) LoadLazySegment(ctx interface{}, segment interface{}, loadInfo interface{}) *MockLoader_LoadLazySegment_Call {
|
||||
return &MockLoader_LoadLazySegment_Call{Call: _e.mock.On("LoadLazySegment", ctx, segment, loadInfo)}
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadLazySegment_Call) Run(run func(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo)) *MockLoader_LoadLazySegment_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*LocalSegment), args[2].(*querypb.SegmentLoadInfo))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadLazySegment_Call) Return(_a0 error) *MockLoader_LoadLazySegment_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadLazySegment_Call) RunAndReturn(run func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error) *MockLoader_LoadLazySegment_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadSegment provides a mock function with given fields: ctx, segment, loadInfo
|
||||
func (_m *MockLoader) LoadSegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo) error {
|
||||
ret := _m.Called(ctx, segment, loadInfo)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error); ok {
|
||||
r0 = rf(ctx, segment, loadInfo)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
@ -284,14 +328,13 @@ type MockLoader_LoadSegment_Call struct {
|
|||
// - ctx context.Context
|
||||
// - segment *LocalSegment
|
||||
// - loadInfo *querypb.SegmentLoadInfo
|
||||
// - loadStatus LoadStatus
|
||||
func (_e *MockLoader_Expecter) LoadSegment(ctx interface{}, segment interface{}, loadInfo interface{}, loadStatus interface{}) *MockLoader_LoadSegment_Call {
|
||||
return &MockLoader_LoadSegment_Call{Call: _e.mock.On("LoadSegment", ctx, segment, loadInfo, loadStatus)}
|
||||
func (_e *MockLoader_Expecter) LoadSegment(ctx interface{}, segment interface{}, loadInfo interface{}) *MockLoader_LoadSegment_Call {
|
||||
return &MockLoader_LoadSegment_Call{Call: _e.mock.On("LoadSegment", ctx, segment, loadInfo)}
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadSegment_Call) Run(run func(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo, loadStatus LoadStatus)) *MockLoader_LoadSegment_Call {
|
||||
func (_c *MockLoader_LoadSegment_Call) Run(run func(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo)) *MockLoader_LoadSegment_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*LocalSegment), args[2].(*querypb.SegmentLoadInfo), args[3].(LoadStatus))
|
||||
run(args[0].(context.Context), args[1].(*LocalSegment), args[2].(*querypb.SegmentLoadInfo))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -301,7 +344,7 @@ func (_c *MockLoader_LoadSegment_Call) Return(_a0 error) *MockLoader_LoadSegment
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadSegment_Call) RunAndReturn(run func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo, LoadStatus) error) *MockLoader_LoadSegment_Call {
|
||||
func (_c *MockLoader_LoadSegment_Call) RunAndReturn(run func(context.Context, *LocalSegment, *querypb.SegmentLoadInfo) error) *MockLoader_LoadSegment_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -709,47 +709,6 @@ func (_c *MockSegment_LoadInfo_Call) RunAndReturn(run func() *querypb.SegmentLoa
|
|||
return _c
|
||||
}
|
||||
|
||||
// LoadStatus provides a mock function with given fields:
|
||||
func (_m *MockSegment) LoadStatus() LoadStatus {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 LoadStatus
|
||||
if rf, ok := ret.Get(0).(func() LoadStatus); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(LoadStatus)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockSegment_LoadStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadStatus'
|
||||
type MockSegment_LoadStatus_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LoadStatus is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) LoadStatus() *MockSegment_LoadStatus_Call {
|
||||
return &MockSegment_LoadStatus_Call{Call: _e.mock.On("LoadStatus")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_LoadStatus_Call) Run(run func()) *MockSegment_LoadStatus_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_LoadStatus_Call) Return(_a0 LoadStatus) *MockSegment_LoadStatus_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_LoadStatus_Call) RunAndReturn(run func() LoadStatus) *MockSegment_LoadStatus_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// MayPkExist provides a mock function with given fields: pk
|
||||
func (_m *MockSegment) MayPkExist(pk storage.PrimaryKey) bool {
|
||||
ret := _m.Called(pk)
|
||||
|
@ -833,6 +792,47 @@ func (_c *MockSegment_MemSize_Call) RunAndReturn(run func() int64) *MockSegment_
|
|||
return _c
|
||||
}
|
||||
|
||||
// NeedUpdatedVersion provides a mock function with given fields:
|
||||
func (_m *MockSegment) NeedUpdatedVersion() int64 {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 int64
|
||||
if rf, ok := ret.Get(0).(func() int64); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(int64)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockSegment_NeedUpdatedVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NeedUpdatedVersion'
|
||||
type MockSegment_NeedUpdatedVersion_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// NeedUpdatedVersion is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) NeedUpdatedVersion() *MockSegment_NeedUpdatedVersion_Call {
|
||||
return &MockSegment_NeedUpdatedVersion_Call{Call: _e.mock.On("NeedUpdatedVersion")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_NeedUpdatedVersion_Call) Run(run func()) *MockSegment_NeedUpdatedVersion_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_NeedUpdatedVersion_Call) Return(_a0 int64) *MockSegment_NeedUpdatedVersion_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_NeedUpdatedVersion_Call) RunAndReturn(run func() int64) *MockSegment_NeedUpdatedVersion_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Partition provides a mock function with given fields:
|
||||
func (_m *MockSegment) Partition() int64 {
|
||||
ret := _m.Called()
|
||||
|
@ -888,72 +888,41 @@ func (_m *MockSegment) PinIfNotReleased() error {
|
|||
return r0
|
||||
}
|
||||
|
||||
// MockSegment_RLock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RLock'
|
||||
type MockSegment_RLock_Call struct {
|
||||
// MockSegment_PinIfNotReleased_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PinIfNotReleased'
|
||||
type MockSegment_PinIfNotReleased_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RLock is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) RLock() *MockSegment_RLock_Call {
|
||||
return &MockSegment_RLock_Call{Call: _e.mock.On("RLock")}
|
||||
// PinIfNotReleased is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) PinIfNotReleased() *MockSegment_PinIfNotReleased_Call {
|
||||
return &MockSegment_PinIfNotReleased_Call{Call: _e.mock.On("PinIfNotReleased")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RLock_Call) Run(run func()) *MockSegment_RLock_Call {
|
||||
func (_c *MockSegment_PinIfNotReleased_Call) Run(run func()) *MockSegment_PinIfNotReleased_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RLock_Call) Return(_a0 error) *MockSegment_RLock_Call {
|
||||
func (_c *MockSegment_PinIfNotReleased_Call) Return(_a0 error) *MockSegment_PinIfNotReleased_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RLock_Call) RunAndReturn(run func() error) *MockSegment_RLock_Call {
|
||||
func (_c *MockSegment_PinIfNotReleased_Call) RunAndReturn(run func() error) *MockSegment_PinIfNotReleased_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Unpin provides a mock function with given fields:
|
||||
func (_m *MockSegment) Unpin() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockSegment_RUnlock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RUnlock'
|
||||
type MockSegment_RUnlock_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RUnlock is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) RUnlock() *MockSegment_RUnlock_Call {
|
||||
return &MockSegment_RUnlock_Call{Call: _e.mock.On("RUnlock")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RUnlock_Call) Run(run func()) *MockSegment_RUnlock_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RUnlock_Call) Return() *MockSegment_RUnlock_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RUnlock_Call) RunAndReturn(run func()) *MockSegment_RUnlock_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Release provides a mock function with given fields: opts
|
||||
func (_m *MockSegment) Release(opts ...releaseOption) {
|
||||
// Release provides a mock function with given fields: ctx, opts
|
||||
func (_m *MockSegment) Release(ctx context.Context, opts ...releaseOption) {
|
||||
_va := make([]interface{}, len(opts))
|
||||
for _i := range opts {
|
||||
_va[_i] = opts[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx)
|
||||
_ca = append(_ca, _va...)
|
||||
_m.Called(_ca...)
|
||||
}
|
||||
|
@ -964,21 +933,22 @@ type MockSegment_Release_Call struct {
|
|||
}
|
||||
|
||||
// Release is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - opts ...releaseOption
|
||||
func (_e *MockSegment_Expecter) Release(opts ...interface{}) *MockSegment_Release_Call {
|
||||
func (_e *MockSegment_Expecter) Release(ctx interface{}, opts ...interface{}) *MockSegment_Release_Call {
|
||||
return &MockSegment_Release_Call{Call: _e.mock.On("Release",
|
||||
append([]interface{}{}, opts...)...)}
|
||||
append([]interface{}{ctx}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_Release_Call) Run(run func(opts ...releaseOption)) *MockSegment_Release_Call {
|
||||
func (_c *MockSegment_Release_Call) Run(run func(ctx context.Context, opts ...releaseOption)) *MockSegment_Release_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]releaseOption, len(args)-0)
|
||||
for i, a := range args[0:] {
|
||||
variadicArgs := make([]releaseOption, len(args)-1)
|
||||
for i, a := range args[1:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(releaseOption)
|
||||
}
|
||||
}
|
||||
run(variadicArgs...)
|
||||
run(args[0].(context.Context), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -988,7 +958,48 @@ func (_c *MockSegment_Release_Call) Return() *MockSegment_Release_Call {
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_Release_Call) RunAndReturn(run func(...releaseOption)) *MockSegment_Release_Call {
|
||||
func (_c *MockSegment_Release_Call) RunAndReturn(run func(context.Context, ...releaseOption)) *MockSegment_Release_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveUnusedFieldFiles provides a mock function with given fields:
|
||||
func (_m *MockSegment) RemoveUnusedFieldFiles() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockSegment_RemoveUnusedFieldFiles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveUnusedFieldFiles'
|
||||
type MockSegment_RemoveUnusedFieldFiles_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RemoveUnusedFieldFiles is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) RemoveUnusedFieldFiles() *MockSegment_RemoveUnusedFieldFiles_Call {
|
||||
return &MockSegment_RemoveUnusedFieldFiles_Call{Call: _e.mock.On("RemoveUnusedFieldFiles")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RemoveUnusedFieldFiles_Call) Run(run func()) *MockSegment_RemoveUnusedFieldFiles_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RemoveUnusedFieldFiles_Call) Return(_a0 error) *MockSegment_RemoveUnusedFieldFiles_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_RemoveUnusedFieldFiles_Call) RunAndReturn(run func() error) *MockSegment_RemoveUnusedFieldFiles_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
@ -1440,6 +1451,38 @@ func (_c *MockSegment_Type_Call) RunAndReturn(run func() commonpb.SegmentState)
|
|||
return _c
|
||||
}
|
||||
|
||||
// Unpin provides a mock function with given fields:
|
||||
func (_m *MockSegment) Unpin() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockSegment_Unpin_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Unpin'
|
||||
type MockSegment_Unpin_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Unpin is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) Unpin() *MockSegment_Unpin_Call {
|
||||
return &MockSegment_Unpin_Call{Call: _e.mock.On("Unpin")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_Unpin_Call) Run(run func()) *MockSegment_Unpin_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_Unpin_Call) Return() *MockSegment_Unpin_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_Unpin_Call) RunAndReturn(run func()) *MockSegment_Unpin_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateBloomFilter provides a mock function with given fields: pks
|
||||
func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
||||
_m.Called(pks)
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
@ -22,9 +25,9 @@ func (_m *MockSegmentManager) EXPECT() *MockSegmentManager_Expecter {
|
|||
return &MockSegmentManager_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Clear provides a mock function with given fields:
|
||||
func (_m *MockSegmentManager) Clear() {
|
||||
_m.Called()
|
||||
// Clear provides a mock function with given fields: ctx
|
||||
func (_m *MockSegmentManager) Clear(ctx context.Context) {
|
||||
_m.Called(ctx)
|
||||
}
|
||||
|
||||
// MockSegmentManager_Clear_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Clear'
|
||||
|
@ -33,13 +36,14 @@ type MockSegmentManager_Clear_Call struct {
|
|||
}
|
||||
|
||||
// Clear is a helper method to define mock.On call
|
||||
func (_e *MockSegmentManager_Expecter) Clear() *MockSegmentManager_Clear_Call {
|
||||
return &MockSegmentManager_Clear_Call{Call: _e.mock.On("Clear")}
|
||||
// - ctx context.Context
|
||||
func (_e *MockSegmentManager_Expecter) Clear(ctx interface{}) *MockSegmentManager_Clear_Call {
|
||||
return &MockSegmentManager_Clear_Call{Call: _e.mock.On("Clear", ctx)}
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_Clear_Call) Run(run func()) *MockSegmentManager_Clear_Call {
|
||||
func (_c *MockSegmentManager_Clear_Call) Run(run func(ctx context.Context)) *MockSegmentManager_Clear_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -49,7 +53,7 @@ func (_c *MockSegmentManager_Clear_Call) Return() *MockSegmentManager_Clear_Call
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_Clear_Call) RunAndReturn(run func()) *MockSegmentManager_Clear_Call {
|
||||
func (_c *MockSegmentManager_Clear_Call) RunAndReturn(run func(context.Context)) *MockSegmentManager_Clear_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
@ -465,14 +469,14 @@ func (_c *MockSegmentManager_GetWithType_Call) RunAndReturn(run func(int64, comm
|
|||
return _c
|
||||
}
|
||||
|
||||
// Put provides a mock function with given fields: segmentType, segments
|
||||
func (_m *MockSegmentManager) Put(segmentType commonpb.SegmentState, segments ...Segment) {
|
||||
// Put provides a mock function with given fields: ctx, segmentType, segments
|
||||
func (_m *MockSegmentManager) Put(ctx context.Context, segmentType commonpb.SegmentState, segments ...Segment) {
|
||||
_va := make([]interface{}, len(segments))
|
||||
for _i := range segments {
|
||||
_va[_i] = segments[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, segmentType)
|
||||
_ca = append(_ca, ctx, segmentType)
|
||||
_ca = append(_ca, _va...)
|
||||
_m.Called(_ca...)
|
||||
}
|
||||
|
@ -483,22 +487,23 @@ type MockSegmentManager_Put_Call struct {
|
|||
}
|
||||
|
||||
// Put is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - segmentType commonpb.SegmentState
|
||||
// - segments ...Segment
|
||||
func (_e *MockSegmentManager_Expecter) Put(segmentType interface{}, segments ...interface{}) *MockSegmentManager_Put_Call {
|
||||
func (_e *MockSegmentManager_Expecter) Put(ctx interface{}, segmentType interface{}, segments ...interface{}) *MockSegmentManager_Put_Call {
|
||||
return &MockSegmentManager_Put_Call{Call: _e.mock.On("Put",
|
||||
append([]interface{}{segmentType}, segments...)...)}
|
||||
append([]interface{}{ctx, segmentType}, segments...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_Put_Call) Run(run func(segmentType commonpb.SegmentState, segments ...Segment)) *MockSegmentManager_Put_Call {
|
||||
func (_c *MockSegmentManager_Put_Call) Run(run func(ctx context.Context, segmentType commonpb.SegmentState, segments ...Segment)) *MockSegmentManager_Put_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]Segment, len(args)-1)
|
||||
for i, a := range args[1:] {
|
||||
variadicArgs := make([]Segment, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(Segment)
|
||||
}
|
||||
}
|
||||
run(args[0].(commonpb.SegmentState), variadicArgs...)
|
||||
run(args[0].(context.Context), args[1].(commonpb.SegmentState), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -508,28 +513,28 @@ func (_c *MockSegmentManager_Put_Call) Return() *MockSegmentManager_Put_Call {
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_Put_Call) RunAndReturn(run func(commonpb.SegmentState, ...Segment)) *MockSegmentManager_Put_Call {
|
||||
func (_c *MockSegmentManager_Put_Call) RunAndReturn(run func(context.Context, commonpb.SegmentState, ...Segment)) *MockSegmentManager_Put_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Remove provides a mock function with given fields: segmentID, scope
|
||||
func (_m *MockSegmentManager) Remove(segmentID int64, scope querypb.DataScope) (int, int) {
|
||||
ret := _m.Called(segmentID, scope)
|
||||
// Remove provides a mock function with given fields: ctx, segmentID, scope
|
||||
func (_m *MockSegmentManager) Remove(ctx context.Context, segmentID int64, scope querypb.DataScope) (int, int) {
|
||||
ret := _m.Called(ctx, segmentID, scope)
|
||||
|
||||
var r0 int
|
||||
var r1 int
|
||||
if rf, ok := ret.Get(0).(func(int64, querypb.DataScope) (int, int)); ok {
|
||||
return rf(segmentID, scope)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64, querypb.DataScope) (int, int)); ok {
|
||||
return rf(ctx, segmentID, scope)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(int64, querypb.DataScope) int); ok {
|
||||
r0 = rf(segmentID, scope)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64, querypb.DataScope) int); ok {
|
||||
r0 = rf(ctx, segmentID, scope)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(int64, querypb.DataScope) int); ok {
|
||||
r1 = rf(segmentID, scope)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int64, querypb.DataScope) int); ok {
|
||||
r1 = rf(ctx, segmentID, scope)
|
||||
} else {
|
||||
r1 = ret.Get(1).(int)
|
||||
}
|
||||
|
@ -543,15 +548,16 @@ type MockSegmentManager_Remove_Call struct {
|
|||
}
|
||||
|
||||
// Remove is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - segmentID int64
|
||||
// - scope querypb.DataScope
|
||||
func (_e *MockSegmentManager_Expecter) Remove(segmentID interface{}, scope interface{}) *MockSegmentManager_Remove_Call {
|
||||
return &MockSegmentManager_Remove_Call{Call: _e.mock.On("Remove", segmentID, scope)}
|
||||
func (_e *MockSegmentManager_Expecter) Remove(ctx interface{}, segmentID interface{}, scope interface{}) *MockSegmentManager_Remove_Call {
|
||||
return &MockSegmentManager_Remove_Call{Call: _e.mock.On("Remove", ctx, segmentID, scope)}
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_Remove_Call) Run(run func(segmentID int64, scope querypb.DataScope)) *MockSegmentManager_Remove_Call {
|
||||
func (_c *MockSegmentManager_Remove_Call) Run(run func(ctx context.Context, segmentID int64, scope querypb.DataScope)) *MockSegmentManager_Remove_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64), args[1].(querypb.DataScope))
|
||||
run(args[0].(context.Context), args[1].(int64), args[2].(querypb.DataScope))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -561,34 +567,35 @@ func (_c *MockSegmentManager_Remove_Call) Return(_a0 int, _a1 int) *MockSegmentM
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_Remove_Call) RunAndReturn(run func(int64, querypb.DataScope) (int, int)) *MockSegmentManager_Remove_Call {
|
||||
func (_c *MockSegmentManager_Remove_Call) RunAndReturn(run func(context.Context, int64, querypb.DataScope) (int, int)) *MockSegmentManager_Remove_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveBy provides a mock function with given fields: filters
|
||||
func (_m *MockSegmentManager) RemoveBy(filters ...SegmentFilter) (int, int) {
|
||||
// RemoveBy provides a mock function with given fields: ctx, filters
|
||||
func (_m *MockSegmentManager) RemoveBy(ctx context.Context, filters ...SegmentFilter) (int, int) {
|
||||
_va := make([]interface{}, len(filters))
|
||||
for _i := range filters {
|
||||
_va[_i] = filters[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 int
|
||||
var r1 int
|
||||
if rf, ok := ret.Get(0).(func(...SegmentFilter) (int, int)); ok {
|
||||
return rf(filters...)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, ...SegmentFilter) (int, int)); ok {
|
||||
return rf(ctx, filters...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(...SegmentFilter) int); ok {
|
||||
r0 = rf(filters...)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, ...SegmentFilter) int); ok {
|
||||
r0 = rf(ctx, filters...)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(...SegmentFilter) int); ok {
|
||||
r1 = rf(filters...)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, ...SegmentFilter) int); ok {
|
||||
r1 = rf(ctx, filters...)
|
||||
} else {
|
||||
r1 = ret.Get(1).(int)
|
||||
}
|
||||
|
@ -602,21 +609,22 @@ type MockSegmentManager_RemoveBy_Call struct {
|
|||
}
|
||||
|
||||
// RemoveBy is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - filters ...SegmentFilter
|
||||
func (_e *MockSegmentManager_Expecter) RemoveBy(filters ...interface{}) *MockSegmentManager_RemoveBy_Call {
|
||||
func (_e *MockSegmentManager_Expecter) RemoveBy(ctx interface{}, filters ...interface{}) *MockSegmentManager_RemoveBy_Call {
|
||||
return &MockSegmentManager_RemoveBy_Call{Call: _e.mock.On("RemoveBy",
|
||||
append([]interface{}{}, filters...)...)}
|
||||
append([]interface{}{ctx}, filters...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_RemoveBy_Call) Run(run func(filters ...SegmentFilter)) *MockSegmentManager_RemoveBy_Call {
|
||||
func (_c *MockSegmentManager_RemoveBy_Call) Run(run func(ctx context.Context, filters ...SegmentFilter)) *MockSegmentManager_RemoveBy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]SegmentFilter, len(args)-0)
|
||||
for i, a := range args[0:] {
|
||||
variadicArgs := make([]SegmentFilter, len(args)-1)
|
||||
for i, a := range args[1:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(SegmentFilter)
|
||||
}
|
||||
}
|
||||
run(variadicArgs...)
|
||||
run(args[0].(context.Context), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -626,7 +634,7 @@ func (_c *MockSegmentManager_RemoveBy_Call) Return(_a0 int, _a1 int) *MockSegmen
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegmentManager_RemoveBy_Call) RunAndReturn(run func(...SegmentFilter) (int, int)) *MockSegmentManager_RemoveBy_Call {
|
||||
func (_c *MockSegmentManager_RemoveBy_Call) RunAndReturn(run func(context.Context, ...SegmentFilter) (int, int)) *MockSegmentManager_RemoveBy_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -39,8 +39,11 @@ type SearchResult struct {
|
|||
cSearchResult C.CSearchResult
|
||||
}
|
||||
|
||||
// searchResultDataBlobs is the CSearchResultsDataBlobs in C++
|
||||
type searchResultDataBlobs = C.CSearchResultDataBlobs
|
||||
// SearchResultDataBlobs is the CSearchResultsDataBlobs in C++
|
||||
type (
|
||||
SearchResultDataBlobs = C.CSearchResultDataBlobs
|
||||
StreamSearchReducer = C.CSearchStreamReducer
|
||||
)
|
||||
|
||||
// RetrieveResult contains a pointer to the retrieve result in C++ memory
|
||||
type RetrieveResult struct {
|
||||
|
@ -71,9 +74,58 @@ func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *S
|
|||
return sInfo
|
||||
}
|
||||
|
||||
func NewStreamReducer(ctx context.Context,
|
||||
plan *SearchPlan,
|
||||
sliceNQs []int64,
|
||||
sliceTopKs []int64,
|
||||
) (StreamSearchReducer, error) {
|
||||
if plan.cSearchPlan == nil {
|
||||
return nil, fmt.Errorf("nil search plan")
|
||||
}
|
||||
if len(sliceNQs) == 0 {
|
||||
return nil, fmt.Errorf("empty slice nqs is not allowed")
|
||||
}
|
||||
if len(sliceNQs) != len(sliceTopKs) {
|
||||
return nil, fmt.Errorf("unaligned sliceNQs(len=%d) and sliceTopKs(len=%d)", len(sliceNQs), len(sliceTopKs))
|
||||
}
|
||||
cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0])
|
||||
cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0])
|
||||
cNumSlices := C.int64_t(len(sliceNQs))
|
||||
|
||||
var streamReducer StreamSearchReducer
|
||||
status := C.NewStreamReducer(plan.cSearchPlan, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices, &streamReducer)
|
||||
if err := HandleCStatus(ctx, &status, "MergeSearchResultsWithOutputFields failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return streamReducer, nil
|
||||
}
|
||||
|
||||
func StreamReduceSearchResult(ctx context.Context,
|
||||
newResult *SearchResult, streamReducer StreamSearchReducer,
|
||||
) error {
|
||||
cSearchResults := make([]C.CSearchResult, 0)
|
||||
cSearchResults = append(cSearchResults, newResult.cSearchResult)
|
||||
cSearchResultPtr := &cSearchResults[0]
|
||||
|
||||
status := C.StreamReduce(streamReducer, cSearchResultPtr, 1)
|
||||
if err := HandleCStatus(ctx, &status, "StreamReduceSearchResult failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetStreamReduceResult(ctx context.Context, streamReducer StreamSearchReducer) (SearchResultDataBlobs, error) {
|
||||
var cSearchResultDataBlobs SearchResultDataBlobs
|
||||
status := C.GetStreamReduceResult(streamReducer, &cSearchResultDataBlobs)
|
||||
if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cSearchResultDataBlobs, nil
|
||||
}
|
||||
|
||||
func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searchResults []*SearchResult,
|
||||
numSegments int64, sliceNQs []int64, sliceTopKs []int64,
|
||||
) (searchResultDataBlobs, error) {
|
||||
) (SearchResultDataBlobs, error) {
|
||||
if plan.cSearchPlan == nil {
|
||||
return nil, fmt.Errorf("nil search plan")
|
||||
}
|
||||
|
@ -98,7 +150,7 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc
|
|||
cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0])
|
||||
cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0])
|
||||
cNumSlices := C.int64_t(len(sliceNQs))
|
||||
var cSearchResultDataBlobs searchResultDataBlobs
|
||||
var cSearchResultDataBlobs SearchResultDataBlobs
|
||||
status := C.ReduceSearchResultsAndFillData(&cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
|
||||
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)
|
||||
if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil {
|
||||
|
@ -107,7 +159,7 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc
|
|||
return cSearchResultDataBlobs, nil
|
||||
}
|
||||
|
||||
func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs searchResultDataBlobs, blobIndex int) ([]byte, error) {
|
||||
func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs SearchResultDataBlobs, blobIndex int) ([]byte, error) {
|
||||
var blob C.CProto
|
||||
status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex))
|
||||
if err := HandleCStatus(ctx, &status, "marshal failed"); err != nil {
|
||||
|
@ -116,10 +168,14 @@ func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs searchR
|
|||
return GetCProtoBlob(&blob), nil
|
||||
}
|
||||
|
||||
func DeleteSearchResultDataBlobs(cSearchResultDataBlobs searchResultDataBlobs) {
|
||||
func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) {
|
||||
C.DeleteSearchResultDataBlobs(cSearchResultDataBlobs)
|
||||
}
|
||||
|
||||
func DeleteStreamReduceHelper(cStreamReduceHelper StreamSearchReducer) {
|
||||
C.DeleteStreamSearchReducer(cStreamReduceHelper)
|
||||
}
|
||||
|
||||
func DeleteSearchResults(results []*SearchResult) {
|
||||
if len(results) == 0 {
|
||||
return
|
||||
|
|
|
@ -83,6 +83,7 @@ func (suite *ReduceSuite) SetupTest() {
|
|||
CollectionID: suite.collectionID,
|
||||
PartitionID: suite.partitionID,
|
||||
InsertChannel: "dml",
|
||||
NumOfRows: int64(msgLength),
|
||||
Level: datapb.SegmentLevel_Legacy,
|
||||
},
|
||||
)
|
||||
|
@ -104,7 +105,7 @@ func (suite *ReduceSuite) SetupTest() {
|
|||
}
|
||||
|
||||
func (suite *ReduceSuite) TearDownTest() {
|
||||
suite.segment.Release()
|
||||
suite.segment.Release(context.Background())
|
||||
DeleteCollection(suite.collection)
|
||||
ctx := context.Background()
|
||||
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
|
||||
|
|
|
@ -520,16 +520,18 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
|
|||
|
||||
selected := make([]int, 0, ret.GetAllRetrieveCount())
|
||||
|
||||
var limit int = -1
|
||||
if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
|
||||
loopEnd = int(param.limit)
|
||||
limit = int(param.limit)
|
||||
}
|
||||
|
||||
idSet := make(map[interface{}]struct{})
|
||||
cursors := make([]int64, len(validRetrieveResults))
|
||||
|
||||
var availableCount int
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
for j := 0; j < loopEnd; j++ {
|
||||
for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ {
|
||||
sel, drainOneResult := typeutil.SelectMinPK(param.limit, validRetrieveResults, cursors)
|
||||
if sel == -1 || (param.mergeStopForBest && drainOneResult) {
|
||||
break
|
||||
|
@ -542,6 +544,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
|
|||
selectedOffsets[sel] = append(selectedOffsets[sel], validRetrieveResults[sel].GetOffset()[cursors[sel]])
|
||||
selectedIndexes[sel] = append(selectedIndexes[sel], cursors[sel])
|
||||
idSet[pk] = struct{}{}
|
||||
availableCount++
|
||||
} else {
|
||||
// primary keys duplicate
|
||||
skipDupCnt++
|
||||
|
|
|
@ -20,8 +20,10 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
@ -44,11 +46,7 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
|
|||
result *segcorepb.RetrieveResults
|
||||
segment Segment
|
||||
}
|
||||
var (
|
||||
resultCh = make(chan segmentResult, len(segments))
|
||||
errs = make([]error, len(segments))
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
resultCh := make(chan segmentResult, len(segments))
|
||||
|
||||
plan.ignoreNonPk = len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk()
|
||||
|
||||
|
@ -57,25 +55,29 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
|
|||
label = metrics.GrowingSegmentLabel
|
||||
}
|
||||
|
||||
retriever := func(s Segment) error {
|
||||
retriever := func(ctx context.Context, s Segment) error {
|
||||
tr := timerecord.NewTimeRecorder("retrieveOnSegments")
|
||||
result, err := s.Retrieve(ctx, plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultCh <- segmentResult{
|
||||
result,
|
||||
s,
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.QueryLabel, label).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, segment := range segments {
|
||||
wg.Add(1)
|
||||
go func(seg Segment, i int) {
|
||||
defer wg.Done()
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
for _, segment := range segments {
|
||||
seg := segment
|
||||
errGroup.Go(func() error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// record search time and cache miss
|
||||
var err error
|
||||
accessRecord := metricsutil.NewQuerySegmentAccessRecord(getSegmentMetricLabel(seg))
|
||||
|
@ -84,26 +86,26 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
|
|||
}()
|
||||
|
||||
if seg.IsLazyLoad() {
|
||||
var timeout time.Duration
|
||||
timeout, err = lazyloadWaitTimeout(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var missing bool
|
||||
missing, err = mgr.DiskCache.Do(seg.ID(), retriever)
|
||||
missing, err = mgr.DiskCache.DoWait(ctx, seg.ID(), timeout, retriever)
|
||||
if missing {
|
||||
accessRecord.CacheMissing()
|
||||
}
|
||||
} else {
|
||||
err = retriever(seg)
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
errs[i] = err
|
||||
}
|
||||
}(segment, i)
|
||||
return retriever(ctx, seg)
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
err := errGroup.Wait()
|
||||
close(resultCh)
|
||||
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var retrieveSegments []Segment
|
||||
|
@ -163,9 +165,12 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
|
|||
|
||||
// retrieve will retrieve all the validate target segments
|
||||
func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]*segcorepb.RetrieveResults, []Segment, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
|
||||
var err error
|
||||
var SegType commonpb.SegmentState
|
||||
var retrieveResults []*segcorepb.RetrieveResults
|
||||
var retrieveSegments []Segment
|
||||
|
||||
segIDs := req.GetSegmentIDs()
|
||||
|
@ -181,7 +186,7 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return retrieveResults, retrieveSegments, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return retrieveOnSegments(ctx, manager, retrieveSegments, SegType, plan, req)
|
||||
|
|
|
@ -92,6 +92,7 @@ func (suite *RetrieveSuite) SetupTest() {
|
|||
CollectionID: suite.collectionID,
|
||||
PartitionID: suite.partitionID,
|
||||
InsertChannel: "dml",
|
||||
NumOfRows: int64(msgLength),
|
||||
Level: datapb.SegmentLevel_Legacy,
|
||||
},
|
||||
)
|
||||
|
@ -132,13 +133,13 @@ func (suite *RetrieveSuite) SetupTest() {
|
|||
err = suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
suite.manager.Segment.Put(SegmentTypeSealed, suite.sealed)
|
||||
suite.manager.Segment.Put(SegmentTypeGrowing, suite.growing)
|
||||
suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed)
|
||||
suite.manager.Segment.Put(context.Background(), SegmentTypeGrowing, suite.growing)
|
||||
}
|
||||
|
||||
func (suite *RetrieveSuite) TearDownTest() {
|
||||
suite.sealed.Release()
|
||||
suite.growing.Release()
|
||||
suite.sealed.Release(context.Background())
|
||||
suite.growing.Release(context.Background())
|
||||
DeleteCollection(suite.collection)
|
||||
ctx := context.Background()
|
||||
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
|
||||
|
@ -249,7 +250,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
|
|||
plan, err := genSimpleRetrievePlan(suite.collection)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.sealed.Release()
|
||||
suite.sealed.Release(context.Background())
|
||||
req := &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
CollectionID: suite.collectionID,
|
||||
|
|
|
@ -20,13 +20,18 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/params"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
)
|
||||
|
@ -34,30 +39,20 @@ import (
|
|||
// searchOnSegments performs search on listed segments
|
||||
// all segment ids are validated before calling this function
|
||||
func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segType SegmentType, searchReq *SearchRequest) ([]*SearchResult, error) {
|
||||
var (
|
||||
// results variables
|
||||
resultCh = make(chan *SearchResult, len(segments))
|
||||
errs = make([]error, len(segments))
|
||||
wg sync.WaitGroup
|
||||
|
||||
// For log only
|
||||
mu sync.Mutex
|
||||
segmentsWithoutIndex []int64
|
||||
)
|
||||
|
||||
searchLabel := metrics.SealedSegmentLabel
|
||||
if segType == commonpb.SegmentState_Growing {
|
||||
searchLabel = metrics.GrowingSegmentLabel
|
||||
}
|
||||
|
||||
searcher := func(s Segment) error {
|
||||
resultCh := make(chan *SearchResult, len(segments))
|
||||
searcher := func(ctx context.Context, s Segment) error {
|
||||
// record search time
|
||||
tr := timerecord.NewTimeRecorder("searchOnSegments")
|
||||
searchResult, err := s.Search(ctx, searchReq)
|
||||
resultCh <- searchResult
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultCh <- searchResult
|
||||
// update metrics
|
||||
elapsed := tr.ElapseSpan().Milliseconds()
|
||||
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
|
@ -68,35 +63,41 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
|
|||
}
|
||||
|
||||
// calling segment search in goroutines
|
||||
for i, segment := range segments {
|
||||
wg.Add(1)
|
||||
go func(seg Segment, i int) {
|
||||
defer wg.Done()
|
||||
if !seg.ExistIndex(searchReq.searchFieldID) {
|
||||
mu.Lock()
|
||||
segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID())
|
||||
mu.Unlock()
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
segmentsWithoutIndex := make([]int64, 0)
|
||||
for _, segment := range segments {
|
||||
seg := segment
|
||||
if !seg.ExistIndex(searchReq.searchFieldID) {
|
||||
segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID())
|
||||
}
|
||||
errGroup.Go(func() error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var err error
|
||||
accessRecord := metricsutil.NewSearchSegmentAccessRecord(getSegmentMetricLabel(seg))
|
||||
defer func() {
|
||||
accessRecord.Finish(err)
|
||||
}()
|
||||
|
||||
if seg.IsLazyLoad() {
|
||||
var timeout time.Duration
|
||||
timeout, err = lazyloadWaitTimeout(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var missing bool
|
||||
missing, err = mgr.DiskCache.Do(seg.ID(), searcher)
|
||||
missing, err = mgr.DiskCache.DoWait(ctx, seg.ID(), timeout, searcher)
|
||||
if missing {
|
||||
accessRecord.CacheMissing()
|
||||
}
|
||||
} else {
|
||||
err = searcher(seg)
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
errs[i] = err
|
||||
}
|
||||
}(segment, i)
|
||||
return searcher(ctx, seg)
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
err := errGroup.Wait()
|
||||
close(resultCh)
|
||||
|
||||
searchResults := make([]*SearchResult, 0, len(segments))
|
||||
|
@ -104,11 +105,9 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
|
|||
searchResults = append(searchResults, result)
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
DeleteSearchResults(searchResults)
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
DeleteSearchResults(searchResults)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(segmentsWithoutIndex) > 0 {
|
||||
|
@ -118,11 +117,115 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
|
|||
return searchResults, nil
|
||||
}
|
||||
|
||||
// searchSegmentsStreamly performs search on listed segments in a stream mode instead of a batch mode
|
||||
// all segment ids are validated before calling this function
|
||||
func searchSegmentsStreamly(ctx context.Context,
|
||||
mgr *Manager,
|
||||
segments []Segment,
|
||||
searchReq *SearchRequest,
|
||||
streamReduce func(result *SearchResult) error,
|
||||
) error {
|
||||
searchLabel := metrics.SealedSegmentLabel
|
||||
searchResultsToClear := make([]*SearchResult, 0)
|
||||
var reduceMutex sync.Mutex
|
||||
var sumReduceDuration atomic.Duration
|
||||
searcher := func(ctx context.Context, seg Segment) error {
|
||||
// record search time
|
||||
tr := timerecord.NewTimeRecorder("searchOnSegments")
|
||||
searchResult, searchErr := seg.Search(ctx, searchReq)
|
||||
searchDuration := tr.RecordSpan().Milliseconds()
|
||||
if searchErr != nil {
|
||||
return searchErr
|
||||
}
|
||||
reduceMutex.Lock()
|
||||
searchResultsToClear = append(searchResultsToClear, searchResult)
|
||||
reducedErr := streamReduce(searchResult)
|
||||
reduceMutex.Unlock()
|
||||
reduceDuration := tr.RecordSpan()
|
||||
if reducedErr != nil {
|
||||
return reducedErr
|
||||
}
|
||||
sumReduceDuration.Add(reduceDuration)
|
||||
// update metrics
|
||||
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel, searchLabel).Observe(float64(searchDuration))
|
||||
metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel, searchLabel).Observe(float64(searchDuration) / float64(searchReq.getNumOfQuery()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// calling segment search in goroutines
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
log := log.Ctx(ctx)
|
||||
for _, segment := range segments {
|
||||
seg := segment
|
||||
errGroup.Go(func() error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var err error
|
||||
accessRecord := metricsutil.NewSearchSegmentAccessRecord(getSegmentMetricLabel(seg))
|
||||
defer func() {
|
||||
accessRecord.Finish(err)
|
||||
}()
|
||||
if seg.IsLazyLoad() {
|
||||
log.Debug("before doing stream search in DiskCache", zap.Int64("segID", seg.ID()))
|
||||
var timeout time.Duration
|
||||
timeout, err = lazyloadWaitTimeout(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var missing bool
|
||||
missing, err = mgr.DiskCache.DoWait(ctx, seg.ID(), timeout, searcher)
|
||||
if missing {
|
||||
accessRecord.CacheMissing()
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("failed to do search for disk cache", zap.Int64("seg_id", seg.ID()), zap.Error(err))
|
||||
}
|
||||
log.Debug("after doing stream search in DiskCache", zap.Int64("segID", seg.ID()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return searcher(ctx, seg)
|
||||
})
|
||||
}
|
||||
err := errGroup.Wait()
|
||||
DeleteSearchResults(searchResultsToClear)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel,
|
||||
metrics.ReduceSegments,
|
||||
metrics.StreamReduce).Observe(float64(sumReduceDuration.Load().Milliseconds()))
|
||||
log.Debug("stream reduce sum duration:", zap.Duration("duration", sumReduceDuration.Load()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func lazyloadWaitTimeout(ctx context.Context) (time.Duration, error) {
|
||||
timeout := params.Params.QueryNodeCfg.LazyLoadWaitTimeout.GetAsDuration(time.Millisecond)
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
remain := time.Until(deadline)
|
||||
if remain <= 0 {
|
||||
return -1, merr.WrapErrServiceInternal("search context deadline exceeded")
|
||||
} else if remain < timeout {
|
||||
timeout = remain
|
||||
}
|
||||
}
|
||||
return timeout, nil
|
||||
}
|
||||
|
||||
// search will search on the historical segments the target segments in historical.
|
||||
// if segIDs is not specified, it will search on all the historical segments speficied by partIDs.
|
||||
// if segIDs is specified, it will only search on the segments specified by the segIDs.
|
||||
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
|
||||
func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
|
||||
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -134,6 +237,10 @@ func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRe
|
|||
// searchStreaming will search all the target segments in streaming
|
||||
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
|
||||
func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
|
||||
segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -141,3 +248,21 @@ func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchReq
|
|||
searchResults, err := searchSegments(ctx, manager, segments, SegmentTypeGrowing, searchReq)
|
||||
return searchResults, segments, err
|
||||
}
|
||||
|
||||
func SearchHistoricalStreamly(ctx context.Context, manager *Manager, searchReq *SearchRequest,
|
||||
collID int64, partIDs []int64, segIDs []int64, streamReduce func(result *SearchResult) error,
|
||||
) ([]Segment, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
|
||||
if err != nil {
|
||||
return segments, err
|
||||
}
|
||||
err = searchSegmentsStreamly(ctx, manager, segments, searchReq, streamReduce)
|
||||
if err != nil {
|
||||
return segments, err
|
||||
}
|
||||
return segments, nil
|
||||
}
|
||||
|
|
|
@ -83,6 +83,7 @@ func (suite *SearchSuite) SetupTest() {
|
|||
CollectionID: suite.collectionID,
|
||||
PartitionID: suite.partitionID,
|
||||
InsertChannel: "dml",
|
||||
NumOfRows: int64(msgLength),
|
||||
Level: datapb.SegmentLevel_Legacy,
|
||||
},
|
||||
)
|
||||
|
@ -122,12 +123,12 @@ func (suite *SearchSuite) SetupTest() {
|
|||
suite.Require().NoError(err)
|
||||
suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord)
|
||||
|
||||
suite.manager.Segment.Put(SegmentTypeSealed, suite.sealed)
|
||||
suite.manager.Segment.Put(SegmentTypeGrowing, suite.growing)
|
||||
suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed)
|
||||
suite.manager.Segment.Put(context.Background(), SegmentTypeGrowing, suite.growing)
|
||||
}
|
||||
|
||||
func (suite *SearchSuite) TearDownTest() {
|
||||
suite.sealed.Release()
|
||||
suite.sealed.Release(context.Background())
|
||||
DeleteCollection(suite.collection)
|
||||
ctx := context.Background()
|
||||
suite.chunkManager.RemoveWithPrefix(ctx, paramtable.Get().MinioCfg.RootPath.GetValue())
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
|
@ -47,6 +48,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/params"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
|
@ -73,50 +75,59 @@ var ErrSegmentUnhealthy = errors.New("segment unhealthy")
|
|||
type IndexedFieldInfo struct {
|
||||
FieldBinlog *datapb.FieldBinlog
|
||||
IndexInfo *querypb.FieldIndexInfo
|
||||
LazyLoad bool
|
||||
IsLoaded bool
|
||||
}
|
||||
|
||||
type baseSegment struct {
|
||||
collection *Collection
|
||||
version *atomic.Int64
|
||||
|
||||
// the load status of the segment,
|
||||
// only transitions below are allowed:
|
||||
// 1. LoadStatusMeta <-> LoadStatusMapped
|
||||
// 2. LoadStatusMeta <-> LoadStatusInMemory
|
||||
loadStatus *atomic.String
|
||||
segmentType SegmentType
|
||||
bloomFilterSet *pkoracle.BloomFilterSet
|
||||
loadInfo *querypb.SegmentLoadInfo
|
||||
loadInfo *atomic.Pointer[querypb.SegmentLoadInfo]
|
||||
isLazyLoad bool
|
||||
|
||||
resourceUsageCache *atomic.Pointer[ResourceUsage]
|
||||
|
||||
needUpdatedVersion *atomic.Int64 // only for lazy load mode update index
|
||||
}
|
||||
|
||||
func newBaseSegment(collection *Collection, segmentType SegmentType, version int64, loadInfo *querypb.SegmentLoadInfo) baseSegment {
|
||||
return baseSegment{
|
||||
collection: collection,
|
||||
loadInfo: loadInfo,
|
||||
version: atomic.NewInt64(version),
|
||||
loadStatus: atomic.NewString(string(LoadStatusMeta)),
|
||||
segmentType: segmentType,
|
||||
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
|
||||
|
||||
collection: collection,
|
||||
loadInfo: atomic.NewPointer[querypb.SegmentLoadInfo](loadInfo),
|
||||
version: atomic.NewInt64(version),
|
||||
isLazyLoad: isLazyLoad(collection, segmentType),
|
||||
segmentType: segmentType,
|
||||
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
|
||||
resourceUsageCache: atomic.NewPointer[ResourceUsage](nil),
|
||||
needUpdatedVersion: atomic.NewInt64(0),
|
||||
}
|
||||
}
|
||||
|
||||
// isLazyLoad checks if the segment is lazy load
|
||||
func isLazyLoad(collection *Collection, segmentType SegmentType) bool {
|
||||
return segmentType == SegmentTypeSealed && // only sealed segment enable lazy load
|
||||
(common.IsCollectionLazyLoadEnabled(collection.Schema().Properties...) || // collection level lazy load
|
||||
(!common.HasLazyload(collection.Schema().Properties) &&
|
||||
params.Params.QueryNodeCfg.LazyLoadEnabled.GetAsBool())) // global level lazy load
|
||||
}
|
||||
|
||||
// ID returns the identity number.
|
||||
func (s *baseSegment) ID() int64 {
|
||||
return s.loadInfo.GetSegmentID()
|
||||
return s.loadInfo.Load().GetSegmentID()
|
||||
}
|
||||
|
||||
func (s *baseSegment) Collection() int64 {
|
||||
return s.loadInfo.GetCollectionID()
|
||||
return s.loadInfo.Load().GetCollectionID()
|
||||
}
|
||||
|
||||
func (s *baseSegment) GetCollection() *Collection {
|
||||
return s.collection
|
||||
}
|
||||
|
||||
func (s *baseSegment) Partition() int64 {
|
||||
return s.loadInfo.GetPartitionID()
|
||||
return s.loadInfo.Load().GetPartitionID()
|
||||
}
|
||||
|
||||
func (s *baseSegment) DatabaseName() string {
|
||||
|
@ -128,7 +139,7 @@ func (s *baseSegment) ResourceGroup() string {
|
|||
}
|
||||
|
||||
func (s *baseSegment) Shard() string {
|
||||
return s.loadInfo.GetInsertChannel()
|
||||
return s.loadInfo.Load().GetInsertChannel()
|
||||
}
|
||||
|
||||
func (s *baseSegment) Type() SegmentType {
|
||||
|
@ -136,11 +147,11 @@ func (s *baseSegment) Type() SegmentType {
|
|||
}
|
||||
|
||||
func (s *baseSegment) Level() datapb.SegmentLevel {
|
||||
return s.loadInfo.GetLevel()
|
||||
return s.loadInfo.Load().GetLevel()
|
||||
}
|
||||
|
||||
func (s *baseSegment) StartPosition() *msgpb.MsgPosition {
|
||||
return s.loadInfo.GetStartPosition()
|
||||
return s.loadInfo.Load().GetStartPosition()
|
||||
}
|
||||
|
||||
func (s *baseSegment) Version() int64 {
|
||||
|
@ -151,16 +162,8 @@ func (s *baseSegment) CASVersion(old, newVersion int64) bool {
|
|||
return s.version.CompareAndSwap(old, newVersion)
|
||||
}
|
||||
|
||||
func (s *baseSegment) LoadStatus() LoadStatus {
|
||||
return LoadStatus(s.loadStatus.Load())
|
||||
}
|
||||
|
||||
func (s *baseSegment) LoadInfo() *querypb.SegmentLoadInfo {
|
||||
if s.segmentType == SegmentTypeGrowing {
|
||||
// Growing segment do not have load info.
|
||||
return nil
|
||||
}
|
||||
return s.loadInfo
|
||||
return s.loadInfo.Load()
|
||||
}
|
||||
|
||||
func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
||||
|
@ -185,7 +188,7 @@ func (s *baseSegment) ResourceUsageEstimate() ResourceUsage {
|
|||
return *cache
|
||||
}
|
||||
|
||||
usage, err := getResourceUsageEstimateOfSegment(s.collection.Schema(), s.loadInfo, resourceEstimateFactor{
|
||||
usage, err := getResourceUsageEstimateOfSegment(s.collection.Schema(), s.LoadInfo(), resourceEstimateFactor{
|
||||
memoryUsageFactor: 1.0,
|
||||
memoryIndexUsageFactor: 1.0,
|
||||
enableTempSegmentIndex: false,
|
||||
|
@ -200,7 +203,21 @@ func (s *baseSegment) ResourceUsageEstimate() ResourceUsage {
|
|||
return *usage
|
||||
}
|
||||
|
||||
func (s *baseSegment) IsLazyLoad() bool { return s.isLazyLoad }
|
||||
func (s *baseSegment) IsLazyLoad() bool {
|
||||
return s.isLazyLoad
|
||||
}
|
||||
|
||||
func (s *baseSegment) NeedUpdatedVersion() int64 {
|
||||
return s.needUpdatedVersion.Load()
|
||||
}
|
||||
|
||||
func (s *baseSegment) SetLoadInfo(loadInfo *querypb.SegmentLoadInfo) {
|
||||
s.loadInfo.Store(loadInfo)
|
||||
}
|
||||
|
||||
func (s *baseSegment) SetNeedUpdatedVersion(version int64) {
|
||||
s.needUpdatedVersion.Store(version)
|
||||
}
|
||||
|
||||
type FieldInfo struct {
|
||||
datapb.FieldBinlog
|
||||
|
@ -288,10 +305,9 @@ func NewSegment(ctx context.Context,
|
|||
insertCount: atomic.NewInt64(0),
|
||||
}
|
||||
|
||||
if segmentType != SegmentTypeSealed {
|
||||
segment.loadStatus.Store(string(LoadStatusInMemory))
|
||||
if err := segment.initializeSegment(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return segment, nil
|
||||
}
|
||||
|
||||
|
@ -355,11 +371,57 @@ func NewSegmentV2(
|
|||
insertCount: atomic.NewInt64(0),
|
||||
}
|
||||
|
||||
if segmentType != SegmentTypeSealed {
|
||||
segment.loadStatus.Store(string(LoadStatusInMemory))
|
||||
if err := segment.initializeSegment(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return segment, nil
|
||||
}
|
||||
|
||||
func (s *LocalSegment) initializeSegment() error {
|
||||
loadInfo := s.loadInfo.Load()
|
||||
indexedFieldInfos, fieldBinlogs := separateIndexAndBinlog(loadInfo)
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(s.collection.Schema())
|
||||
|
||||
for fieldID, info := range indexedFieldInfos {
|
||||
field, err := schemaHelper.GetFieldFromID(fieldID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
indexInfo := info.IndexInfo
|
||||
s.fieldIndexes.Insert(indexInfo.GetFieldID(), &IndexedFieldInfo{
|
||||
FieldBinlog: &datapb.FieldBinlog{
|
||||
FieldID: indexInfo.GetFieldID(),
|
||||
},
|
||||
IndexInfo: indexInfo,
|
||||
IsLoaded: false,
|
||||
})
|
||||
if !typeutil.IsVectorType(field.GetDataType()) && !s.HasRawData(fieldID) {
|
||||
s.fields.Insert(fieldID, &FieldInfo{
|
||||
FieldBinlog: *info.FieldBinlog,
|
||||
RowCount: loadInfo.GetNumOfRows(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return segment, nil
|
||||
for _, binlogs := range fieldBinlogs {
|
||||
s.fields.Insert(binlogs.FieldID, &FieldInfo{
|
||||
FieldBinlog: *binlogs,
|
||||
RowCount: loadInfo.GetNumOfRows(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update the insert count when initialize the segment and update the metrics.
|
||||
s.insertCount.Store(loadInfo.GetNumOfRows())
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
s.DatabaseName(),
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(s.Collection()),
|
||||
fmt.Sprint(s.Partition()),
|
||||
s.Type().String(),
|
||||
strconv.FormatInt(int64(len(s.Indexes())), 10),
|
||||
).Add(float64(loadInfo.GetNumOfRows()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PinIfNotReleased acquires the `ptrLock` and returns true if the pointer is valid
|
||||
|
@ -426,10 +488,6 @@ func (s *LocalSegment) LastDeltaTimestamp() uint64 {
|
|||
return s.lastDeltaTimestamp.Load()
|
||||
}
|
||||
|
||||
func (s *LocalSegment) addIndex(fieldID int64, info *IndexedFieldInfo) {
|
||||
s.fieldIndexes.Insert(fieldID, info)
|
||||
}
|
||||
|
||||
func (s *LocalSegment) GetIndex(fieldID int64) *IndexedFieldInfo {
|
||||
info, _ := s.fieldIndexes.Get(fieldID)
|
||||
return info
|
||||
|
@ -464,7 +522,7 @@ func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
|
|||
|
||||
func (s *LocalSegment) ResetIndexesLazyLoad(lazyState bool) {
|
||||
for _, indexInfo := range s.Indexes() {
|
||||
indexInfo.LazyLoad = lazyState
|
||||
indexInfo.IsLoaded = lazyState
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -720,6 +778,15 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
|
|||
}
|
||||
|
||||
s.insertCount.Add(int64(numOfRow))
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
s.DatabaseName(),
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(s.Collection()),
|
||||
fmt.Sprint(s.Partition()),
|
||||
s.Type().String(),
|
||||
strconv.FormatInt(int64(len(s.Indexes())), 10),
|
||||
).Add(float64(numOfRow))
|
||||
|
||||
s.rowNum.Store(-1)
|
||||
s.memSize.Store(-1)
|
||||
return nil
|
||||
|
@ -801,7 +868,11 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary
|
|||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------- interfaces for sealed segment
|
||||
func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error {
|
||||
func (s *LocalSegment) LoadMultiFieldData(ctx context.Context) error {
|
||||
loadInfo := s.loadInfo.Load()
|
||||
rowCount := loadInfo.GetNumOfRows()
|
||||
fields := loadInfo.GetBinlogPaths()
|
||||
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
|
@ -859,7 +930,6 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, f
|
|||
return err
|
||||
}
|
||||
|
||||
s.insertCount.Store(rowCount)
|
||||
log.Info("load mutil field done",
|
||||
zap.Int64("row count", rowCount),
|
||||
zap.Int64("segmentID", s.ID()))
|
||||
|
@ -867,43 +937,7 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, f
|
|||
return nil
|
||||
}
|
||||
|
||||
type loadOptions struct {
|
||||
LoadStatus LoadStatus
|
||||
}
|
||||
|
||||
func newLoadOptions() *loadOptions {
|
||||
return &loadOptions{
|
||||
LoadStatus: LoadStatusInMemory,
|
||||
}
|
||||
}
|
||||
|
||||
type loadOption func(*loadOptions)
|
||||
|
||||
func WithLoadStatus(loadStatus LoadStatus) loadOption {
|
||||
return func(options *loadOptions) {
|
||||
options.LoadStatus = loadStatus
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCount int64, field *datapb.FieldBinlog, opts ...loadOption) error {
|
||||
options := newLoadOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if field != nil {
|
||||
s.fields.Insert(fieldID, &FieldInfo{
|
||||
FieldBinlog: *field,
|
||||
RowCount: rowCount,
|
||||
})
|
||||
}
|
||||
|
||||
if options.LoadStatus == LoadStatusMeta {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.loadStatus.Store(string(options.LoadStatus))
|
||||
|
||||
func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCount int64, field *datapb.FieldBinlog) error {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
|
@ -941,7 +975,9 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
|
|||
}
|
||||
}
|
||||
|
||||
mmapEnabled := options.LoadStatus == LoadStatusMapped
|
||||
collection := s.collection
|
||||
mmapEnabled := common.IsFieldMmapEnabled(collection.Schema(), fieldID) ||
|
||||
(!common.FieldHasMmapKey(collection.Schema(), fieldID) && params.Params.QueryNodeCfg.MmapEnabled.GetAsBool())
|
||||
loadFieldDataInfo.appendMMapDirPath(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue())
|
||||
loadFieldDataInfo.enableMmap(fieldID, mmapEnabled)
|
||||
|
||||
|
@ -970,7 +1006,6 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
|
|||
return err
|
||||
}
|
||||
|
||||
s.insertCount.Store(rowCount)
|
||||
log.Info("load field done")
|
||||
|
||||
return nil
|
||||
|
@ -1194,12 +1229,7 @@ func (s *LocalSegment) LoadDeltaData(ctx context.Context, deltaData *storage.Del
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIndexInfo, fieldType schemapb.DataType, opts ...loadOption) error {
|
||||
options := newLoadOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIndexInfo, fieldType schemapb.DataType) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
zap.Int64("partitionID", s.Partition()),
|
||||
|
@ -1208,19 +1238,9 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn
|
|||
zap.Int64("indexID", indexInfo.GetIndexID()),
|
||||
)
|
||||
|
||||
if options.LoadStatus == LoadStatusMeta {
|
||||
s.addIndex(indexInfo.GetFieldID(), &IndexedFieldInfo{
|
||||
FieldBinlog: &datapb.FieldBinlog{
|
||||
FieldID: indexInfo.GetFieldID(),
|
||||
},
|
||||
IndexInfo: indexInfo,
|
||||
LazyLoad: true,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
old := s.GetIndex(indexInfo.GetFieldID())
|
||||
// the index loaded
|
||||
if old != nil && old.IndexInfo.GetIndexID() == indexInfo.GetIndexID() && !old.LazyLoad {
|
||||
if old != nil && old.IndexInfo.GetIndexID() == indexInfo.GetIndexID() && old.IsLoaded {
|
||||
log.Warn("index already loaded")
|
||||
return nil
|
||||
}
|
||||
|
@ -1228,20 +1248,23 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn
|
|||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, fmt.Sprintf("LoadIndex-%d-%d", s.ID(), indexInfo.GetFieldID()))
|
||||
defer sp.End()
|
||||
|
||||
tr := timerecord.NewTimeRecorder("loadIndex")
|
||||
// 1.
|
||||
loadIndexInfo, err := newLoadIndexInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer deleteLoadIndexInfo(loadIndexInfo)
|
||||
|
||||
if paramtable.Get().CommonCfg.EnableStorageV2.GetAsBool() {
|
||||
uri, err := typeutil_internal.GetStorageURI(paramtable.Get().CommonCfg.StorageScheme.GetValue(), paramtable.Get().CommonCfg.StoragePathPrefix.GetValue(), s.ID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadIndexInfo.appendStorageInfo(uri, indexInfo.IndexStoreVersion)
|
||||
}
|
||||
newLoadIndexInfoSpan := tr.RecordSpan()
|
||||
|
||||
// 2.
|
||||
err = loadIndexInfo.appendLoadIndexInfo(ctx, indexInfo, s.Collection(), s.Partition(), s.ID(), fieldType)
|
||||
if err != nil {
|
||||
if loadIndexInfo.cleanLocalData(ctx) != nil {
|
||||
|
@ -1255,16 +1278,27 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn
|
|||
errMsg := fmt.Sprintln("updateSegmentIndex failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID())
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
appendLoadIndexInfoSpan := tr.RecordSpan()
|
||||
|
||||
// 3.
|
||||
err = s.UpdateIndexInfo(ctx, indexInfo, loadIndexInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateIndexInfoSpan := tr.RecordSpan()
|
||||
if !typeutil.IsVectorType(fieldType) || s.HasRawData(indexInfo.GetFieldID()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 4.
|
||||
s.WarmupChunkCache(ctx, indexInfo.GetFieldID())
|
||||
warmupChunkCacheSpan := tr.RecordSpan()
|
||||
log.Info("Finish loading index",
|
||||
zap.Duration("newLoadIndexInfoSpan", newLoadIndexInfoSpan),
|
||||
zap.Duration("appendLoadIndexInfoSpan", appendLoadIndexInfoSpan),
|
||||
zap.Duration("updateIndexInfoSpan", updateIndexInfoSpan),
|
||||
zap.Duration("updateIndexInfoSpan", warmupChunkCacheSpan),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1294,12 +1328,12 @@ func (s *LocalSegment) UpdateIndexInfo(ctx context.Context, indexInfo *querypb.F
|
|||
return err
|
||||
}
|
||||
|
||||
s.addIndex(indexInfo.GetFieldID(), &IndexedFieldInfo{
|
||||
s.fieldIndexes.Insert(indexInfo.GetFieldID(), &IndexedFieldInfo{
|
||||
FieldBinlog: &datapb.FieldBinlog{
|
||||
FieldID: indexInfo.GetFieldID(),
|
||||
},
|
||||
IndexInfo: indexInfo,
|
||||
LazyLoad: false,
|
||||
IsLoaded: true,
|
||||
})
|
||||
log.Info("updateSegmentIndex done")
|
||||
return nil
|
||||
|
@ -1399,7 +1433,7 @@ func WithReleaseScope(scope ReleaseScope) releaseOption {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *LocalSegment) Release(opts ...releaseOption) {
|
||||
func (s *LocalSegment) Release(ctx context.Context, opts ...releaseOption) {
|
||||
options := newReleaseOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
|
@ -1411,7 +1445,7 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
|
|||
// release will never fail
|
||||
defer stateLockGuard.Done(nil)
|
||||
|
||||
log := log.With(zap.Int64("collectionID", s.Collection()),
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", s.Collection()),
|
||||
zap.Int64("partitionID", s.Partition()),
|
||||
zap.Int64("segmentID", s.ID()),
|
||||
zap.String("segmentType", s.segmentType.String()),
|
||||
|
@ -1421,10 +1455,8 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
|
|||
// wait all read ops finished
|
||||
ptr := s.ptr
|
||||
if options.Scope == ReleaseScopeData {
|
||||
s.loadStatus.Store(string(LoadStatusMeta))
|
||||
C.ClearSegmentData(ptr)
|
||||
s.ResetIndexesLazyLoad(true)
|
||||
log.Debug("release segment data done and the field indexes info has been set lazy load=true")
|
||||
s.ReleaseSegmentData()
|
||||
log.Info("release segment data done and the field indexes info has been set lazy load=true")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1439,6 +1471,14 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
|
|||
log.Info("delete segment from memory")
|
||||
}
|
||||
|
||||
// ReleaseSegmentData releases the segment data.
|
||||
func (s *LocalSegment) ReleaseSegmentData() {
|
||||
C.ClearSegmentData(s.ptr)
|
||||
for _, indexInfo := range s.Indexes() {
|
||||
indexInfo.IsLoaded = false
|
||||
}
|
||||
}
|
||||
|
||||
// StartLoadData starts the loading process of the segment.
|
||||
func (s *LocalSegment) StartLoadData() (state.LoadStateLockGuard, error) {
|
||||
return s.ptrLock.StartLoadData()
|
||||
|
@ -1455,3 +1495,34 @@ func (s *LocalSegment) startRelease(scope ReleaseScope) state.LoadStateLockGuard
|
|||
panic(fmt.Sprintf("unexpected release scope %d", scope))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LocalSegment) RemoveFieldFile(fieldId int64) {
|
||||
C.RemoveFieldFile(s.ptr, C.int64_t(fieldId))
|
||||
}
|
||||
|
||||
func (s *LocalSegment) RemoveUnusedFieldFiles() error {
|
||||
schema := s.collection.Schema()
|
||||
indexInfos, _ := separateIndexAndBinlog(s.LoadInfo())
|
||||
for _, indexInfo := range indexInfos {
|
||||
need, err := s.indexNeedLoadRawData(schema, indexInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !need {
|
||||
s.RemoveFieldFile(indexInfo.IndexInfo.FieldID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LocalSegment) indexNeedLoadRawData(schema *schemapb.CollectionSchema, indexInfo *IndexedFieldInfo) (bool, error) {
|
||||
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
fieldSchema, err := schemaHelper.GetFieldFromID(indexInfo.IndexInfo.FieldID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !typeutil.IsVectorType(fieldSchema.DataType) && s.HasRawData(indexInfo.IndexInfo.FieldID), nil
|
||||
}
|
||||
|
|
|
@ -27,14 +27,6 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type LoadStatus string
|
||||
|
||||
const (
|
||||
LoadStatusMeta LoadStatus = "meta"
|
||||
LoadStatusMapped LoadStatus = "mapped"
|
||||
LoadStatusInMemory LoadStatus = "in_memory"
|
||||
)
|
||||
|
||||
// ResourceUsage is used to estimate the resource usage of a sealed segment.
|
||||
type ResourceUsage struct {
|
||||
MemorySize uint64
|
||||
|
@ -60,7 +52,6 @@ type Segment interface {
|
|||
StartPosition() *msgpb.MsgPosition
|
||||
Type() SegmentType
|
||||
Level() datapb.SegmentLevel
|
||||
LoadStatus() LoadStatus
|
||||
LoadInfo() *querypb.SegmentLoadInfo
|
||||
// PinIfNotReleased the segment to prevent it from being released
|
||||
PinIfNotReleased() error
|
||||
|
@ -87,7 +78,7 @@ type Segment interface {
|
|||
Delete(ctx context.Context, primaryKeys []storage.PrimaryKey, timestamps []typeutil.Timestamp) error
|
||||
LoadDeltaData(ctx context.Context, deltaData *storage.DeleteData) error
|
||||
LastDeltaTimestamp() uint64
|
||||
Release(opts ...releaseOption)
|
||||
Release(ctx context.Context, opts ...releaseOption)
|
||||
|
||||
// Bloom filter related
|
||||
UpdateBloomFilter(pks []storage.PrimaryKey)
|
||||
|
@ -99,4 +90,8 @@ type Segment interface {
|
|||
RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error)
|
||||
IsLazyLoad() bool
|
||||
ResetIndexesLazyLoad(lazyState bool)
|
||||
|
||||
// lazy load related
|
||||
NeedUpdatedVersion() int64
|
||||
RemoveUnusedFieldFiles() error
|
||||
}
|
||||
|
|
|
@ -63,8 +63,6 @@ func NewL0Segment(collection *Collection,
|
|||
}
|
||||
|
||||
// level 0 segments are always in memory
|
||||
segment.loadStatus.Store(string(LoadStatusInMemory))
|
||||
|
||||
return segment, nil
|
||||
}
|
||||
|
||||
|
@ -164,10 +162,14 @@ func (s *L0Segment) DeleteRecords() ([]storage.PrimaryKey, []uint64) {
|
|||
return s.pks, s.tss
|
||||
}
|
||||
|
||||
func (s *L0Segment) Release(opts ...releaseOption) {
|
||||
func (s *L0Segment) Release(ctx context.Context, opts ...releaseOption) {
|
||||
s.dataGuard.Lock()
|
||||
defer s.dataGuard.Unlock()
|
||||
|
||||
s.pks = nil
|
||||
s.tss = nil
|
||||
}
|
||||
|
||||
func (s *L0Segment) RemoveUnusedFieldFiles() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
|
|
@ -82,7 +82,11 @@ type Loader interface {
|
|||
LoadSegment(ctx context.Context,
|
||||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
loadStatus LoadStatus,
|
||||
) error
|
||||
|
||||
LoadLazySegment(ctx context.Context,
|
||||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
) error
|
||||
}
|
||||
|
||||
|
@ -171,7 +175,7 @@ func (loader *segmentLoaderV2) Load(ctx context.Context,
|
|||
loaded := typeutil.NewConcurrentMap[int64, Segment]()
|
||||
defer func() {
|
||||
newSegments.Range(func(_ int64, s Segment) bool {
|
||||
s.Release()
|
||||
s.Release(context.Background())
|
||||
return true
|
||||
})
|
||||
debug.FreeOSMemory()
|
||||
|
@ -214,7 +218,7 @@ func (loader *segmentLoaderV2) Load(ctx context.Context,
|
|||
if loadInfo.GetLevel() == datapb.SegmentLevel_L0 {
|
||||
err = loader.LoadDelta(ctx, collectionID, segment.(*LocalSegment))
|
||||
} else {
|
||||
err = loader.LoadSegment(ctx, segment.(*LocalSegment), loadInfo, LoadStatusInMemory)
|
||||
err = loader.LoadSegment(ctx, segment.(*LocalSegment), loadInfo)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn("load segment failed when load data into memory",
|
||||
|
@ -224,7 +228,7 @@ func (loader *segmentLoaderV2) Load(ctx context.Context,
|
|||
)
|
||||
return err
|
||||
}
|
||||
loader.manager.Segment.Put(segmentType, segment)
|
||||
loader.manager.Segment.Put(ctx, segmentType, segment)
|
||||
newSegments.GetAndRemove(segmentID)
|
||||
loaded.Insert(segmentID, segment)
|
||||
log.Info("load segment done", zap.Int64("segmentID", segmentID))
|
||||
|
@ -374,7 +378,6 @@ func (loader *segmentLoaderV2) loadBloomFilter(ctx context.Context, segmentID in
|
|||
func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
|
||||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
loadstatus LoadStatus,
|
||||
) (err error) {
|
||||
// TODO: we should create a transaction-like api to load segment for segment interface,
|
||||
// but not do many things in segment loader.
|
||||
|
@ -461,7 +464,7 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
|
|||
return err
|
||||
}
|
||||
} else {
|
||||
if err := segment.LoadMultiFieldData(ctx, loadInfo.GetNumOfRows(), loadInfo.BinlogPaths); err != nil {
|
||||
if err := segment.LoadMultiFieldData(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -480,6 +483,13 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
|
|||
return loader.LoadDelta(ctx, segment.Collection(), segment)
|
||||
}
|
||||
|
||||
func (loader *segmentLoaderV2) LoadLazySegment(ctx context.Context,
|
||||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
) (err error) {
|
||||
return merr.ErrOperationNotSupported
|
||||
}
|
||||
|
||||
func (loader *segmentLoaderV2) loadSealedSegmentFields(ctx context.Context, segment *LocalSegment, fields *typeutil.ConcurrentMap[int64, *schemapb.FieldSchema], rowCount int64) error {
|
||||
runningGroup, _ := errgroup.WithContext(ctx)
|
||||
fields.Range(func(fieldID int64, field *schemapb.FieldSchema) bool {
|
||||
|
@ -597,14 +607,20 @@ func (loader *segmentLoader) Load(ctx context.Context,
|
|||
// continue to wait other task done
|
||||
log.Info("start loading...", zap.Int("segmentNum", len(segments)), zap.Int("afterFilter", len(infos)))
|
||||
|
||||
// Check memory & storage limit
|
||||
resource, concurrencyLevel, err := loader.requestResource(ctx, infos...)
|
||||
if err != nil {
|
||||
log.Warn("request resource failed", zap.Error(err))
|
||||
return nil, err
|
||||
var err error
|
||||
var resource LoadResource
|
||||
var concurrencyLevel int
|
||||
coll := loader.manager.Collection.Get(collectionID)
|
||||
if !isLazyLoad(coll, segmentType) {
|
||||
// Check memory & storage limit
|
||||
// no need to check resource for lazy load here
|
||||
resource, concurrencyLevel, err = loader.requestResource(ctx, infos...)
|
||||
if err != nil {
|
||||
log.Warn("request resource failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
defer loader.freeRequest(resource)
|
||||
}
|
||||
defer loader.freeRequest(resource)
|
||||
|
||||
newSegments := typeutil.NewConcurrentMap[int64, Segment]()
|
||||
loaded := typeutil.NewConcurrentMap[int64, Segment]()
|
||||
defer func() {
|
||||
|
@ -613,13 +629,12 @@ func (loader *segmentLoader) Load(ctx context.Context,
|
|||
zap.Int64("segmentID", segmentID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.Release()
|
||||
s.Release(context.Background())
|
||||
return true
|
||||
})
|
||||
debug.FreeOSMemory()
|
||||
}()
|
||||
|
||||
loadStatus := LoadStatusInMemory
|
||||
collection := loader.manager.Collection.Get(collectionID)
|
||||
if collection == nil {
|
||||
err := merr.WrapErrCollectionNotFound(collectionID)
|
||||
|
@ -627,11 +642,6 @@ func (loader *segmentLoader) Load(ctx context.Context,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if common.IsCollectionLazyLoadEnabled(collection.Schema().Properties...) ||
|
||||
(!common.HasLazyload(collection.Schema().Properties) && params.Params.QueryNodeCfg.LazyLoadEnabled.GetAsBool()) {
|
||||
loadStatus = LoadStatusMeta
|
||||
}
|
||||
|
||||
for _, info := range infos {
|
||||
loadInfo := info
|
||||
|
||||
|
@ -674,17 +684,21 @@ func (loader *segmentLoader) Load(ctx context.Context,
|
|||
tr := timerecord.NewTimeRecorder("loadDurationPerSegment")
|
||||
logger.Info("load segment...")
|
||||
|
||||
// L0 segment has no index or data to be load
|
||||
// L0 segment has no index or data to be load.
|
||||
if loadInfo.GetLevel() != datapb.SegmentLevel_L0 {
|
||||
if err = loader.LoadSegment(ctx, segment.(*LocalSegment), loadInfo, loadStatus); err != nil {
|
||||
return errors.Wrap(err, "At LoadSegment")
|
||||
s := segment.(*LocalSegment)
|
||||
// lazy load segment do not load segment at first time.
|
||||
if !s.IsLazyLoad() {
|
||||
if err = loader.LoadSegment(ctx, s, loadInfo); err != nil {
|
||||
return errors.Wrap(err, "At LoadSegment")
|
||||
}
|
||||
}
|
||||
}
|
||||
if err = loader.LoadDeltaLogs(ctx, segment, loadInfo.GetDeltalogs()); err != nil {
|
||||
return errors.Wrap(err, "At LoadDeltaLogs")
|
||||
}
|
||||
|
||||
loader.manager.Segment.Put(segmentType, segment)
|
||||
loader.manager.Segment.Put(ctx, segmentType, segment)
|
||||
newSegments.GetAndRemove(segmentID)
|
||||
loaded.Insert(segmentID, segment)
|
||||
loader.notifyLoadFinish(loadInfo)
|
||||
|
@ -975,23 +989,45 @@ func separateIndexAndBinlog(loadInfo *querypb.SegmentLoadInfo) (map[int64]*Index
|
|||
return indexedFieldInfos, fieldBinlogs
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) loadSealedSegment(ctx context.Context, loadInfo *querypb.SegmentLoadInfo, segment *LocalSegment, collection *Collection, loadStatus LoadStatus) error {
|
||||
func (loader *segmentLoader) loadSealedSegment(ctx context.Context, loadInfo *querypb.SegmentLoadInfo, segment *LocalSegment) error {
|
||||
// TODO: we should create a transaction-like api to load segment for segment interface,
|
||||
// but not do many things in segment loader.
|
||||
stateLockGuard, err := segment.StartLoadData()
|
||||
// segment can not do load now.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stateLockGuard == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// Release partial loaded segment data if load failed.
|
||||
segment.ReleaseSegmentData()
|
||||
}
|
||||
stateLockGuard.Done(err)
|
||||
}()
|
||||
|
||||
collection := segment.GetCollection()
|
||||
|
||||
indexedFieldInfos, fieldBinlogs := separateIndexAndBinlog(loadInfo)
|
||||
schemaHelper, _ := typeutil.CreateSchemaHelper(collection.Schema())
|
||||
|
||||
if err := segment.AddFieldDataInfo(ctx, loadInfo.GetNumOfRows(), loadInfo.GetBinlogPaths()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("segmentID", segment.ID()))
|
||||
tr := timerecord.NewTimeRecorder("segmentLoader.LoadIndex")
|
||||
log.Info("load fields...",
|
||||
tr := timerecord.NewTimeRecorder("segmentLoader.loadSealedSegment")
|
||||
log.Info("Start loading fields...",
|
||||
zap.Int64s("indexedFields", lo.Keys(indexedFieldInfos)),
|
||||
)
|
||||
if err := loader.loadFieldsIndex(ctx, schemaHelper, segment, loadInfo.GetNumOfRows(), indexedFieldInfos, WithLoadStatus(loadStatus)); err != nil {
|
||||
if err := loader.loadFieldsIndex(ctx, schemaHelper, segment, loadInfo.GetNumOfRows(), indexedFieldInfos); err != nil {
|
||||
return err
|
||||
}
|
||||
metrics.QueryNodeLoadIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
loadFieldsIndexSpan := tr.RecordSpan()
|
||||
metrics.QueryNodeLoadIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(loadFieldsIndexSpan))
|
||||
|
||||
// 2. complement raw data for the scalar fields without raw data
|
||||
for fieldID, info := range indexedFieldInfos {
|
||||
field, err := schemaHelper.GetFieldFromID(fieldID)
|
||||
if err != nil {
|
||||
|
@ -1002,49 +1038,39 @@ func (loader *segmentLoader) loadSealedSegment(ctx context.Context, loadInfo *qu
|
|||
zap.Int64("fieldID", fieldID),
|
||||
zap.String("index", info.IndexInfo.GetIndexName()),
|
||||
)
|
||||
status := loadStatus
|
||||
if status != LoadStatusMeta {
|
||||
status = LoadStatusMapped
|
||||
}
|
||||
// for scalar index's raw data, only load to mmap not memory
|
||||
if err = segment.LoadFieldData(ctx, fieldID, loadInfo.GetNumOfRows(), info.FieldBinlog, WithLoadStatus(status)); err != nil {
|
||||
if err = segment.LoadFieldData(ctx, fieldID, loadInfo.GetNumOfRows(), info.FieldBinlog); err != nil {
|
||||
log.Warn("load raw data failed", zap.Int64("fieldID", fieldID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := loadSealedSegmentFields(ctx, collection, segment, fieldBinlogs, loadInfo.GetNumOfRows(), WithLoadStatus(loadStatus)); err != nil {
|
||||
complementScalarDataSpan := tr.RecordSpan()
|
||||
if err := loadSealedSegmentFields(ctx, collection, segment, fieldBinlogs, loadInfo.GetNumOfRows()); err != nil {
|
||||
return err
|
||||
}
|
||||
loadRawDataSpan := tr.RecordSpan()
|
||||
|
||||
// 4. rectify entries number for binlog in very rare cases
|
||||
// https://github.com/milvus-io/milvus/23654
|
||||
// legacy entry num = 0
|
||||
if err := loader.patchEntryNumber(ctx, segment, loadInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
patchEntryNumberSpan := tr.RecordSpan()
|
||||
log.Info("Finish loading segment",
|
||||
zap.Duration("loadFieldsIndexSpan", loadFieldsIndexSpan),
|
||||
zap.Duration("complementScalarDataSpan", complementScalarDataSpan),
|
||||
zap.Duration("loadRawDataSpan", loadRawDataSpan),
|
||||
zap.Duration("patchEntryNumberSpan", patchEntryNumberSpan),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) LoadSegment(ctx context.Context,
|
||||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
loadStatus LoadStatus,
|
||||
) (err error) {
|
||||
// TODO: we should create a transaction-like api to load segment for segment interface,
|
||||
// but not do many things in segment loader.
|
||||
stateLockGuard, err := segment.StartLoadData()
|
||||
// segment can not do load now.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// segment is already loaded.
|
||||
// TODO: if stateLockGuard is nil, we should not call LoadSegment anymore.
|
||||
// but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug.
|
||||
if stateLockGuard != nil {
|
||||
stateLockGuard.Done(err)
|
||||
}
|
||||
}()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", segment.Collection()),
|
||||
zap.Int64("partitionID", segment.Partition()),
|
||||
|
@ -1068,15 +1094,11 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
|
|||
defer debug.FreeOSMemory()
|
||||
|
||||
if segment.Type() == SegmentTypeSealed {
|
||||
if loadStatus == LoadStatusMeta {
|
||||
segment.baseSegment.isLazyLoad = true
|
||||
segment.baseSegment.loadInfo = loadInfo
|
||||
}
|
||||
if err := loader.loadSealedSegment(ctx, loadInfo, segment, collection, loadStatus); err != nil {
|
||||
if err := loader.loadSealedSegment(ctx, loadInfo, segment); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := segment.LoadMultiFieldData(ctx, loadInfo.GetNumOfRows(), loadInfo.BinlogPaths); err != nil {
|
||||
if err := segment.LoadMultiFieldData(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -1090,10 +1112,24 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
|
|||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) LoadLazySegment(ctx context.Context,
|
||||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
) (err error) {
|
||||
infos := []*querypb.SegmentLoadInfo{loadInfo}
|
||||
resource, _, err := loader.requestResource(ctx, infos...)
|
||||
log := log.Ctx(ctx)
|
||||
if err != nil {
|
||||
log.Warn("request resource failed", zap.Error(err))
|
||||
return merr.ErrServiceResourceInsufficient
|
||||
}
|
||||
defer loader.freeRequest(resource)
|
||||
return loader.LoadSegment(ctx, segment, loadInfo)
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBinlog, pkFieldID int64) ([]string, storage.StatsLogType) {
|
||||
result := make([]string, 0)
|
||||
for _, fieldBinlog := range fieldBinlogs {
|
||||
|
@ -1114,27 +1150,16 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi
|
|||
return result, storage.DefaultStatsType
|
||||
}
|
||||
|
||||
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64, opts ...loadOption) error {
|
||||
options := newLoadOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error {
|
||||
runningGroup, _ := errgroup.WithContext(ctx)
|
||||
for _, field := range fields {
|
||||
opts := opts
|
||||
fieldBinLog := field
|
||||
fieldID := field.FieldID
|
||||
mmapEnabled := common.IsFieldMmapEnabled(collection.Schema(), fieldID) ||
|
||||
(!common.FieldHasMmapKey(collection.Schema(), fieldID) && params.Params.QueryNodeCfg.MmapEnabled.GetAsBool())
|
||||
if mmapEnabled && options.LoadStatus == LoadStatusInMemory {
|
||||
opts = append(opts, WithLoadStatus(LoadStatusMapped))
|
||||
}
|
||||
runningGroup.Go(func() error {
|
||||
return segment.LoadFieldData(ctx,
|
||||
fieldID,
|
||||
rowCount,
|
||||
fieldBinLog,
|
||||
opts...,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
@ -1157,7 +1182,6 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
|
|||
segment *LocalSegment,
|
||||
numRows int64,
|
||||
indexedFieldInfos map[int64]*IndexedFieldInfo,
|
||||
opts ...loadOption,
|
||||
) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", segment.Collection()),
|
||||
|
@ -1168,7 +1192,9 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
|
|||
|
||||
for fieldID, fieldInfo := range indexedFieldInfos {
|
||||
indexInfo := fieldInfo.IndexInfo
|
||||
err := loader.loadFieldIndex(ctx, segment, indexInfo, opts...)
|
||||
tr := timerecord.NewTimeRecorder("loadFieldIndex")
|
||||
err := loader.loadFieldIndex(ctx, segment, indexInfo)
|
||||
loadFieldIndexSpan := tr.RecordSpan()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1177,6 +1203,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
|
|||
zap.Int64("fieldID", fieldID),
|
||||
zap.Any("binlog", fieldInfo.FieldBinlog.Binlogs),
|
||||
zap.Int32("current_index_version", fieldInfo.IndexInfo.GetCurrentIndexVersion()),
|
||||
zap.Duration("load_duration", loadFieldIndexSpan),
|
||||
)
|
||||
|
||||
// set average row data size of variable field
|
||||
|
@ -1195,7 +1222,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalSegment, indexInfo *querypb.FieldIndexInfo, opts ...loadOption) error {
|
||||
func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalSegment, indexInfo *querypb.FieldIndexInfo) error {
|
||||
filteredPaths := make([]string, 0, len(indexInfo.IndexFilePaths))
|
||||
|
||||
for _, indexPath := range indexInfo.IndexFilePaths {
|
||||
|
@ -1215,7 +1242,7 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
|
|||
return merr.WrapErrCollectionNotLoaded(segment.Collection(), "failed to load field index")
|
||||
}
|
||||
|
||||
return segment.LoadIndex(ctx, indexInfo, fieldType, opts...)
|
||||
return segment.LoadIndex(ctx, indexInfo, fieldType)
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet,
|
||||
|
@ -1526,20 +1553,20 @@ func getResourceUsageEstimateOfSegment(schema *schemapb.CollectionSchema, loadIn
|
|||
loadInfo.GetSegmentID(),
|
||||
fieldIndexInfo.GetBuildID())
|
||||
}
|
||||
segmentMemorySize += neededMemSize
|
||||
if mmapEnabled {
|
||||
segmentDiskSize += neededMemSize + neededDiskSize
|
||||
} else {
|
||||
segmentMemorySize += neededMemSize
|
||||
segmentDiskSize += neededDiskSize
|
||||
}
|
||||
} else {
|
||||
mmapEnabled = common.IsFieldMmapEnabled(schema, fieldID) ||
|
||||
(!common.FieldHasMmapKey(schema, fieldID) && params.Params.QueryNodeCfg.MmapEnabled.GetAsBool())
|
||||
binlogSize := uint64(getBinlogDataSize(fieldBinlog))
|
||||
segmentMemorySize += binlogSize
|
||||
if mmapEnabled {
|
||||
segmentDiskSize += binlogSize
|
||||
} else {
|
||||
segmentMemorySize += binlogSize
|
||||
if multiplyFactor.enableTempSegmentIndex {
|
||||
segmentMemorySize += uint64(float64(binlogSize) * multiplyFactor.tempSegmentIndexFactor)
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/apache/arrow/go/v12/arrow/memory"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
@ -97,7 +98,7 @@ func (suite *SegmentLoaderSuite) SetupTest() {
|
|||
func (suite *SegmentLoaderSuite) TearDownTest() {
|
||||
ctx := context.Background()
|
||||
for i := 0; i < suite.segmentNum; i++ {
|
||||
suite.manager.Segment.Remove(suite.segmentID+int64(i), querypb.DataScope_All)
|
||||
suite.manager.Segment.Remove(context.Background(), suite.segmentID+int64(i), querypb.DataScope_All)
|
||||
}
|
||||
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
|
||||
}
|
||||
|
@ -450,7 +451,7 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
|
|||
seg := segment.(*LocalSegment)
|
||||
// nothing would happen as the delta logs have been all applied,
|
||||
// so the released segment won't cause error
|
||||
seg.Release()
|
||||
seg.Release(ctx)
|
||||
loadInfos[i].Deltalogs[0].Binlogs[0].TimestampTo--
|
||||
err := suite.loader.LoadDeltaLogs(ctx, seg, loadInfos[i].GetDeltalogs())
|
||||
suite.NoError(err)
|
||||
|
@ -459,7 +460,6 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
|
|||
|
||||
func (suite *SegmentLoaderSuite) TestLoadIndex() {
|
||||
ctx := context.Background()
|
||||
segment := &LocalSegment{}
|
||||
loadInfo := &querypb.SegmentLoadInfo{
|
||||
SegmentID: 1,
|
||||
PartitionID: suite.partitionID,
|
||||
|
@ -470,6 +470,11 @@ func (suite *SegmentLoaderSuite) TestLoadIndex() {
|
|||
},
|
||||
},
|
||||
}
|
||||
segment := &LocalSegment{
|
||||
baseSegment: baseSegment{
|
||||
loadInfo: atomic.NewPointer[querypb.SegmentLoadInfo](loadInfo),
|
||||
},
|
||||
}
|
||||
|
||||
err := suite.loader.LoadIndex(ctx, segment, loadInfo, 0)
|
||||
suite.ErrorIs(err, merr.ErrIndexNotFound)
|
||||
|
@ -866,7 +871,7 @@ func (suite *SegmentLoaderV2Suite) SetupTest() {
|
|||
func (suite *SegmentLoaderV2Suite) TearDownTest() {
|
||||
ctx := context.Background()
|
||||
for i := 0; i < suite.segmentNum; i++ {
|
||||
suite.manager.Segment.Remove(suite.segmentID+int64(i), querypb.DataScope_All)
|
||||
suite.manager.Segment.Remove(context.Background(), suite.segmentID+int64(i), querypb.DataScope_All)
|
||||
}
|
||||
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
|
||||
paramtable.Get().CommonCfg.EnableStorageV2.SwapTempValue("false")
|
||||
|
|
|
@ -71,6 +71,7 @@ func (suite *SegmentSuite) SetupTest() {
|
|||
PartitionID: suite.partitionID,
|
||||
InsertChannel: "dml",
|
||||
Level: datapb.SegmentLevel_Legacy,
|
||||
NumOfRows: int64(msgLength),
|
||||
BinlogPaths: []*datapb.FieldBinlog{
|
||||
{
|
||||
FieldID: 101,
|
||||
|
@ -123,14 +124,14 @@ func (suite *SegmentSuite) SetupTest() {
|
|||
err = suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
suite.manager.Segment.Put(SegmentTypeSealed, suite.sealed)
|
||||
suite.manager.Segment.Put(SegmentTypeGrowing, suite.growing)
|
||||
suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed)
|
||||
suite.manager.Segment.Put(context.Background(), SegmentTypeGrowing, suite.growing)
|
||||
}
|
||||
|
||||
func (suite *SegmentSuite) TearDownTest() {
|
||||
ctx := context.Background()
|
||||
suite.sealed.Release()
|
||||
suite.growing.Release()
|
||||
suite.sealed.Release(context.Background())
|
||||
suite.growing.Release(context.Background())
|
||||
DeleteCollection(suite.collection)
|
||||
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
|
||||
}
|
||||
|
@ -139,7 +140,7 @@ func (suite *SegmentSuite) TestLoadInfo() {
|
|||
// sealed segment has load info
|
||||
suite.NotNil(suite.sealed.LoadInfo())
|
||||
// growing segment has no load info
|
||||
suite.Nil(suite.growing.LoadInfo())
|
||||
suite.NotNil(suite.growing.LoadInfo())
|
||||
}
|
||||
|
||||
func (suite *SegmentSuite) TestResourceUsageEstimate() {
|
||||
|
@ -196,8 +197,11 @@ func (suite *SegmentSuite) TestCASVersion() {
|
|||
suite.Equal(curVersion+1, segment.Version())
|
||||
}
|
||||
|
||||
func (suite *SegmentSuite) TestSegmentRemoveUnusedFieldFiles() {
|
||||
}
|
||||
|
||||
func (suite *SegmentSuite) TestSegmentReleased() {
|
||||
suite.sealed.Release()
|
||||
suite.sealed.Release(context.Background())
|
||||
|
||||
sealed := suite.sealed.(*LocalSegment)
|
||||
|
||||
|
|
|
@ -182,3 +182,13 @@ func getSegmentMetricLabel(segment Segment) metricsutil.SegmentLabel {
|
|||
ResourceGroup: segment.ResourceGroup(),
|
||||
}
|
||||
}
|
||||
|
||||
func FilterZeroValuesFromSlice(intVals []int64) []int64 {
|
||||
var result []int64
|
||||
for _, value := range intVals {
|
||||
if value != 0 {
|
||||
result = append(result, value)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilterZeroValuesFromSlice(t *testing.T) {
|
||||
var ints []int64
|
||||
ints = append(ints, 10)
|
||||
ints = append(ints, 0)
|
||||
ints = append(ints, 5)
|
||||
ints = append(ints, 13)
|
||||
ints = append(ints, 0)
|
||||
|
||||
filteredInts := FilterZeroValuesFromSlice(ints)
|
||||
assert.Equal(t, 3, len(filteredInts))
|
||||
assert.EqualValues(t, []int64{10, 5, 13}, filteredInts)
|
||||
}
|
|
@ -484,7 +484,7 @@ func (node *QueryNode) Stop() error {
|
|||
node.dispClient.Close()
|
||||
}
|
||||
if node.manager != nil {
|
||||
node.manager.Segment.Clear()
|
||||
node.manager.Segment.Clear(context.Background())
|
||||
}
|
||||
|
||||
node.CloseSegcore()
|
||||
|
|
|
@ -236,7 +236,7 @@ func (suite *QueryNodeSuite) TestStop() {
|
|||
},
|
||||
)
|
||||
suite.NoError(err)
|
||||
suite.node.manager.Segment.Put(segments.SegmentTypeSealed, segment)
|
||||
suite.node.manager.Segment.Put(context.Background(), segments.SegmentTypeSealed, segment)
|
||||
err = suite.node.Stop()
|
||||
suite.NoError(err)
|
||||
suite.True(suite.node.manager.Segment.Empty())
|
||||
|
|
|
@ -301,7 +301,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
|||
defer func() {
|
||||
if err != nil {
|
||||
// remove legacy growing
|
||||
node.manager.Segment.RemoveBy(segments.WithChannel(channel.GetChannelName()),
|
||||
node.manager.Segment.RemoveBy(ctx, segments.WithChannel(channel.GetChannelName()),
|
||||
segments.WithType(segments.SegmentTypeGrowing))
|
||||
}
|
||||
}()
|
||||
|
@ -357,8 +357,8 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||
delegator.Close()
|
||||
|
||||
node.pipelineManager.Remove(req.GetChannelName())
|
||||
node.manager.Segment.RemoveBy(segments.WithChannel(req.GetChannelName()), segments.WithType(segments.SegmentTypeGrowing))
|
||||
node.manager.Segment.RemoveBy(segments.WithChannel(req.GetChannelName()), segments.WithLevel(datapb.SegmentLevel_L0))
|
||||
node.manager.Segment.RemoveBy(ctx, segments.WithChannel(req.GetChannelName()), segments.WithType(segments.SegmentTypeGrowing))
|
||||
node.manager.Segment.RemoveBy(ctx, segments.WithChannel(req.GetChannelName()), segments.WithLevel(datapb.SegmentLevel_L0))
|
||||
node.tSafeManager.Remove(ctx, req.GetChannelName())
|
||||
|
||||
node.manager.Collection.Unref(req.GetCollectionID(), 1)
|
||||
|
@ -547,7 +547,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release
|
|||
log.Info("start to release segments")
|
||||
sealedCount := 0
|
||||
for _, id := range req.GetSegmentIDs() {
|
||||
_, count := node.manager.Segment.Remove(id, req.GetScope())
|
||||
_, count := node.manager.Segment.Remove(ctx, id, req.GetScope())
|
||||
sealedCount += count
|
||||
}
|
||||
node.manager.Collection.Unref(req.GetCollectionID(), uint32(sealedCount))
|
||||
|
@ -660,7 +660,13 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID)
|
||||
var task tasks.Task
|
||||
if paramtable.Get().QueryNodeCfg.UseStreamComputing.GetAsBool() {
|
||||
task = tasks.NewStreamingSearchTask(searchCtx, collection, node.manager, req, node.serverID)
|
||||
} else {
|
||||
task = tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID)
|
||||
}
|
||||
|
||||
if err := node.scheduler.Add(task); err != nil {
|
||||
log.Warn("failed to search channel", zap.Error(err))
|
||||
resp.Status = merr.Status(err)
|
||||
|
@ -683,7 +689,7 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
|
|||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
|
||||
|
||||
resp = task.Result()
|
||||
resp = task.SearchResult()
|
||||
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||
resp.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
|
||||
return resp, nil
|
||||
|
@ -767,7 +773,8 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
}
|
||||
|
||||
reduceLatency := tr.RecordSpan()
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
|
||||
metrics.QueryNodeReduceLatency.
|
||||
WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce).
|
||||
Observe(float64(reduceLatency.Milliseconds()))
|
||||
|
||||
collector.Rate.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
|
||||
|
@ -914,7 +921,8 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
|||
}, nil
|
||||
}
|
||||
reduceLatency := tr.RecordSpan()
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards).
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()),
|
||||
metrics.QueryLabel, metrics.ReduceShards, metrics.BatchReduce).
|
||||
Observe(float64(reduceLatency.Milliseconds()))
|
||||
|
||||
collector.Rate.Add(metricsinfo.NQPerSecond, 1)
|
||||
|
|
|
@ -160,7 +160,6 @@ func (suite *ServiceSuite) TearDownTest() {
|
|||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
suite.node.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
|
||||
|
||||
suite.node.Stop()
|
||||
suite.etcdClient.Close()
|
||||
}
|
||||
|
@ -474,9 +473,9 @@ func (suite *ServiceSuite) TestUnsubDmChannels_Normal() {
|
|||
l0Segment.EXPECT().Type().Return(commonpb.SegmentState_Sealed)
|
||||
l0Segment.EXPECT().Indexes().Return(nil)
|
||||
l0Segment.EXPECT().Shard().Return(suite.vchannel)
|
||||
l0Segment.EXPECT().Release().Return()
|
||||
l0Segment.EXPECT().Release(ctx).Return()
|
||||
|
||||
suite.node.manager.Segment.Put(segments.SegmentTypeSealed, l0Segment)
|
||||
suite.node.manager.Segment.Put(ctx, segments.SegmentTypeSealed, l0Segment)
|
||||
|
||||
// data
|
||||
req := &querypb.UnsubDmChannelRequest{
|
||||
|
@ -1364,6 +1363,48 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
|||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestStreamingSearch() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true")
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: true,
|
||||
DmlChannels: []string{suite.vchannel},
|
||||
TotalChannelNum: 2,
|
||||
SegmentIDs: suite.validSegmentIDs,
|
||||
Scope: querypb.DataScope_Historical,
|
||||
}
|
||||
suite.NoError(err)
|
||||
|
||||
rsp, err := suite.node.SearchSegments(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestStreamingSearchGrowing() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true")
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: true,
|
||||
DmlChannels: []string{suite.vchannel},
|
||||
TotalChannelNum: 2,
|
||||
Scope: querypb.DataScope_Streaming,
|
||||
}
|
||||
suite.NoError(err)
|
||||
|
||||
rsp, err := suite.node.SearchSegments(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
// Test Query
|
||||
func (suite *ServiceSuite) genCQueryRequest(nq int64, indexType string, schema *schemapb.CollectionSchema) (*internalpb.RetrieveRequest, error) {
|
||||
expr, err := genSimpleRetrievePlanExpr(schema)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
)
|
||||
|
||||
|
@ -114,6 +115,10 @@ func (t *MockTask) MergeWith(t2 Task) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (t *MockTask) SearchResult() *internalpb.SearchResults {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *MockTask) NQ() int64 {
|
||||
return t.nq
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tasks
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
|
@ -86,3 +87,7 @@ func (t *QueryStreamTask) Wait() error {
|
|||
func (t *QueryStreamTask) NQ() int64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *QueryStreamTask) SearchResult() *internalpb.SearchResults {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -91,6 +91,10 @@ func (t *QueryTask) PreExecute() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *QueryTask) SearchResult() *internalpb.SearchResults {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute the task, only call once.
|
||||
func (t *QueryTask) Execute() error {
|
||||
if t.scheduleSpan != nil {
|
||||
|
@ -125,7 +129,8 @@ func (t *QueryTask) Execute() error {
|
|||
metrics.QueryNodeReduceLatency.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.QueryLabel,
|
||||
metrics.ReduceSegments).Observe(float64(time.Since(beforeReduce).Milliseconds()))
|
||||
metrics.ReduceSegments,
|
||||
metrics.BatchReduce).Observe(float64(time.Since(beforeReduce).Milliseconds()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package tasks
|
|||
|
||||
// TODO: rename this file into search_task.go
|
||||
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
@ -233,7 +235,8 @@ func (t *SearchTask) Execute() error {
|
|||
metrics.QueryNodeReduceLatency.WithLabelValues(
|
||||
fmt.Sprint(t.GetNodeID()),
|
||||
metrics.SearchLabel,
|
||||
metrics.ReduceSegments).
|
||||
metrics.ReduceSegments,
|
||||
metrics.BatchReduce).
|
||||
Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
for i := range t.originNqs {
|
||||
blob, err := segments.GetSearchResultDataBlob(t.ctx, blobs, i)
|
||||
|
@ -333,7 +336,7 @@ func (t *SearchTask) Wait() error {
|
|||
return <-t.notifier
|
||||
}
|
||||
|
||||
func (t *SearchTask) Result() *internalpb.SearchResults {
|
||||
func (t *SearchTask) SearchResult() *internalpb.SearchResults {
|
||||
if t.result != nil {
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
for _, ch := range t.req.GetDmlChannels() {
|
||||
|
@ -382,3 +385,214 @@ func (t *SearchTask) combinePlaceHolderGroups() error {
|
|||
t.placeholderGroup, _ = proto.Marshal(ret)
|
||||
return nil
|
||||
}
|
||||
|
||||
type StreamingSearchTask struct {
|
||||
SearchTask
|
||||
others []*StreamingSearchTask
|
||||
resultBlobs segments.SearchResultDataBlobs
|
||||
streamReducer segments.StreamSearchReducer
|
||||
}
|
||||
|
||||
func NewStreamingSearchTask(ctx context.Context,
|
||||
collection *segments.Collection,
|
||||
manager *segments.Manager,
|
||||
req *querypb.SearchRequest,
|
||||
serverID int64,
|
||||
) *StreamingSearchTask {
|
||||
ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "schedule")
|
||||
return &StreamingSearchTask{
|
||||
SearchTask: SearchTask{
|
||||
ctx: ctx,
|
||||
collection: collection,
|
||||
segmentManager: manager,
|
||||
req: req,
|
||||
merged: false,
|
||||
groupSize: 1,
|
||||
topk: req.GetReq().GetTopk(),
|
||||
nq: req.GetReq().GetNq(),
|
||||
placeholderGroup: req.GetReq().GetPlaceholderGroup(),
|
||||
originTopks: []int64{req.GetReq().GetTopk()},
|
||||
originNqs: []int64{req.GetReq().GetNq()},
|
||||
notifier: make(chan error, 1),
|
||||
tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"),
|
||||
scheduleSpan: span,
|
||||
serverID: serverID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *StreamingSearchTask) MergeWith(other Task) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *StreamingSearchTask) Execute() error {
|
||||
log := log.Ctx(t.ctx).With(
|
||||
zap.Int64("collectionID", t.collection.ID()),
|
||||
zap.String("shard", t.req.GetDmlChannels()[0]),
|
||||
)
|
||||
// 0. prepare search req
|
||||
if t.scheduleSpan != nil {
|
||||
t.scheduleSpan.End()
|
||||
}
|
||||
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")
|
||||
req := t.req
|
||||
t.combinePlaceHolderGroups()
|
||||
searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer searchReq.Delete()
|
||||
|
||||
// 1. search&&reduce or streaming-search&&streaming-reduce
|
||||
metricType := searchReq.Plan().GetMetricType()
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
streamReduceFunc := func(result *segments.SearchResult) error {
|
||||
reduceErr := t.streamReduce(t.ctx, searchReq.Plan(), result, t.originNqs, t.originTopks)
|
||||
return reduceErr
|
||||
}
|
||||
pinnedSegments, err := segments.SearchHistoricalStreamly(
|
||||
t.ctx,
|
||||
t.segmentManager,
|
||||
searchReq,
|
||||
req.GetReq().GetCollectionID(),
|
||||
nil,
|
||||
req.GetSegmentIDs(),
|
||||
streamReduceFunc)
|
||||
defer segments.DeleteStreamReduceHelper(t.streamReducer)
|
||||
defer t.segmentManager.Segment.Unpin(pinnedSegments)
|
||||
if err != nil {
|
||||
log.Error("Failed to search sealed segments streamly", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.resultBlobs, err = segments.GetStreamReduceResult(t.ctx, t.streamReducer)
|
||||
defer segments.DeleteSearchResultDataBlobs(t.resultBlobs)
|
||||
if err != nil {
|
||||
log.Error("Failed to get stream-reduced search result")
|
||||
return err
|
||||
}
|
||||
} else if req.GetScope() == querypb.DataScope_Streaming {
|
||||
results, pinnedSegments, err := segments.SearchStreaming(
|
||||
t.ctx,
|
||||
t.segmentManager,
|
||||
searchReq,
|
||||
req.GetReq().GetCollectionID(),
|
||||
nil,
|
||||
req.GetSegmentIDs(),
|
||||
)
|
||||
defer segments.DeleteSearchResults(results)
|
||||
defer t.segmentManager.Segment.Unpin(pinnedSegments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t.maybeReturnForEmptyResults(results, metricType, tr) {
|
||||
return nil
|
||||
}
|
||||
tr.RecordSpan()
|
||||
t.resultBlobs, err = segments.ReduceSearchResultsAndFillData(
|
||||
t.ctx,
|
||||
searchReq.Plan(),
|
||||
results,
|
||||
int64(len(results)),
|
||||
t.originNqs,
|
||||
t.originTopks,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
defer segments.DeleteSearchResultDataBlobs(t.resultBlobs)
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(
|
||||
fmt.Sprint(t.GetNodeID()),
|
||||
metrics.SearchLabel,
|
||||
metrics.ReduceSegments,
|
||||
metrics.BatchReduce).
|
||||
Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
}
|
||||
|
||||
// 2. reorganize blobs to original search request
|
||||
for i := range t.originNqs {
|
||||
blob, err := segments.GetSearchResultDataBlob(t.ctx, t.resultBlobs, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var task *StreamingSearchTask
|
||||
if i == 0 {
|
||||
task = t
|
||||
} else {
|
||||
task = t.others[i-1]
|
||||
}
|
||||
|
||||
// Note: blob is unsafe because get from C
|
||||
bs := make([]byte, len(blob))
|
||||
copy(bs, blob)
|
||||
|
||||
task.result = &internalpb.SearchResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: t.GetNodeID(),
|
||||
},
|
||||
Status: merr.Success(),
|
||||
MetricType: metricType,
|
||||
NumQueries: t.originNqs[i],
|
||||
TopK: t.originTopks[i],
|
||||
SlicedBlob: bs,
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
CostAggregation: &internalpb.CostAggregation{
|
||||
ServiceTime: tr.ElapseSpan().Milliseconds(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *StreamingSearchTask) maybeReturnForEmptyResults(results []*segments.SearchResult,
|
||||
metricType string, tr *timerecord.TimeRecorder,
|
||||
) bool {
|
||||
if len(results) == 0 {
|
||||
for i := range t.originNqs {
|
||||
var task *StreamingSearchTask
|
||||
if i == 0 {
|
||||
task = t
|
||||
} else {
|
||||
task = t.others[i-1]
|
||||
}
|
||||
|
||||
task.result = &internalpb.SearchResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
SourceID: t.GetNodeID(),
|
||||
},
|
||||
Status: merr.Success(),
|
||||
MetricType: metricType,
|
||||
NumQueries: t.originNqs[i],
|
||||
TopK: t.originTopks[i],
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
CostAggregation: &internalpb.CostAggregation{
|
||||
ServiceTime: tr.ElapseSpan().Milliseconds(),
|
||||
},
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *StreamingSearchTask) streamReduce(ctx context.Context,
|
||||
plan *segments.SearchPlan,
|
||||
newResult *segments.SearchResult,
|
||||
sliceNQs []int64,
|
||||
sliceTopKs []int64,
|
||||
) error {
|
||||
if t.streamReducer == nil {
|
||||
var err error
|
||||
t.streamReducer, err = segments.NewStreamReducer(ctx, plan, sliceNQs, sliceTopKs)
|
||||
if err != nil {
|
||||
log.Error("Fail to init stream reducer, return")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return segments.StreamReduceSearchResult(ctx, newResult, t.streamReducer)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package tasks
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
||||
const (
|
||||
schedulePolicyNameFIFO = "fifo"
|
||||
schedulePolicyNameUserTaskPolling = "user-task-polling"
|
||||
|
@ -104,4 +106,6 @@ type Task interface {
|
|||
|
||||
// Return the NQ of task.
|
||||
NQ() int64
|
||||
|
||||
SearchResult() *internalpb.SearchResults
|
||||
}
|
||||
|
|
|
@ -66,6 +66,9 @@ const (
|
|||
ReduceSegments = "segments"
|
||||
ReduceShards = "shards"
|
||||
|
||||
BatchReduce = "batch_reduce"
|
||||
StreamReduce = "stream_reduce"
|
||||
|
||||
Pending = "pending"
|
||||
Executing = "executing"
|
||||
Done = "done"
|
||||
|
@ -96,6 +99,7 @@ const (
|
|||
requestScope = "scope"
|
||||
fullMethodLabelName = "full_method"
|
||||
reduceLevelName = "reduce_level"
|
||||
reduceType = "reduce_type"
|
||||
lockName = "lock_name"
|
||||
lockSource = "lock_source"
|
||||
lockType = "lock_type"
|
||||
|
|
|
@ -229,6 +229,7 @@ var (
|
|||
nodeIDLabelName,
|
||||
queryTypeLabelName,
|
||||
reduceLevelName,
|
||||
reduceType,
|
||||
})
|
||||
|
||||
QueryNodeLoadSegmentLatency = prometheus.NewHistogramVec(
|
||||
|
|
|
@ -18,13 +18,15 @@ package cache
|
|||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/lock"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -35,14 +37,15 @@ var (
|
|||
)
|
||||
|
||||
type cacheItem[K comparable, V any] struct {
|
||||
key K
|
||||
value V
|
||||
pinCount atomic.Int32
|
||||
key K
|
||||
value V
|
||||
pinCount atomic.Int32
|
||||
needReload bool
|
||||
}
|
||||
|
||||
type (
|
||||
Loader[K comparable, V any] func(key K) (V, bool)
|
||||
Finalizer[K comparable, V any] func(key K, value V) error
|
||||
Loader[K comparable, V any] func(ctx context.Context, key K) (V, error)
|
||||
Finalizer[K comparable, V any] func(ctx context.Context, key K, value V) error
|
||||
)
|
||||
|
||||
// Scavenger records occupation of cache and decide whether to evict if necessary.
|
||||
|
@ -57,18 +60,26 @@ type Scavenger[K comparable] interface {
|
|||
Collect(key K) (bool, func(K) bool)
|
||||
// Throw records entry removals.
|
||||
Throw(key K)
|
||||
// Spare returns a collector function based on given key.
|
||||
// The collector is a function which can be invoked repetedly, each invocation will test if there is enough
|
||||
// room for all the pending entries if the thrown entry is evicted. Typically, the collector will get multiple true
|
||||
// before it gets a false.
|
||||
Spare(key K) func(K) bool
|
||||
Replace(key K) (bool, func(K) bool, func())
|
||||
}
|
||||
|
||||
type LazyScavenger[K comparable] struct {
|
||||
capacity int64
|
||||
size int64
|
||||
weight func(K) int64
|
||||
weights map[K]int64
|
||||
}
|
||||
|
||||
func NewLazyScavenger[K comparable](weight func(K) int64, capacity int64) *LazyScavenger[K] {
|
||||
return &LazyScavenger[K]{
|
||||
capacity: capacity,
|
||||
weight: weight,
|
||||
weights: make(map[K]int64),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,16 +88,47 @@ func (s *LazyScavenger[K]) Collect(key K) (bool, func(K) bool) {
|
|||
if s.size+w > s.capacity {
|
||||
needCollect := s.size + w - s.capacity
|
||||
return false, func(key K) bool {
|
||||
needCollect -= s.weight(key)
|
||||
needCollect -= s.weights[key]
|
||||
return needCollect <= 0
|
||||
}
|
||||
}
|
||||
s.size += w
|
||||
s.weights[key] = w
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *LazyScavenger[K]) Replace(key K) (bool, func(K) bool, func()) {
|
||||
pw := s.weights[key]
|
||||
w := s.weight(key)
|
||||
if s.size-pw+w > s.capacity {
|
||||
needCollect := s.size - pw + w - s.capacity
|
||||
return false, func(key K) bool {
|
||||
needCollect -= s.weights[key]
|
||||
return needCollect <= 0
|
||||
}, nil
|
||||
}
|
||||
s.size += w - pw
|
||||
s.weights[key] = w
|
||||
return true, nil, func() {
|
||||
s.size -= w - pw
|
||||
s.weights[key] = pw
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LazyScavenger[K]) Throw(key K) {
|
||||
s.size -= s.weight(key)
|
||||
if w, ok := s.weights[key]; ok {
|
||||
s.size -= w
|
||||
delete(s.weights, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LazyScavenger[K]) Spare(key K) func(K) bool {
|
||||
w := s.weight(key)
|
||||
available := s.capacity - s.size + w
|
||||
return func(k K) bool {
|
||||
available -= s.weight(k)
|
||||
return available >= 0
|
||||
}
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
|
@ -104,15 +146,21 @@ type Cache[K comparable, V any] interface {
|
|||
// completes.
|
||||
// Throws `ErrNoSuchItem` if the key is not found or not able to be loaded from given loader.
|
||||
// Throws `ErrNotEnoughSpace` if there is no room for the operation.
|
||||
Do(key K, doer func(V) error) (missing bool, err error)
|
||||
Do(ctx context.Context, key K, doer func(V) error) (missing bool, err error)
|
||||
// Do the operation `doer` on the given key `key`. The key is kept in the cache until the operation
|
||||
// completes. The function waits for `timeout` if there is not enough space for the given key.
|
||||
// Throws `ErrNoSuchItem` if the key is not found or not able to be loaded from given loader.
|
||||
// Throws `ErrTimeOut` if timed out.
|
||||
DoWait(key K, timeout time.Duration, doer func(V) error) (missing bool, err error)
|
||||
|
||||
DoWait(ctx context.Context, key K, timeout time.Duration, doer func(context.Context, V) error) (missing bool, err error)
|
||||
// Get stats
|
||||
Stats() *Stats
|
||||
|
||||
MarkItemNeedReload(ctx context.Context, key K) bool
|
||||
|
||||
// Expire removes the item from the cache.
|
||||
// Return true if the item is not in used and removed immediately or the item is not in cache now.
|
||||
// Return false if the item is in used, it will be marked as need to be reloaded, a lazy expire is applied.
|
||||
Expire(ctx context.Context, key K) (evicted bool)
|
||||
}
|
||||
|
||||
type Waiter[K comparable] struct {
|
||||
|
@ -131,21 +179,23 @@ func newWaiter[K comparable](key K) Waiter[K] {
|
|||
type lruCache[K comparable, V any] struct {
|
||||
rwlock sync.RWMutex
|
||||
// the value is *cacheItem[V]
|
||||
items map[K]*list.Element
|
||||
accessList *list.List
|
||||
loaderSingleFlight singleflight.Group
|
||||
stats *Stats
|
||||
waitQueue *list.List
|
||||
items map[K]*list.Element
|
||||
accessList *list.List
|
||||
loaderKeyLocks *lock.KeyLock[K]
|
||||
stats *Stats
|
||||
waitQueue *list.List
|
||||
|
||||
loader Loader[K, V]
|
||||
finalizer Finalizer[K, V]
|
||||
scavenger Scavenger[K]
|
||||
reloader Loader[K, V]
|
||||
}
|
||||
|
||||
type CacheBuilder[K comparable, V any] struct {
|
||||
loader Loader[K, V]
|
||||
finalizer Finalizer[K, V]
|
||||
scavenger Scavenger[K]
|
||||
reloader Loader[K, V]
|
||||
}
|
||||
|
||||
func NewCacheBuilder[K comparable, V any]() *CacheBuilder[K, V] {
|
||||
|
@ -186,30 +236,37 @@ func (b *CacheBuilder[K, V]) WithCapacity(capacity int64) *CacheBuilder[K, V] {
|
|||
return b
|
||||
}
|
||||
|
||||
func (b *CacheBuilder[K, V]) WithReloader(reloader Loader[K, V]) *CacheBuilder[K, V] {
|
||||
b.reloader = reloader
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *CacheBuilder[K, V]) Build() Cache[K, V] {
|
||||
return newLRUCache(b.loader, b.finalizer, b.scavenger)
|
||||
return newLRUCache(b.loader, b.finalizer, b.scavenger, b.reloader)
|
||||
}
|
||||
|
||||
func newLRUCache[K comparable, V any](
|
||||
loader Loader[K, V],
|
||||
finalizer Finalizer[K, V],
|
||||
scavenger Scavenger[K],
|
||||
reloader Loader[K, V],
|
||||
) Cache[K, V] {
|
||||
return &lruCache[K, V]{
|
||||
items: make(map[K]*list.Element),
|
||||
accessList: list.New(),
|
||||
waitQueue: list.New(),
|
||||
loaderSingleFlight: singleflight.Group{},
|
||||
stats: new(Stats),
|
||||
loader: loader,
|
||||
finalizer: finalizer,
|
||||
scavenger: scavenger,
|
||||
items: make(map[K]*list.Element),
|
||||
accessList: list.New(),
|
||||
waitQueue: list.New(),
|
||||
loaderKeyLocks: lock.NewKeyLock[K](),
|
||||
stats: new(Stats),
|
||||
loader: loader,
|
||||
finalizer: finalizer,
|
||||
scavenger: scavenger,
|
||||
reloader: reloader,
|
||||
}
|
||||
}
|
||||
|
||||
// Do picks up an item from cache and executes doer. The entry of interest is garented in the cache when doer is executing.
|
||||
func (c *lruCache[K, V]) Do(key K, doer func(V) error) (bool, error) {
|
||||
item, missing, err := c.getAndPin(key)
|
||||
func (c *lruCache[K, V]) Do(ctx context.Context, key K, doer func(V) error) (bool, error) {
|
||||
item, missing, err := c.getAndPin(ctx, key)
|
||||
if err != nil {
|
||||
return missing, err
|
||||
}
|
||||
|
@ -217,8 +274,11 @@ func (c *lruCache[K, V]) Do(key K, doer func(V) error) (bool, error) {
|
|||
return missing, doer(item.value)
|
||||
}
|
||||
|
||||
func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error) (bool, error) {
|
||||
timedWait := func(cond *sync.Cond, timeout time.Duration) bool {
|
||||
func (c *lruCache[K, V]) DoWait(ctx context.Context, key K, timeout time.Duration, doer func(context.Context, V) error) (bool, error) {
|
||||
timedWait := func(cond *sync.Cond, timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
return ErrTimeOut // timed out
|
||||
}
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
cond.L.Lock()
|
||||
|
@ -228,25 +288,34 @@ func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error
|
|||
}()
|
||||
select {
|
||||
case <-c:
|
||||
return false // completed normally
|
||||
return nil // completed normally
|
||||
case <-time.After(timeout):
|
||||
return true // timed out
|
||||
return ErrTimeOut // timed out
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // context timeout
|
||||
}
|
||||
}
|
||||
|
||||
var ele *list.Element
|
||||
defer func() {
|
||||
if ele != nil {
|
||||
c.rwlock.Lock()
|
||||
c.waitQueue.Remove(ele)
|
||||
c.rwlock.Unlock()
|
||||
}
|
||||
}()
|
||||
start := time.Now()
|
||||
log := log.Ctx(ctx).With(zap.Any("key", key))
|
||||
for {
|
||||
item, missing, err := c.getAndPin(key)
|
||||
item, missing, err := c.getAndPin(ctx, key)
|
||||
if err == nil {
|
||||
if ele != nil {
|
||||
c.rwlock.Lock()
|
||||
c.waitQueue.Remove(ele)
|
||||
c.rwlock.Unlock()
|
||||
}
|
||||
defer c.Unpin(key)
|
||||
return missing, doer(item.value)
|
||||
} else if err != ErrNotEnoughSpace {
|
||||
return missing, doer(ctx, item.value)
|
||||
} else if err == ErrNotEnoughSpace {
|
||||
log.Warn("Failed to get disk cache for segment, wait and try again")
|
||||
} else if err == merr.ErrServiceResourceInsufficient {
|
||||
log.Warn("Failed to load segment for insufficient resource, wait and try later")
|
||||
} else if err != nil {
|
||||
return true, err
|
||||
}
|
||||
if ele == nil {
|
||||
|
@ -254,12 +323,17 @@ func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error
|
|||
c.rwlock.Lock()
|
||||
waiter := newWaiter(key)
|
||||
ele = c.waitQueue.PushBack(&waiter)
|
||||
log.Info("push waiter into waiter queue", zap.Any("key", key))
|
||||
c.rwlock.Unlock()
|
||||
}
|
||||
// Wait for the key to be available
|
||||
timeLeft := time.Until(start.Add(timeout))
|
||||
if timeLeft <= 0 || timedWait(ele.Value.(*Waiter[K]).c, timeLeft) {
|
||||
return true, ErrTimeOut
|
||||
if err = timedWait(ele.Value.(*Waiter[K]).c, timeLeft); err != nil {
|
||||
log.Warn("failed to get item for key",
|
||||
zap.Any("key", key),
|
||||
zap.Int("wait_len", c.waitQueue.Len()),
|
||||
zap.Error(err))
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -277,77 +351,95 @@ func (c *lruCache[K, V]) Unpin(key K) {
|
|||
}
|
||||
item := e.Value.(*cacheItem[K, V])
|
||||
item.pinCount.Dec()
|
||||
if item.pinCount.Load() == 0 {
|
||||
c.notifyWaiters()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *lruCache[K, V]) notifyWaiters() {
|
||||
if c.waitQueue.Len() > 0 {
|
||||
log := log.With(zap.Any("UnPinedKey", key))
|
||||
if item.pinCount.Load() == 0 && c.waitQueue.Len() > 0 {
|
||||
log.Debug("Unpin item to zero ref, trigger activating waiters")
|
||||
// Notify waiters
|
||||
// collector := c.scavenger.Spare(key)
|
||||
for e := c.waitQueue.Front(); e != nil; e = e.Next() {
|
||||
w := e.Value.(*Waiter[K])
|
||||
log.Info("try to activate waiter", zap.Any("activated_waiter_key", w.key))
|
||||
w.c.Broadcast()
|
||||
// we try best to activate as many waiters as possible every time
|
||||
}
|
||||
} else {
|
||||
log.Debug("Miss to trigger activating waiters",
|
||||
zap.Int32("PinCount", item.pinCount.Load()),
|
||||
zap.Int("wait_len", c.waitQueue.Len()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *lruCache[K, V]) peekAndPin(key K) *cacheItem[K, V] {
|
||||
func (c *lruCache[K, V]) peekAndPin(ctx context.Context, key K) *cacheItem[K, V] {
|
||||
c.rwlock.Lock()
|
||||
defer c.rwlock.Unlock()
|
||||
e, ok := c.items[key]
|
||||
log := log.Ctx(ctx)
|
||||
if ok {
|
||||
item := e.Value.(*cacheItem[K, V])
|
||||
if item.needReload && item.pinCount.Load() == 0 {
|
||||
ok, _, retback := c.scavenger.Replace(key)
|
||||
if ok {
|
||||
// there is room for reload and no one is using the item
|
||||
if c.reloader != nil {
|
||||
reloaded, err := c.reloader(ctx, key)
|
||||
if err == nil {
|
||||
item.value = reloaded
|
||||
} else if retback != nil {
|
||||
retback()
|
||||
}
|
||||
}
|
||||
item.needReload = false
|
||||
}
|
||||
}
|
||||
c.accessList.MoveToFront(e)
|
||||
item.pinCount.Inc()
|
||||
log.Debug("peeked item success",
|
||||
zap.Int32("PinCount", item.pinCount.Load()),
|
||||
zap.Any("key", key))
|
||||
return item
|
||||
}
|
||||
log.Debug("failed to peek item", zap.Any("key", key))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAndPin gets and pins the given key if it exists
|
||||
func (c *lruCache[K, V]) getAndPin(key K) (*cacheItem[K, V], bool, error) {
|
||||
if item := c.peekAndPin(key); item != nil {
|
||||
func (c *lruCache[K, V]) getAndPin(ctx context.Context, key K) (*cacheItem[K, V], bool, error) {
|
||||
if item := c.peekAndPin(ctx, key); item != nil {
|
||||
c.stats.HitCount.Inc()
|
||||
return item, false, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
c.stats.MissCount.Inc()
|
||||
if c.loader != nil {
|
||||
// Try scavenge if there is room. If not, fail fast.
|
||||
// Note that the test is not accurate since we are not locking `loader` here.
|
||||
if _, ok := c.tryScavenge(key); !ok {
|
||||
log.Warn("getAndPin ran into scavenge failure, return", zap.Any("key", key))
|
||||
return nil, true, ErrNotEnoughSpace
|
||||
}
|
||||
|
||||
strKey := fmt.Sprint(key)
|
||||
item, err, _ := c.loaderSingleFlight.Do(strKey, func() (interface{}, error) {
|
||||
if item := c.peekAndPin(key); item != nil {
|
||||
return item, nil
|
||||
}
|
||||
|
||||
timer := time.Now()
|
||||
value, ok := c.loader(key)
|
||||
c.stats.TotalLoadTimeMs.Add(uint64(time.Since(timer).Milliseconds()))
|
||||
|
||||
if !ok {
|
||||
c.stats.LoadFailCount.Inc()
|
||||
return nil, ErrNoSuchItem
|
||||
}
|
||||
|
||||
c.stats.LoadSuccessCount.Inc()
|
||||
item, err := c.setAndPin(key, value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return item, nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
return item.(*cacheItem[K, V]), true, nil
|
||||
c.loaderKeyLocks.Lock(key)
|
||||
defer c.loaderKeyLocks.Unlock(key)
|
||||
if item := c.peekAndPin(ctx, key); item != nil {
|
||||
return item, false, nil
|
||||
}
|
||||
timer := time.Now()
|
||||
value, err := c.loader(ctx, key)
|
||||
c.stats.TotalLoadTimeMs.Add(uint64(time.Since(timer).Milliseconds()))
|
||||
if err != nil {
|
||||
c.stats.LoadFailCount.Inc()
|
||||
log.Debug("loader failed for key", zap.Any("key", key))
|
||||
return nil, true, err
|
||||
}
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
c.stats.LoadSuccessCount.Inc()
|
||||
item, err := c.setAndPin(ctx, key, value)
|
||||
if err != nil {
|
||||
log.Debug("setAndPin failed for key", zap.Any("key", key), zap.Error(err))
|
||||
return nil, true, err
|
||||
}
|
||||
return item, true, nil
|
||||
}
|
||||
return nil, true, ErrNoSuchItem
|
||||
}
|
||||
|
||||
|
@ -381,7 +473,7 @@ func (c *lruCache[K, V]) lockfreeTryScavenge(key K) ([]K, bool) {
|
|||
}
|
||||
|
||||
// for cache miss
|
||||
func (c *lruCache[K, V]) setAndPin(key K, value V) (*cacheItem[K, V], error) {
|
||||
func (c *lruCache[K, V]) setAndPin(ctx context.Context, key K, value V) (*cacheItem[K, V], error) {
|
||||
c.rwlock.Lock()
|
||||
defer c.rwlock.Unlock()
|
||||
|
||||
|
@ -390,32 +482,69 @@ func (c *lruCache[K, V]) setAndPin(key K, value V) (*cacheItem[K, V], error) {
|
|||
|
||||
// tryScavenge is done again since the load call is lock free.
|
||||
toEvict, ok := c.lockfreeTryScavenge(key)
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
if !ok {
|
||||
if c.finalizer != nil {
|
||||
c.finalizer(key, value)
|
||||
log.Warn("setAndPin ran into scavenge failure, release data for", zap.Any("key", key))
|
||||
c.finalizer(ctx, key, value)
|
||||
}
|
||||
return nil, ErrNotEnoughSpace
|
||||
}
|
||||
|
||||
for _, ek := range toEvict {
|
||||
e := c.items[ek]
|
||||
delete(c.items, ek)
|
||||
c.accessList.Remove(e)
|
||||
c.scavenger.Throw(ek)
|
||||
c.stats.EvictionCount.Inc()
|
||||
|
||||
if c.finalizer != nil {
|
||||
item := e.Value.(*cacheItem[K, V])
|
||||
timer := time.Now()
|
||||
c.finalizer(ek, item.value)
|
||||
c.stats.TotalFinalizeTimeMs.Add(uint64(time.Since(timer).Milliseconds()))
|
||||
}
|
||||
c.evict(ctx, ek)
|
||||
log.Debug("cache evicting", zap.Any("key", ek), zap.Any("by", key))
|
||||
}
|
||||
|
||||
c.scavenger.Collect(key)
|
||||
e := c.accessList.PushFront(item)
|
||||
c.items[item.key] = e
|
||||
|
||||
log.Debug("setAndPin set up item", zap.Any("item.key", item.key),
|
||||
zap.Int32("pinCount", item.pinCount.Load()))
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (c *lruCache[K, V]) Expire(ctx context.Context, key K) (evicted bool) {
|
||||
c.rwlock.Lock()
|
||||
defer c.rwlock.Unlock()
|
||||
|
||||
e, ok := c.items[key]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
item := e.Value.(*cacheItem[K, V])
|
||||
if item.pinCount.Load() == 0 {
|
||||
c.evict(ctx, key)
|
||||
return true
|
||||
}
|
||||
|
||||
item.needReload = true
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *lruCache[K, V]) evict(ctx context.Context, key K) {
|
||||
c.stats.EvictionCount.Inc()
|
||||
e := c.items[key]
|
||||
delete(c.items, key)
|
||||
c.accessList.Remove(e)
|
||||
c.scavenger.Throw(key)
|
||||
|
||||
if c.finalizer != nil {
|
||||
item := e.Value.(*cacheItem[K, V])
|
||||
c.finalizer(ctx, key, item.value)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *lruCache[K, V]) MarkItemNeedReload(ctx context.Context, key K) bool {
|
||||
c.rwlock.Lock()
|
||||
defer c.rwlock.Unlock()
|
||||
|
||||
if e, ok := c.items[key]; ok {
|
||||
item := e.Value.(*cacheItem[K, V])
|
||||
item.needReload = true
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -12,8 +12,8 @@ import (
|
|||
)
|
||||
|
||||
func TestLRUCache(t *testing.T) {
|
||||
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
})
|
||||
|
||||
t.Run("test loader", func(t *testing.T) {
|
||||
|
@ -21,7 +21,7 @@ func TestLRUCache(t *testing.T) {
|
|||
cache := cacheBuilder.WithCapacity(int64(size)).Build()
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
missing, err := cache.Do(i, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -33,13 +33,13 @@ func TestLRUCache(t *testing.T) {
|
|||
t.Run("test finalizer", func(t *testing.T) {
|
||||
size := 10
|
||||
finalizeSeq := make([]int, 0)
|
||||
cache := cacheBuilder.WithCapacity(int64(size)).WithFinalizer(func(key, value int) error {
|
||||
cache := cacheBuilder.WithCapacity(int64(size)).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
finalizeSeq = append(finalizeSeq, key)
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
for i := 0; i < size*2; i++ {
|
||||
missing, err := cache.Do(i, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -50,7 +50,7 @@ func TestLRUCache(t *testing.T) {
|
|||
|
||||
// Hit the cache again, there should be no swap-out
|
||||
for i := size; i < size*2; i++ {
|
||||
missing, err := cache.Do(i, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -65,13 +65,13 @@ func TestLRUCache(t *testing.T) {
|
|||
sumCapacity := 20 // inserting 1 to 19, capacity is set to sum of 20, expecting (19) at last.
|
||||
cache := cacheBuilder.WithLazyScavenger(func(key int) int64 {
|
||||
return int64(key)
|
||||
}, int64(sumCapacity)).WithFinalizer(func(key, value int) error {
|
||||
}, int64(sumCapacity)).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
finalizeSeq = append(finalizeSeq, key)
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
missing, err := cache.Do(i, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -84,7 +84,7 @@ func TestLRUCache(t *testing.T) {
|
|||
t.Run("test do negative", func(t *testing.T) {
|
||||
cache := cacheBuilder.Build()
|
||||
theErr := errors.New("error")
|
||||
missing, err := cache.Do(-1, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), -1, func(v int) error {
|
||||
return theErr
|
||||
})
|
||||
assert.True(t, missing)
|
||||
|
@ -96,13 +96,13 @@ func TestLRUCache(t *testing.T) {
|
|||
sumCapacity := 20 // inserting 1 to 19, capacity is set to sum of 20, expecting (19) at last.
|
||||
cache := cacheBuilder.WithLazyScavenger(func(key int) int64 {
|
||||
return int64(key)
|
||||
}, int64(sumCapacity)).WithFinalizer(func(key, value int) error {
|
||||
}, int64(sumCapacity)).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
finalizeSeq = append(finalizeSeq, key)
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
missing, err := cache.Do(i, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -110,7 +110,7 @@ func TestLRUCache(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, finalizeSeq)
|
||||
missing, err := cache.Do(100, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), 100, func(v int) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, missing)
|
||||
|
@ -118,28 +118,52 @@ func TestLRUCache(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("test load negative", func(t *testing.T) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
if key < 0 {
|
||||
return 0, false
|
||||
return 0, ErrNoSuchItem
|
||||
}
|
||||
return key, true
|
||||
return key, nil
|
||||
}).Build()
|
||||
missing, err := cache.Do(0, func(v int) error {
|
||||
missing, err := cache.Do(context.Background(), 0, func(v int) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, missing)
|
||||
assert.NoError(t, err)
|
||||
missing, err = cache.Do(-1, func(v int) error {
|
||||
missing, err = cache.Do(context.Background(), -1, func(v int) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, missing)
|
||||
assert.Equal(t, ErrNoSuchItem, err)
|
||||
})
|
||||
|
||||
t.Run("test reloader", func(t *testing.T) {
|
||||
cache := cacheBuilder.WithReloader(func(ctx context.Context, key int) (int, error) {
|
||||
return -key, nil
|
||||
}).Build()
|
||||
_, err := cache.Do(context.Background(), 1, func(i int) error { return nil })
|
||||
assert.NoError(t, err)
|
||||
exist := cache.MarkItemNeedReload(context.Background(), 1)
|
||||
assert.True(t, exist)
|
||||
cache.Do(context.Background(), 1, func(i int) error {
|
||||
assert.Equal(t, -1, i)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test mark", func(t *testing.T) {
|
||||
cache := cacheBuilder.WithCapacity(1).Build()
|
||||
exist := cache.MarkItemNeedReload(context.Background(), 1)
|
||||
assert.False(t, exist)
|
||||
_, err := cache.Do(context.Background(), 1, func(i int) error { return nil })
|
||||
assert.NoError(t, err)
|
||||
exist = cache.MarkItemNeedReload(context.Background(), 1)
|
||||
assert.True(t, exist)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
})
|
||||
|
||||
t.Run("test loader", func(t *testing.T) {
|
||||
|
@ -155,7 +179,7 @@ func TestStats(t *testing.T) {
|
|||
assert.Equal(t, uint64(0), stats.LoadFailCount.Load())
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
_, err := cache.Do(i, func(v int) error {
|
||||
_, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -170,7 +194,7 @@ func TestStats(t *testing.T) {
|
|||
assert.Equal(t, uint64(0), stats.LoadFailCount.Load())
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
_, err := cache.Do(i, func(v int) error {
|
||||
_, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -184,7 +208,7 @@ func TestStats(t *testing.T) {
|
|||
assert.Equal(t, uint64(0), stats.LoadFailCount.Load())
|
||||
|
||||
for i := size; i < size*2; i++ {
|
||||
_, err := cache.Do(i, func(v int) error {
|
||||
_, err := cache.Do(context.Background(), i, func(v int) error {
|
||||
assert.Equal(t, i, v)
|
||||
return nil
|
||||
})
|
||||
|
@ -202,9 +226,9 @@ func TestStats(t *testing.T) {
|
|||
func TestLRUCacheConcurrency(t *testing.T) {
|
||||
t.Run("test race condition", func(t *testing.T) {
|
||||
numEvict := new(atomic.Int32)
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
}).WithCapacity(10).WithFinalizer(func(key, value int) error {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(10).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
numEvict.Add(1)
|
||||
return nil
|
||||
}).Build()
|
||||
|
@ -215,7 +239,7 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 100; j++ {
|
||||
_, err := cache.Do(j, func(v int) error {
|
||||
_, err := cache.Do(context.Background(), j, func(v int) error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
@ -226,9 +250,9 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("test not enough space", func(t *testing.T) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
}).WithCapacity(1).WithFinalizer(func(key, value int) error {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(1).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
|
@ -236,13 +260,13 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
var wg1 sync.WaitGroup // Make sure goroutine is started
|
||||
wg.Add(1)
|
||||
wg1.Add(1)
|
||||
go cache.Do(1000, func(v int) error {
|
||||
go cache.Do(context.Background(), 1000, func(v int) error {
|
||||
wg1.Done()
|
||||
wg.Wait()
|
||||
return nil
|
||||
})
|
||||
wg1.Wait()
|
||||
_, err := cache.Do(1001, func(v int) error {
|
||||
_, err := cache.Do(context.Background(), 1001, func(v int) error {
|
||||
return nil
|
||||
})
|
||||
wg.Done()
|
||||
|
@ -250,9 +274,9 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("test time out", func(t *testing.T) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
}).WithCapacity(1).WithFinalizer(func(key, value int) error {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(1).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
|
@ -260,13 +284,13 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
var wg1 sync.WaitGroup // Make sure goroutine is started
|
||||
wg.Add(1)
|
||||
wg1.Add(1)
|
||||
go cache.Do(1000, func(v int) error {
|
||||
go cache.Do(context.Background(), 1000, func(v int) error {
|
||||
wg1.Done()
|
||||
wg.Wait()
|
||||
return nil
|
||||
})
|
||||
wg1.Wait()
|
||||
missing, err := cache.DoWait(1001, time.Nanosecond, func(v int) error {
|
||||
missing, err := cache.DoWait(context.Background(), 1001, time.Nanosecond, func(ctx context.Context, v int) error {
|
||||
return nil
|
||||
})
|
||||
wg.Done()
|
||||
|
@ -275,22 +299,22 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("test wait", func(t *testing.T) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
}).WithCapacity(1).WithFinalizer(func(key, value int) error {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(1).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
var wg1 sync.WaitGroup // Make sure goroutine is started
|
||||
|
||||
wg1.Add(1)
|
||||
go cache.Do(1000, func(v int) error {
|
||||
go cache.Do(context.Background(), 1000, func(v int) error {
|
||||
wg1.Done()
|
||||
time.Sleep(time.Second)
|
||||
return nil
|
||||
})
|
||||
wg1.Wait()
|
||||
missing, err := cache.DoWait(1001, time.Second*2, func(v int) error {
|
||||
missing, err := cache.DoWait(context.Background(), 1001, time.Second*2, func(ctx context.Context, v int) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, missing)
|
||||
|
@ -299,9 +323,9 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
|
||||
t.Run("test wait race condition", func(t *testing.T) {
|
||||
numEvict := new(atomic.Int32)
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) {
|
||||
return key, true
|
||||
}).WithCapacity(5).WithFinalizer(func(key, value int) error {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(5).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
numEvict.Add(1)
|
||||
return nil
|
||||
}).Build()
|
||||
|
@ -311,9 +335,8 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 20; j++ {
|
||||
_, err := cache.DoWait(j, time.Second, func(v int) error {
|
||||
time.Sleep(time.Duration(rand.Intn(3)))
|
||||
for j := 0; j < 100; j++ {
|
||||
_, err := cache.DoWait(context.Background(), j, 2*time.Second, func(ctx context.Context, v int) error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
@ -322,4 +345,93 @@ func TestLRUCacheConcurrency(t *testing.T) {
|
|||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("test concurrent reload and mark", func(t *testing.T) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(5).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
return nil
|
||||
}).WithReloader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).Build()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error { return nil })
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
for j := 0; j < 100; j++ {
|
||||
cache.MarkItemNeedReload(context.Background(), j)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
for j := 0; j < 100; j++ {
|
||||
cache.DoWait(context.Background(), j, 2*time.Second, func(ctx context.Context, v int) error { return nil })
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("test expire", func(t *testing.T) {
|
||||
cache := NewCacheBuilder[int, int]().WithLoader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).WithCapacity(5).WithFinalizer(func(ctx context.Context, key, value int) error {
|
||||
return nil
|
||||
}).WithReloader(func(ctx context.Context, key int) (int, error) {
|
||||
return key, nil
|
||||
}).Build()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error { return nil })
|
||||
}
|
||||
|
||||
evicted := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
if cache.Expire(context.Background(), i) {
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 100, evicted)
|
||||
|
||||
// all item shouldn't be evicted if they are in used.
|
||||
for i := 0; i < 5; i++ {
|
||||
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error { return nil })
|
||||
}
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
cache.DoWait(context.Background(), i, 2*time.Second, func(ctx context.Context, v int) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
return nil
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
evicted = 0
|
||||
for i := 0; i < 5; i++ {
|
||||
if cache.Expire(context.Background(), i) {
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
assert.Zero(t, evicted)
|
||||
wg.Wait()
|
||||
|
||||
evicted = 0
|
||||
for i := 0; i < 5; i++ {
|
||||
if cache.Expire(context.Background(), i) {
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 5, evicted)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ var (
|
|||
ErrServiceQuotaExceeded = newMilvusError("quota exceeded", 9, false)
|
||||
ErrServiceUnimplemented = newMilvusError("service unimplemented", 10, false)
|
||||
ErrServiceTimeTickLongDelay = newMilvusError("time tick long delay", 11, false)
|
||||
ErrServiceResourceInsufficient = newMilvusError("service resource insufficient", 12, true)
|
||||
|
||||
// Collection related
|
||||
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
|
||||
|
@ -85,6 +86,7 @@ var (
|
|||
ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false)
|
||||
ErrSegmentLack = newMilvusError("segment lacks", 602, false)
|
||||
ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false)
|
||||
ErrSegmentLoadFailed = newMilvusError("segment load failed", 604, false)
|
||||
|
||||
// Index related
|
||||
ErrIndexNotFound = newMilvusError("index not found", 700, false)
|
||||
|
@ -167,6 +169,15 @@ var (
|
|||
|
||||
// Search/Query related
|
||||
ErrInconsistentRequery = newMilvusError("inconsistent requery result", 2200, true)
|
||||
|
||||
// Compaction
|
||||
ErrCompactionReadDeltaLogErr = newMilvusError("fail to read delta log", 2300, false)
|
||||
ErrClusteringCompactionClusterNotSupport = newMilvusError("milvus cluster not support clustering compaction", 2301, false)
|
||||
ErrClusteringCompactionCollectionNotSupport = newMilvusError("collection not support clustering compaction", 2302, false)
|
||||
ErrClusteringCompactionCollectionIsCompacting = newMilvusError("collection is compacting", 2303, false)
|
||||
|
||||
// General
|
||||
ErrOperationNotSupported = newMilvusError("unsupported operation", 3000, false)
|
||||
)
|
||||
|
||||
type milvusError struct {
|
||||
|
|
|
@ -678,6 +678,14 @@ func WrapErrSegmentsNotFound(ids []int64, msg ...string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func WrapErrSegmentLoadFailed(id int64, msg ...string) error {
|
||||
err := wrapFields(ErrSegmentLoadFailed, value("segment", id))
|
||||
if len(msg) > 0 {
|
||||
err = errors.Wrap(err, strings.Join(msg, "->"))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func WrapErrSegmentNotLoaded(id int64, msg ...string) error {
|
||||
err := wrapFields(ErrSegmentNotLoaded, value("segment", id))
|
||||
if len(msg) > 0 {
|
||||
|
|
|
@ -1994,7 +1994,8 @@ type queryNodeConfig struct {
|
|||
MmapDirPath ParamItem `refreshable:"false"`
|
||||
MmapEnabled ParamItem `refreshable:"false"`
|
||||
|
||||
LazyLoadEnabled ParamItem `refreshable:"false"`
|
||||
LazyLoadEnabled ParamItem `refreshable:"false"`
|
||||
LazyLoadWaitTimeout ParamItem `refreshable:"false"`
|
||||
|
||||
// chunk cache
|
||||
ReadAheadPolicy ParamItem `refreshable:"false"`
|
||||
|
@ -2045,6 +2046,7 @@ type queryNodeConfig struct {
|
|||
MemoryIndexLoadPredictMemoryUsageFactor ParamItem `refreshable:"true"`
|
||||
EnableSegmentPrune ParamItem `refreshable:"false"`
|
||||
DefaultSegmentFilterRatio ParamItem `refreshable:"false"`
|
||||
UseStreamComputing ParamItem `refreshable:"false"`
|
||||
}
|
||||
|
||||
func (p *queryNodeConfig) init(base *BaseTable) {
|
||||
|
@ -2237,6 +2239,13 @@ func (p *queryNodeConfig) init(base *BaseTable) {
|
|||
Export: true,
|
||||
}
|
||||
p.LazyLoadEnabled.Init(base.mgr)
|
||||
p.LazyLoadWaitTimeout = ParamItem{
|
||||
Key: "queryNode.lazyloadWaitTimeout",
|
||||
Version: "2.4.0",
|
||||
DefaultValue: "30000",
|
||||
Doc: "max wait timeout duration in milliseconds before start to do lazyload search and retrieve",
|
||||
}
|
||||
p.LazyLoadWaitTimeout.Init(base.mgr)
|
||||
|
||||
p.ReadAheadPolicy = ParamItem{
|
||||
Key: "queryNode.cache.readAheadPolicy",
|
||||
|
@ -2569,6 +2578,13 @@ user-task-polling:
|
|||
Doc: "filter ratio used for pruning segments when searching",
|
||||
}
|
||||
p.DefaultSegmentFilterRatio.Init(base.mgr)
|
||||
p.UseStreamComputing = ParamItem{
|
||||
Key: "queryNode.useStreamComputing",
|
||||
Version: "2.4.0",
|
||||
DefaultValue: "false",
|
||||
Doc: "use stream search mode when searching or querying",
|
||||
}
|
||||
p.UseStreamComputing.Init(base.mgr)
|
||||
}
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////
|
||||
|
|
Loading…
Reference in New Issue