Add default retrieve limit (#24782)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/26239/head
xige-16 2023-08-10 14:11:15 +08:00 committed by GitHub
parent 2b367b6bb0
commit 1055c90456
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 506 additions and 49 deletions

View File

@ -50,3 +50,5 @@ const int DEFAULT_CPU_NUM = 1;
constexpr const char* RADIUS = knowhere::meta::RADIUS;
constexpr const char* RANGE_FILTER = knowhere::meta::RANGE_FILTER;
const int64_t DEFAULT_MAX_OUTPUT_SIZE = 67108864; // bytes, 64MB

View File

@ -73,11 +73,11 @@ SegmentGrowingImpl::try_remove_chunks(FieldId fieldId) {
void
SegmentGrowingImpl::Insert(int64_t reserved_offset,
int64_t size,
int64_t num_rows,
const int64_t* row_ids,
const Timestamp* timestamps_raw,
const InsertData* insert_data) {
AssertInfo(insert_data->num_rows() == size,
AssertInfo(insert_data->num_rows() == num_rows,
"Entities_raw count not equal to insert size");
// AssertInfo(insert_data->fields_data_size() == schema_->size(),
// "num fields of insert data not equal to num of schema fields");
@ -95,15 +95,15 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
// step 3: fill into Segment.ConcurrentVector
insert_record_.timestamps_.set_data_raw(
reserved_offset, timestamps_raw, size);
insert_record_.row_ids_.set_data_raw(reserved_offset, row_ids, size);
reserved_offset, timestamps_raw, num_rows);
insert_record_.row_ids_.set_data_raw(reserved_offset, row_ids, num_rows);
for (auto [field_id, field_meta] : schema_->get_fields()) {
AssertInfo(field_id_to_offset.count(field_id), "Cannot find field_id");
auto data_offset = field_id_to_offset[field_id];
if (!indexing_record_.SyncDataWithIndex(field_id)) {
insert_record_.get_field_data_base(field_id)->set_data_raw(
reserved_offset,
size,
num_rows,
&insert_data->fields_data(data_offset),
field_meta);
}
@ -111,27 +111,36 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
if (segcore_config_.get_enable_growing_segment_index()) {
indexing_record_.AppendingIndex(
reserved_offset,
size,
num_rows,
field_id,
&insert_data->fields_data(data_offset),
insert_record_);
}
// update average row data size
if (datatype_is_variable(field_meta.get_data_type())) {
auto field_data_size = GetRawDataSizeOfDataArray(
&insert_data->fields_data(data_offset), field_meta, num_rows);
SegmentInternalInterface::set_field_avg_size(
field_id, num_rows, field_data_size);
}
try_remove_chunks(field_id);
}
// step 4: set pks to offset
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
AssertInfo(field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
std::vector<PkType> pks(size);
std::vector<PkType> pks(num_rows);
ParsePksFromFieldData(
pks, insert_data->fields_data(field_id_to_offset[field_id]));
for (int i = 0; i < size; ++i) {
for (int i = 0; i < num_rows; ++i) {
insert_record_.insert_pk(pks[i], reserved_offset + i);
}
// step 5: update small indexes
insert_record_.ack_responder_.AddSegment(reserved_offset,
reserved_offset + size);
reserved_offset + num_rows);
}
void
@ -196,6 +205,15 @@ SegmentGrowingImpl::LoadFieldData(const LoadFieldDataInfo& infos) {
if (field_id == primary_field_id) {
insert_record_.insert_pks(field_data);
}
// update average row data size
auto field_meta = (*schema_)[field_id];
if (datatype_is_variable(field_meta.get_data_type())) {
SegmentInternalInterface::set_field_avg_size(
field_id,
num_rows,
storage::GetByteSizeOfFieldDatas(field_data));
}
}
// step 5: update small indexes

View File

@ -77,13 +77,24 @@ SegmentInternalInterface::Search(
std::unique_ptr<proto::segcore::RetrieveResults>
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
Timestamp timestamp) const {
Timestamp timestamp,
int64_t limit_size) const {
std::shared_lock lck(mutex_);
auto results = std::make_unique<proto::segcore::RetrieveResults>();
query::ExecPlanNodeVisitor visitor(*this, timestamp);
auto retrieve_results = visitor.get_retrieve_result(*plan->plan_node_);
retrieve_results.segment_ = (void*)this;
auto result_rows = retrieve_results.result_offsets_.size();
int64_t output_data_size = 0;
for (auto field_id : plan->field_ids_) {
output_data_size += get_field_avg_size(field_id) * result_rows;
}
if (output_data_size > limit_size) {
throw std::runtime_error("query results exceed the limit size " +
std::to_string(limit_size));
}
if (plan->plan_node_->is_count_) {
AssertInfo(retrieve_results.field_data_.size() == 1,
"count result should only have one column");
@ -166,7 +177,7 @@ SegmentInternalInterface::get_real_count() const {
auto plan = std::make_unique<query::RetrievePlan>(get_schema());
plan->plan_node_ = std::make_unique<query::RetrievePlanNode>();
plan->plan_node_->is_count_ = true;
auto res = Retrieve(plan.get(), MAX_TIMESTAMP);
auto res = Retrieve(plan.get(), MAX_TIMESTAMP, INT64_MAX);
AssertInfo(res->fields_data().size() == 1,
"count result should only have one column");
AssertInfo(res->fields_data()[0].has_scalars(),
@ -178,6 +189,61 @@ SegmentInternalInterface::get_real_count() const {
return res->fields_data()[0].scalars().long_data().data(0);
}
int64_t
SegmentInternalInterface::get_field_avg_size(FieldId field_id) const {
AssertInfo(field_id.get() >= 0,
"invalid field id, should be greater than or equal to 0");
if (SystemProperty::Instance().IsSystem(field_id)) {
if (field_id == TimestampFieldID || field_id == RowFieldID) {
return sizeof(int64_t);
}
throw std::runtime_error("unsupported system field id");
}
auto schema = get_schema();
auto& field_meta = schema[field_id];
auto data_type = field_meta.get_data_type();
std::shared_lock lck(mutex_);
if (datatype_is_variable(data_type)) {
if (variable_fields_avg_size_.find(field_id) ==
variable_fields_avg_size_.end()) {
return 0;
}
return variable_fields_avg_size_.at(field_id).second;
} else {
return field_meta.get_sizeof();
}
}
void
SegmentInternalInterface::set_field_avg_size(FieldId field_id,
int64_t num_rows,
int64_t field_size) {
AssertInfo(field_id.get() >= 0,
"invalid field id, should be greater than or equal to 0");
auto schema = get_schema();
auto& field_meta = schema[field_id];
auto data_type = field_meta.get_data_type();
std::unique_lock lck(mutex_);
if (datatype_is_variable(data_type)) {
AssertInfo(num_rows > 0,
"The num rows of field data should be greater than 0");
if (variable_fields_avg_size_.find(field_id) ==
variable_fields_avg_size_.end()) {
variable_fields_avg_size_.emplace(field_id, std::make_pair(0, 0));
}
auto& field_info = variable_fields_avg_size_.at(field_id);
auto size = field_info.first * field_info.second + field_size;
field_info.first = field_info.first + num_rows;
field_info.second = size / field_info.first;
}
}
void
SegmentInternalInterface::timestamp_filter(BitsetType& bitset,
Timestamp timestamp) const {

View File

@ -53,7 +53,9 @@ class SegmentInterface {
Timestamp timestamp) const = 0;
virtual std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan, Timestamp timestamp) const = 0;
Retrieve(const query::RetrievePlan* Plan,
Timestamp timestamp,
int64_t limit_size) const = 0;
// TODO: memory use is not correct when load string or load string index
virtual int64_t
@ -71,6 +73,14 @@ class SegmentInterface {
virtual int64_t
get_real_count() const = 0;
virtual int64_t
get_field_avg_size(FieldId field_id) const = 0;
virtual void
set_field_avg_size(FieldId field_id,
int64_t num_rows,
int64_t field_size) = 0;
// virtual int64_t
// PreDelete(int64_t size) = 0;
@ -131,8 +141,9 @@ class SegmentInternalInterface : public SegmentInterface {
SearchResult& results) const override;
std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* plan,
Timestamp timestamp) const override;
Retrieve(const query::RetrievePlan* Plan,
Timestamp timestamp,
int64_t limit_size) const override;
virtual bool
HasIndex(FieldId field_id) const = 0;
@ -146,6 +157,14 @@ class SegmentInternalInterface : public SegmentInterface {
int64_t
get_real_count() const override;
int64_t
get_field_avg_size(FieldId field_id) const override;
void
set_field_avg_size(FieldId field_id,
int64_t num_rows,
int64_t field_size) override;
public:
virtual void
vector_search(SearchInfo& search_info,
@ -258,6 +277,9 @@ class SegmentInternalInterface : public SegmentInterface {
protected:
mutable std::shared_mutex mutex_;
// fieldID -> std::pair<num_rows, avg_size>
std::unordered_map<FieldId, std::pair<int64_t, int64_t>>
variable_fields_avg_size_; // bytes
};
} // namespace milvus::segcore

View File

@ -259,6 +259,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
std::shared_ptr<ColumnBase> column{};
if (datatype_is_variable(data_type)) {
int64_t field_data_size = 0;
switch (data_type) {
case milvus::DataType::STRING:
case milvus::DataType::VARCHAR: {
@ -270,7 +271,9 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
for (auto i = 0; i < field_data->get_num_rows(); i++) {
auto str = static_cast<const std::string*>(
field_data->RawValue(i));
var_column->Append(str->data(), str->size());
auto str_size = str->size();
var_column->Append(str->data(), str_size);
field_data_size += str_size;
}
}
var_column->Seal();
@ -288,8 +291,10 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
static_cast<const milvus::Json*>(
field_data->RawValue(i))
->data();
auto padded_string_size = padded_string.size();
var_column->Append(padded_string.data(),
padded_string.size());
padded_string_size);
field_data_size += padded_string_size;
}
}
var_column->Seal();
@ -299,6 +304,10 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
default: {
}
}
// update average row data size
SegmentInternalInterface::set_field_avg_size(
field_id, num_rows, field_data_size);
} else {
column = std::make_shared<Column>(num_rows, field_meta);
storage::FieldDataPtr field_data;
@ -468,7 +477,6 @@ SegmentSealedImpl::chunk_data_impl(FieldId field_id, int64_t chunk_id) const {
AssertInfo(get_bit(field_data_ready_bitset_, field_id),
"Can't get bitset element at " + std::to_string(field_id.get()));
auto& field_meta = schema_->operator[](field_id);
auto element_sizeof = field_meta.get_sizeof();
if (auto it = fields_.find(field_id); it != fields_.end()) {
auto& field_data = it->second;
return field_data->Span();

View File

@ -110,6 +110,41 @@ GetSizeOfIdArray(const IdArray& data) {
PanicInfo("unsupported id type");
}
int64_t
GetRawDataSizeOfDataArray(const DataArray* data,
const FieldMeta& field_meta,
int64_t num_rows) {
int64_t result = 0;
auto data_type = field_meta.get_data_type();
if (!datatype_is_variable(data_type)) {
result = field_meta.get_sizeof() * num_rows;
} else {
switch (data_type) {
case DataType::STRING:
case DataType::VARCHAR: {
auto& string_data = FIELD_DATA(data, string);
for (auto& str : string_data) {
result += str.size();
}
break;
}
case DataType::JSON: {
auto& json_data = FIELD_DATA(data, json);
for (auto& json_bytes : json_data) {
result += json_bytes.size();
}
break;
}
default: {
PanicInfo(
fmt::format("unsupported variable datatype {}", data_type));
}
}
}
return result;
}
// Note: this is temporary solution.
// modify bulk script implement to make process more clear

View File

@ -45,6 +45,11 @@ ParsePksFromIDs(std::vector<PkType>& pks,
int64_t
GetSizeOfIdArray(const IdArray& data);
int64_t
GetRawDataSizeOfDataArray(const DataArray* data,
const FieldMeta& field_meta,
int64_t num_rows);
// Note: this is temporary solution.
// modify bulk script implement to make process more clear
std::unique_ptr<DataArray>

View File

@ -108,17 +108,18 @@ Retrieve(CSegmentInterface c_segment,
CRetrievePlan c_plan,
CTraceContext c_trace,
uint64_t timestamp,
CRetrieveResult* result) {
CRetrieveResult* result,
int64_t limit_size) {
try {
auto segment =
static_cast<const milvus::segcore::SegmentInterface*>(c_segment);
static_cast<milvus::segcore::SegmentInterface*>(c_segment);
auto plan = static_cast<const milvus::query::RetrievePlan*>(c_plan);
auto ctx = milvus::tracer::TraceContext{
c_trace.traceID, c_trace.spanID, c_trace.flag};
auto span = milvus::tracer::StartSpan("SegcoreRetrieve", &ctx);
auto retrieve_result = segment->Retrieve(plan, timestamp);
auto retrieve_result = segment->Retrieve(plan, timestamp, limit_size);
auto size = retrieve_result->ByteSizeLong();
void* buffer = malloc(size);
@ -322,6 +323,23 @@ UpdateSealedSegmentIndex(CSegmentInterface c_segment,
}
}
CStatus
UpdateFieldRawDataSize(CSegmentInterface c_segment,
int64_t field_id,
int64_t num_rows,
int64_t field_data_size) {
try {
auto segment_interface =
reinterpret_cast<milvus::segcore::SegmentInterface*>(c_segment);
AssertInfo(segment_interface != nullptr, "segment conversion failed");
segment_interface->set_field_avg_size(
milvus::FieldId(field_id), num_rows, field_data_size);
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
CStatus
DropFieldData(CSegmentInterface c_segment, int64_t field_id) {
try {

View File

@ -54,7 +54,8 @@ Retrieve(CSegmentInterface c_segment,
CRetrievePlan c_plan,
CTraceContext c_trace,
uint64_t timestamp,
CRetrieveResult* result);
CRetrieveResult* result,
int64_t limit_size);
int64_t
GetMemoryUsageInBytes(CSegmentInterface c_segment);
@ -103,6 +104,12 @@ CStatus
UpdateSealedSegmentIndex(CSegmentInterface c_segment,
CLoadIndexInfo c_load_index_info);
CStatus
UpdateFieldRawDataSize(CSegmentInterface c_segment,
int64_t field_id,
int64_t num_rows,
int64_t field_data_size);
CStatus
DropFieldData(CSegmentInterface c_segment, int64_t field_id);

View File

@ -566,6 +566,16 @@ CreateFieldData(const DataType& type, int64_t dim, int64_t total_num_rows) {
}
}
int64_t
GetByteSizeOfFieldDatas(const std::vector<FieldDataPtr>& field_datas) {
int64_t result = 0;
for (auto& data : field_datas) {
result += data->Size();
}
return result;
}
std::vector<storage::FieldDataPtr>
CollectFieldDataChannel(storage::FieldDataChannelPtr& channel) {
std::vector<storage::FieldDataPtr> result;

View File

@ -132,6 +132,9 @@ CreateFieldData(const DataType& type,
int64_t dim = 1,
int64_t total_num_rows = 0);
int64_t
GetByteSizeOfFieldDatas(const std::vector<FieldDataPtr>& field_datas);
std::vector<storage::FieldDataPtr>
CollectFieldDataChannel(storage::FieldDataChannelPtr& channel);

View File

@ -47,6 +47,16 @@ namespace {
const int64_t ROW_COUNT = 10 * 1000;
const int64_t BIAS = 4200;
CStatus
CRetrieve(CSegmentInterface c_segment,
CRetrievePlan c_plan,
CTraceContext c_trace,
uint64_t timestamp,
CRetrieveResult* result) {
return Retrieve(
c_segment, c_plan, c_trace, timestamp, result, DEFAULT_MAX_OUTPUT_SIZE);
}
const char*
get_default_schema_config() {
static std::string conf = R"(name: "default-collection"
@ -434,7 +444,7 @@ TEST(CApiTest, MultiDeleteGrowingSegment) {
auto max_ts = dataset.timestamps_[N - 1] + 10;
CRetrieveResult retrieve_result;
res = Retrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob,
@ -451,7 +461,7 @@ TEST(CApiTest, MultiDeleteGrowingSegment) {
retrive_pks,
proto::plan::GenericValue::kInt64Val);
plan->plan_node_->predicate_ = std::move(term_expr);
res = Retrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob,
retrieve_result.proto_size);
@ -476,7 +486,7 @@ TEST(CApiTest, MultiDeleteGrowingSegment) {
ASSERT_EQ(del_res.error_code, Success);
// retrieve pks in {2}
res = Retrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob,
retrieve_result.proto_size);
@ -534,7 +544,7 @@ TEST(CApiTest, MultiDeleteSealedSegment) {
auto max_ts = dataset.timestamps_[N - 1] + 10;
CRetrieveResult retrieve_result;
auto res = Retrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
auto res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob,
@ -551,7 +561,7 @@ TEST(CApiTest, MultiDeleteSealedSegment) {
retrive_pks,
proto::plan::GenericValue::kInt64Val);
plan->plan_node_->predicate_ = std::move(term_expr);
res = Retrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob,
retrieve_result.proto_size);
@ -576,7 +586,7 @@ TEST(CApiTest, MultiDeleteSealedSegment) {
ASSERT_EQ(del_res.error_code, Success);
// retrieve pks in {2}
res = Retrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob,
retrieve_result.proto_size);
@ -638,7 +648,7 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) {
plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result;
res = Retrieve(
res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
@ -666,7 +676,7 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) {
ASSERT_EQ(del_res.error_code, Success);
// retrieve pks in {1, 2, 3}
res = Retrieve(
res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
@ -710,7 +720,7 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) {
plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result;
auto res = Retrieve(
auto res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
@ -739,7 +749,7 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) {
ASSERT_EQ(del_res.error_code, Success);
// retrieve pks in {1, 2, 3}
res = Retrieve(
res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
@ -811,7 +821,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) {
plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result;
res = Retrieve(
res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
@ -836,7 +846,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) {
ASSERT_EQ(res.error_code, Success);
// retrieve pks in {1, 2, 3}, timestamp = 19
res = Retrieve(
res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
@ -899,7 +909,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) {
plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result;
auto res = Retrieve(
auto res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
@ -1102,7 +1112,7 @@ TEST(CApiTest, RetrieveTestWithExpr) {
plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result;
auto res = Retrieve(
auto res = CRetrieve(
segment, plan.get(), {}, dataset.timestamps_[0], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
@ -4041,6 +4051,39 @@ TEST(CApiTest, SealedSegment_search_float_With_Expr_Predicate_Range) {
DeleteSegment(segment);
}
TEST(CApiTest, SealedSegment_Update_Field_Size) {
auto schema = std::make_shared<Schema>();
auto str_fid = schema->AddDebugField("string", DataType::VARCHAR);
auto vec_fid = schema->AddDebugField(
"vector_float", DataType::VECTOR_FLOAT, DIM, "L2");
schema->set_primary_field_id(str_fid);
auto segment = CreateSealedSegment(schema).release();
auto N = ROW_COUNT;
int row_size = 10;
// update row_size =10 with n rows
auto status = UpdateFieldRawDataSize(segment, str_fid.get(), N, N * row_size);
ASSERT_EQ(status.error_code, Success);
ASSERT_EQ(segment->get_field_avg_size(str_fid), row_size);
// load data and update avg field size
std::vector<std::string> str_datas;
int64_t total_size = 0;
for (int i = 0; i < N; ++i) {
auto str = "string_data_" + std::to_string(i);
total_size += str.size();
str_datas.emplace_back(str);
}
auto res = LoadFieldRawData(segment, str_fid.get(), str_datas.data(), N);
ASSERT_EQ(res.error_code, Success);
ASSERT_EQ(segment->get_field_avg_size(str_fid),
(row_size * N + total_size) / (2 * N));
DeleteSegment(segment);
}
TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) {
auto schema = std::make_shared<Schema>();
auto i8_fid = schema->AddDebugField("age8", DataType::INT8);
@ -4149,7 +4192,7 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) {
plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result;
res = Retrieve(
res = CRetrieve(
segment, plan.get(), {}, raw_data.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();

View File

@ -19,6 +19,13 @@
using namespace milvus;
using namespace milvus::segcore;
std::unique_ptr<proto::segcore::RetrieveResults>
RetrieveUsingDefaultOutputSize(SegmentInterface* segment,
const query::RetrievePlan* plan,
Timestamp timestamp) {
return segment->Retrieve(plan, timestamp, DEFAULT_MAX_OUTPUT_SIZE);
}
TEST(Retrieve, ScalarIndex) {
SUCCEED();
auto index = std::make_unique<ScalarIndexVector>();
@ -77,7 +84,8 @@ TEST(Retrieve, AutoID) {
std::vector<FieldId> target_fields_id{fid_64, fid_vec};
plan->field_ids_ = target_fields_id;
auto retrieve_results = segment->Retrieve(plan.get(), 100);
auto retrieve_results =
RetrieveUsingDefaultOutputSize(segment.get(), plan.get(), 100);
Assert(retrieve_results->fields_data_size() == target_fields_id.size());
auto field0 = retrieve_results->fields_data(0);
Assert(field0.has_scalars());
@ -132,7 +140,8 @@ TEST(Retrieve, AutoID2) {
std::vector<FieldId> target_offsets{fid_64, fid_vec};
plan->field_ids_ = target_offsets;
auto retrieve_results = segment->Retrieve(plan.get(), 100);
auto retrieve_results =
RetrieveUsingDefaultOutputSize(segment.get(), plan.get(), 100);
Assert(retrieve_results->fields_data_size() == target_offsets.size());
auto field0 = retrieve_results->fields_data(0);
Assert(field0.has_scalars());
@ -185,7 +194,8 @@ TEST(Retrieve, NotExist) {
std::vector<FieldId> target_offsets{fid_64, fid_vec};
plan->field_ids_ = target_offsets;
auto retrieve_results = segment->Retrieve(plan.get(), 100);
auto retrieve_results =
RetrieveUsingDefaultOutputSize(segment.get(), plan.get(), 100);
Assert(retrieve_results->fields_data_size() == target_offsets.size());
auto field0 = retrieve_results->fields_data(0);
Assert(field0.has_scalars());
@ -232,7 +242,8 @@ TEST(Retrieve, Empty) {
std::vector<FieldId> target_offsets{fid_64, fid_vec};
plan->field_ids_ = target_offsets;
auto retrieve_results = segment->Retrieve(plan.get(), 100);
auto retrieve_results =
RetrieveUsingDefaultOutputSize(segment.get(), plan.get(), 100);
Assert(retrieve_results->fields_data_size() == target_offsets.size());
auto field0 = retrieve_results->fields_data(0);
@ -243,6 +254,43 @@ TEST(Retrieve, Empty) {
Assert(field1.vectors().float_vector().data_size() == 0);
}
TEST(Retrieve, Limit) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
schema->set_primary_field_id(fid_64);
int64_t N = 101;
auto dataset = DataGen(schema, N, 42);
auto segment = CreateSealedSegment(schema);
SealedLoadFieldData(dataset, *segment);
auto plan = std::make_unique<query::RetrievePlan>(*schema);
auto term_expr = std::make_unique<query::UnaryRangeExprImpl<int64_t>>(
milvus::query::ColumnInfo(
fid_64, DataType::INT64, std::vector<std::string>()),
OpType::GreaterEqual,
0,
proto::plan::GenericValue::kInt64Val);
plan->plan_node_ = std::make_unique<query::RetrievePlanNode>();
plan->plan_node_->predicate_ = std::move(term_expr);
// test query results exceed the limit size
std::vector<FieldId> target_fields{TimestampFieldID, fid_64, fid_vec};
plan->field_ids_ = target_fields;
EXPECT_THROW(segment->Retrieve(plan.get(), N, 1), std::runtime_error);
auto retrieve_results =
segment->Retrieve(plan.get(), N, DEFAULT_MAX_OUTPUT_SIZE);
Assert(retrieve_results->fields_data_size() == target_fields.size());
auto field0 = retrieve_results->fields_data(0);
auto field2 = retrieve_results->fields_data(2);
Assert(field0.scalars().long_data().data_size() == N);
Assert(field2.vectors().float_vector().data_size() == N * DIM);
}
TEST(Retrieve, LargeTimestamp) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
@ -279,8 +327,8 @@ TEST(Retrieve, LargeTimestamp) {
std::vector<int> filter_timestamps{-1, 0, 1, 10, 20};
filter_timestamps.push_back(N / 2);
for (const auto& f_ts : filter_timestamps) {
auto retrieve_results =
segment->Retrieve(plan.get(), ts_offset + 1 + f_ts);
auto retrieve_results = RetrieveUsingDefaultOutputSize(
segment.get(), plan.get(), ts_offset + 1 + f_ts);
Assert(retrieve_results->fields_data_size() == 2);
int target_num = (f_ts + choose_sep) / choose_sep;
@ -341,7 +389,8 @@ TEST(Retrieve, Delete) {
plan->field_ids_ = target_offsets;
{
auto retrieve_results = segment->Retrieve(plan.get(), 100);
auto retrieve_results =
RetrieveUsingDefaultOutputSize(segment.get(), plan.get(), 100);
ASSERT_EQ(retrieve_results->fields_data_size(), target_offsets.size());
auto field0 = retrieve_results->fields_data(0);
Assert(field0.has_scalars());
@ -397,7 +446,8 @@ TEST(Retrieve, Delete) {
reinterpret_cast<const Timestamp*>(new_timestamps.data()));
{
auto retrieve_results = segment->Retrieve(plan.get(), 100);
auto retrieve_results =
RetrieveUsingDefaultOutputSize(segment.get(), plan.get(), 100);
Assert(retrieve_results->fields_data_size() == target_offsets.size());
auto field1 = retrieve_results->fields_data(1);
Assert(field1.has_scalars());

View File

@ -710,7 +710,7 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFields) {
const auto& fvec_meta = schema->operator[](FieldName("fvec"));
const auto& str_meta = schema->operator[](FieldName("str"));
auto N = 100000;
auto N = 10000;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<float>(fvec_meta.get_id());
auto str_col =
@ -731,7 +731,8 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFields) {
Timestamp time = MAX_TIMESTAMP;
auto retrieved = segment->Retrieve(plan.get(), time);
auto retrieved =
segment->Retrieve(plan.get(), time, DEFAULT_MAX_OUTPUT_SIZE);
ASSERT_EQ(retrieved->ids().str_id().data().size(), N);
ASSERT_EQ(retrieved->offset().size(), N);
ASSERT_EQ(retrieved->fields_data().size(), 1);

View File

@ -581,6 +581,12 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
// primary keys duplicate
skipDupCnt++
}
// limit retrieve result to avoid oom
if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64())
}
cursors[sel]++
}

View File

@ -551,6 +551,33 @@ func TestTaskQuery_functions(t *testing.T) {
})
t.Run("test unLimited and maxOutputSize", func(t *testing.T) {
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1")
ids := make([]int64, 100)
offsets := make([]int64, 100)
for i := range ids {
ids[i] = int64(i)
offsets[i] = int64(i)
}
fieldData := getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1)
result := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
FieldsData: []*schemapb.FieldData{fieldData},
}
_, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result}, &queryParams{limit: typeutil.Unlimited})
assert.Error(t, err)
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
})
t.Run("test offset", func(t *testing.T) {
tests := []struct {
description string

View File

@ -889,6 +889,11 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
}
realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK)
// limit search result to avoid oom
if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64())
}
}
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))

View File

@ -157,6 +157,11 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
// // return nil, errors.New("the length (topk) between all result of query is different")
// }
ret.Topks = append(ret.Topks, j)
// limit search result to avoid oom
if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() {
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64())
}
}
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
return ret, nil
@ -290,6 +295,12 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
}
}
// limit retrieve result to avoid oom
if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64())
}
cursors[sel]++
}
@ -378,6 +389,12 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
// primary keys duplicate
skipDupCnt++
}
// limit retrieve result to avoid oom
if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64())
}
cursors[sel]++
}

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -167,6 +168,35 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
}
})
suite.Run("test unLimited and maxOutputSize", func() {
reqLimit := typeutil.Unlimited
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1")
ids := make([]int64, 100)
offsets := make([]int64, 100)
for i := range ids {
ids[i] = int64(i)
offsets[i] = int64(i)
}
fieldData := genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1)
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
Offset: offsets,
FieldsData: []*schemapb.FieldData{fieldData},
}
_, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result}, reqLimit)
suite.Error(err)
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
})
suite.Run("test int ID", func() {
result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, typeutil.Unlimited)
suite.Equal(2, len(result.GetFieldsData()))
@ -346,6 +376,33 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
}
})
suite.Run("test unLimited and maxOutputSize", func() {
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1")
ids := make([]int64, 100)
offsets := make([]int64, 100)
for i := range ids {
ids[i] = int64(i)
offsets[i] = int64(i)
}
fieldData := genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1)
result := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
FieldsData: []*schemapb.FieldData{fieldData},
}
_, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result}, typeutil.Unlimited)
suite.Error(err)
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
})
suite.Run("test int ID", func() {
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, typeutil.Unlimited)
suite.Equal(2, len(result.GetFieldsData()))

View File

@ -428,6 +428,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
flag: C.uchar(span.SpanContext().TraceFlags()),
}
maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
var retrieveResult RetrieveResult
var status C.CStatus
GetSQPool().Submit(func() (any, error) {
@ -438,7 +439,8 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
traceCtx,
ts,
&retrieveResult.cRetrieveResult,
)
C.int64_t(maxLimitSize))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("cgo retrieve done", zap.Duration("timeTaken", tr.ElapseSpan()))
@ -882,3 +884,24 @@ func (s *LocalSegment) LoadIndexInfo(indexInfo *querypb.FieldIndexInfo, info *Lo
return nil
}
func (s *LocalSegment) UpdateFieldRawDataSize(numRows int64, fieldBinlog *datapb.FieldBinlog) error {
var status C.CStatus
fieldID := fieldBinlog.FieldID
fieldDataSize := int64(0)
for _, binlog := range fieldBinlog.GetBinlogs() {
fieldDataSize += binlog.LogSize
}
GetDynamicPool().Submit(func() (any, error) {
status = C.UpdateFieldRawDataSize(s.ptr, C.int64_t(fieldID), C.int64_t(numRows), C.int64_t(fieldDataSize))
return nil, nil
}).Await()
if err := HandleCStatus(&status, "updateFieldRawDataSize failed"); err != nil {
return err
}
log.Info("updateFieldRawDataSize done", zap.Int64("segmentID", s.ID()))
return nil
}

View File

@ -496,7 +496,7 @@ func (loader *segmentLoader) loadSegment(ctx context.Context,
log.Info("load fields...",
zap.Int64s("indexedFields", lo.Keys(indexedFieldInfos)),
)
if err := loader.loadFieldsIndex(ctx, segment, indexedFieldInfos); err != nil {
if err := loader.loadFieldsIndex(ctx, collection.Schema(), segment, loadInfo.GetNumOfRows(), indexedFieldInfos); err != nil {
return err
}
if err := loader.loadSealedSegmentFields(ctx, segment, fieldBinlogs, loadInfo.GetNumOfRows()); err != nil {
@ -570,7 +570,13 @@ func (loader *segmentLoader) loadSealedSegmentFields(ctx context.Context, segmen
return nil
}
func (loader *segmentLoader) loadFieldsIndex(ctx context.Context, segment *LocalSegment, vecFieldInfos map[int64]*IndexedFieldInfo) error {
func (loader *segmentLoader) loadFieldsIndex(ctx context.Context,
schema *schemapb.CollectionSchema,
segment *LocalSegment,
numRows int64,
vecFieldInfos map[int64]*IndexedFieldInfo) error {
schemaHelper, _ := typeutil.CreateSchemaHelper(schema)
for fieldID, fieldInfo := range vecFieldInfos {
indexInfo := fieldInfo.IndexInfo
err := loader.loadFieldIndex(ctx, segment, indexInfo)
@ -586,6 +592,18 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context, segment *Local
)
segment.AddIndex(fieldID, fieldInfo)
// set average row data size of variable field
field, err := schemaHelper.GetFieldFromID(fieldID)
if err != nil {
return err
}
if typeutil.IsVariableDataType(field.GetDataType()) {
err = segment.UpdateFieldRawDataSize(numRows, fieldInfo.FieldBinlog)
if err != nil {
return err
}
}
}
return nil

View File

@ -96,6 +96,7 @@ type quotaConfig struct {
TopKLimit ParamItem `refreshable:"true"`
NQLimit ParamItem `refreshable:"true"`
MaxQueryResultWindow ParamItem `refreshable:"true"`
MaxOutputSize ParamItem `refreshable:"true"`
// limit writing
ForceDenyWriting ParamItem `refreshable:"true"`
@ -865,6 +866,13 @@ Check https://milvus.io/docs/limitations.md for more details.`,
}
p.MaxQueryResultWindow.Init(base.mgr)
p.MaxOutputSize = ParamItem{
Key: "quotaAndLimits.limits.maxOutputSize",
Version: "2.3.0",
DefaultValue: "104857600", // 100 MB, 100 * 1024 * 1024
}
p.MaxOutputSize.Init(base.mgr)
// limit writing
p.ForceDenyWriting = ParamItem{
Key: "quotaAndLimits.limitWriting.forceDeny",

View File

@ -331,6 +331,10 @@ func IsJSONType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_JSON
}
func IsArrayType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_Array
}
// IsFloatingType returns true if input is a floating type, otherwise false
func IsFloatingType(dataType schemapb.DataType) bool {
switch dataType {
@ -366,6 +370,10 @@ func IsStringType(dataType schemapb.DataType) bool {
}
}
func IsVariableDataType(dataType schemapb.DataType) bool {
return IsStringType(dataType) || IsArrayType(dataType) || IsJSONType(dataType)
}
// AppendFieldData appends fields data of specified index from src to dst
func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx int64) {
for i, fieldData := range src {