Re-format cpp code (#22513)

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/22531/head
yah01 2023-03-02 15:55:49 +08:00 committed by GitHub
parent fa86de530d
commit bdd6bc7695
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
145 changed files with 3291 additions and 1446 deletions

View File

@ -18,7 +18,7 @@
# Below is copied from milvus project
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 120
ColumnLimit: 80
IndentWidth: 4
AccessModifierOffset: -3
AlwaysBreakAfterReturnType: All
@ -28,7 +28,9 @@ AllowShortIfStatementsOnASingleLine: false
AlignTrailingComments: true
# Appended Options
SortIncludes: false
SortIncludes: false
Standard: Latest
AlignAfterOpenBracket: Align
BinPackParameters: false
BinPackArguments: false
ReflowComments: false

View File

@ -32,14 +32,17 @@ class BitsetView : public knowhere::BitsetView {
BitsetView() = default;
~BitsetView() = default;
BitsetView(const std::nullptr_t value) : knowhere::BitsetView(value) { // NOLINT
BitsetView(const std::nullptr_t value)
: knowhere::BitsetView(value) { // NOLINT
}
BitsetView(const uint8_t* data, size_t num_bits) : knowhere::BitsetView(data, num_bits) { // NOLINT
BitsetView(const uint8_t* data, size_t num_bits)
: knowhere::BitsetView(data, num_bits) { // NOLINT
}
BitsetView(const BitsetType& bitset) // NOLINT
: BitsetView((uint8_t*)boost_ext::get_data(bitset), size_t(bitset.size())) {
: BitsetView((uint8_t*)boost_ext::get_data(bitset),
size_t(bitset.size())) {
}
BitsetView(const BitsetTypePtr& bitset_ptr) { // NOLINT
@ -56,7 +59,11 @@ class BitsetView : public knowhere::BitsetView {
AssertInfo((offset & 0x7) == 0, "offset is not divisible by 8");
AssertInfo(offset + size <= this->size(),
fmt::format("index out of range, offset={}, size={}, bitset.size={}", offset, size, this->size()));
fmt::format(
"index out of range, offset={}, size={}, bitset.size={}",
offset,
size,
this->size()));
return {data() + (offset >> 3), size};
}
};

View File

@ -16,7 +16,9 @@
namespace milvus {
template <typename T, typename = std::enable_if_t<std::is_fundamental_v<T> || std::is_same_v<T, std::string>>>
template <typename T,
typename = std::enable_if_t<std::is_fundamental_v<T> ||
std::is_same_v<T, std::string>>>
inline CDataType
GetDType() {
return None;

View File

@ -26,13 +26,15 @@ int cpu_num = DEFAULT_CPU_NUM;
void
SetIndexSliceSize(const int64_t size) {
index_file_slice_size = size;
LOG_SEGCORE_DEBUG_ << "set config index slice size: " << index_file_slice_size;
LOG_SEGCORE_DEBUG_ << "set config index slice size: "
<< index_file_slice_size;
}
void
SetThreadCoreCoefficient(const int64_t coefficient) {
thread_core_coefficient = coefficient;
LOG_SEGCORE_DEBUG_ << "set thread pool core coefficient: " << thread_core_coefficient;
LOG_SEGCORE_DEBUG_ << "set thread pool core coefficient: "
<< thread_core_coefficient;
}
void

View File

@ -81,7 +81,8 @@ datatype_name(DataType data_type) {
return "vector_binary";
}
default: {
auto err_msg = "Unsupported DataType(" + std::to_string((int)data_type) + ")";
auto err_msg =
"Unsupported DataType(" + std::to_string((int)data_type) + ")";
PanicInfo(err_msg);
}
}
@ -89,7 +90,8 @@ datatype_name(DataType data_type) {
inline bool
datatype_is_vector(DataType datatype) {
return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT;
return datatype == DataType::VECTOR_BINARY ||
datatype == DataType::VECTOR_FLOAT;
}
inline bool
@ -148,25 +150,39 @@ class FieldMeta {
FieldMeta&
operator=(FieldMeta&&) = default;
FieldMeta(const FieldName& name, FieldId id, DataType type) : name_(name), id_(id), type_(type) {
FieldMeta(const FieldName& name, FieldId id, DataType type)
: name_(name), id_(id), type_(type) {
Assert(!is_vector());
}
FieldMeta(const FieldName& name, FieldId id, DataType type, int64_t max_length)
: name_(name), id_(id), type_(type), string_info_(StringInfo{max_length}) {
FieldMeta(const FieldName& name,
FieldId id,
DataType type,
int64_t max_length)
: name_(name),
id_(id),
type_(type),
string_info_(StringInfo{max_length}) {
Assert(is_string());
}
FieldMeta(
const FieldName& name, FieldId id, DataType type, int64_t dim, std::optional<knowhere::MetricType> metric_type)
: name_(name), id_(id), type_(type), vector_info_(VectorInfo{dim, metric_type}) {
FieldMeta(const FieldName& name,
FieldId id,
DataType type,
int64_t dim,
std::optional<knowhere::MetricType> metric_type)
: name_(name),
id_(id),
type_(type),
vector_info_(VectorInfo{dim, metric_type}) {
Assert(is_vector());
}
bool
is_vector() const {
Assert(type_ != DataType::NONE);
return type_ == DataType::VECTOR_BINARY || type_ == DataType::VECTOR_FLOAT;
return type_ == DataType::VECTOR_BINARY ||
type_ == DataType::VECTOR_FLOAT;
}
bool

View File

@ -38,7 +38,8 @@ struct SearchResult {
if (topk_per_nq_prefix_sum_.empty()) {
return 0;
}
AssertInfo(topk_per_nq_prefix_sum_.size() == total_nq_ + 1, "wrong topk_per_nq_prefix_sum_ size");
AssertInfo(topk_per_nq_prefix_sum_.size() == total_nq_ + 1,
"wrong topk_per_nq_prefix_sum_ size");
return topk_per_nq_prefix_sum_[total_nq_];
}

View File

@ -20,11 +20,14 @@ namespace {
using ResultPair = std::pair<float, int64_t>;
}
DatasetPtr
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type) {
SortRangeSearchResult(DatasetPtr data_set,
int64_t topk,
int64_t nq,
std::string metric_type) {
/**
* nq: number of querys;
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result querys[i]
* for example, the nq is 5. In the seleted range,
* nq: number of queries;
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result queries[i]
* for example, the nq is 5. In the selected range,
* the size of RangeSearch result for each nq is [1, 2, 3, 4, 5],
* the lims will be [0, 1, 3, 6, 10, 15];
* ids: the size of ids is lim[nq],
@ -32,13 +35,13 @@ SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string
* i(1,0), i(1,1), , i(1,k1-1),
* ,
* i(n-1,0), i(n-1,1), , i(n-1,kn-1)},
* i(0,0), i(0,1), , i(0,k0-1) means the ids of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
* i(0,0), i(0,1), , i(0,k0-1) means the ids of RangeSearch result queries[0], k0 equals lim[1] - lim[0];
* dist: the size of ids is lim[nq],
* { d(0,0), d(0,1), , d(0,k0-1),
* d(1,0), d(1,1), , d(1,k1-1),
* ,
* d(n-1,0), d(n-1,1), , d(n-1,kn-1)},
* d(0,0), d(0,1), , d(0,k0-1) means the distances of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
* d(0,0), d(0,1), , d(0,k0-1) means the distances of RangeSearch result queries[0], k0 equals lim[1] - lim[0];
*/
auto lims = GetDatasetLims(data_set);
auto id = GetDatasetIDs(data_set);
@ -65,11 +68,14 @@ SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string
* |------------+---------------| max_heap ascending_order
*
*/
std::function<bool(const ResultPair&, const ResultPair&)> cmp = std::less<std::pair<float, int64_t>>();
std::function<bool(const ResultPair&, const ResultPair&)> cmp =
std::less<std::pair<float, int64_t>>();
if (IsMetricType(metric_type, knowhere::metric::IP)) {
cmp = std::greater<std::pair<float, int64_t>>();
}
std::priority_queue<std::pair<float, int64_t>, std::vector<std::pair<float, int64_t>>, decltype(cmp)>
std::priority_queue<std::pair<float, int64_t>,
std::vector<std::pair<float, int64_t>>,
decltype(cmp)>
sub_result(cmp);
for (int j = lims[i]; j < lims[i + 1]; j++) {
@ -94,7 +100,9 @@ SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string
}
void
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type) {
CheckRangeSearchParam(float radius,
float range_filter,
std::string metric_type) {
/*
* IP: 1.0 range_filter radius
* |------------+---------------| min_heap descending_order

View File

@ -17,8 +17,13 @@
namespace milvus {
DatasetPtr
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type);
SortRangeSearchResult(DatasetPtr data_set,
int64_t topk,
int64_t nq,
std::string metric_type);
void
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type);
CheckRangeSearchParam(float radius,
float range_filter,
std::string metric_type);
} // namespace milvus

View File

@ -26,10 +26,13 @@ namespace milvus {
using std::string;
static std::map<string, string>
RepeatedKeyValToMap(const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>& kvs) {
RepeatedKeyValToMap(
const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>&
kvs) {
std::map<string, string> mapping;
for (auto& kv : kvs) {
AssertInfo(!mapping.count(kv.key()), "repeat key(" + kv.key() + ") in protobuf");
AssertInfo(!mapping.count(kv.key()),
"repeat key(" + kv.key() + ") in protobuf");
mapping.emplace(kv.key(), kv.value());
}
return mapping;
@ -42,15 +45,18 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
// NOTE: only two system
for (const milvus::proto::schema::FieldSchema& child : schema_proto.fields()) {
for (const milvus::proto::schema::FieldSchema& child :
schema_proto.fields()) {
auto field_id = FieldId(child.fieldid());
auto name = FieldName(child.name());
if (field_id.get() < 100) {
// system field id
auto is_system = SystemProperty::Instance().SystemFieldVerify(name, field_id);
auto is_system =
SystemProperty::Instance().SystemFieldVerify(name, field_id);
AssertInfo(is_system,
"invalid system type: name(" + name.get() + "), id(" + std::to_string(field_id.get()) + ")");
"invalid system type: name(" + name.get() + "), id(" +
std::to_string(field_id.get()) + ")");
continue;
}
@ -71,23 +77,28 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
} else if (datatype_is_string(data_type)) {
auto type_map = RepeatedKeyValToMap(child.type_params());
AssertInfo(type_map.count(MAX_LENGTH), "max_length not found");
auto max_len = boost::lexical_cast<int64_t>(type_map.at(MAX_LENGTH));
auto max_len =
boost::lexical_cast<int64_t>(type_map.at(MAX_LENGTH));
schema->AddField(name, field_id, data_type, max_len);
} else {
schema->AddField(name, field_id, data_type);
}
if (child.is_primary_key()) {
AssertInfo(!schema->get_primary_field_id().has_value(), "repetitive primary key");
AssertInfo(!schema->get_primary_field_id().has_value(),
"repetitive primary key");
schema->set_primary_field_id(field_id);
}
}
AssertInfo(schema->get_primary_field_id().has_value(), "primary key should be specified");
AssertInfo(schema->get_primary_field_id().has_value(),
"primary key should be specified");
return schema;
}
const FieldMeta FieldMeta::RowIdMeta(FieldName("RowID"), RowFieldID, DataType::INT64);
const FieldMeta FieldMeta::RowIdMeta(FieldName("RowID"),
RowFieldID,
DataType::INT64);
} // namespace milvus

View File

@ -49,7 +49,8 @@ class Schema {
std::optional<knowhere::MetricType> metric_type) {
auto field_id = FieldId(debug_id);
debug_id++;
auto field_meta = FieldMeta(FieldName(name), field_id, data_type, dim, metric_type);
auto field_meta =
FieldMeta(FieldName(name), field_id, data_type, dim, metric_type);
this->AddField(std::move(field_meta));
return field_id;
}
@ -63,7 +64,10 @@ class Schema {
// string type
void
AddField(const FieldName& name, const FieldId id, DataType data_type, int64_t max_length) {
AddField(const FieldName& name,
const FieldId id,
DataType data_type,
int64_t max_length) {
auto field_meta = FieldMeta(name, id, data_type, max_length);
this->AddField(std::move(field_meta));
}
@ -103,7 +107,8 @@ class Schema {
operator[](FieldId field_id) const {
Assert(field_id.get() >= 0);
AssertInfo(fields_.find(field_id) != fields_.end(),
"Cannot find field with field_id: " + std::to_string(field_id.get()));
"Cannot find field with field_id: " +
std::to_string(field_id.get()));
return fields_.at(field_id);
}
@ -131,7 +136,8 @@ class Schema {
const FieldMeta&
operator[](const FieldName& field_name) const {
auto id_iter = name_ids_.find(field_name);
AssertInfo(id_iter != name_ids_.end(), "Cannot find field with field_name: " + field_name.get());
AssertInfo(id_iter != name_ids_.end(),
"Cannot find field with field_name: " + field_name.get());
return fields_.at(id_iter->second);
}

View File

@ -27,8 +27,11 @@ static const char* SLICE_NUM = "slice_num";
static const char* TOTAL_LEN = "total_len";
void
Slice(
const std::string& prefix, const BinaryPtr& data_src, const int64_t slice_len, BinarySet& binarySet, Config& ret) {
Slice(const std::string& prefix,
const BinaryPtr& data_src,
const int64_t slice_len,
BinarySet& binarySet,
Config& ret) {
if (!data_src) {
return;
}
@ -39,7 +42,8 @@ Slice(
auto size = static_cast<size_t>(ri - i);
auto slice_i = std::shared_ptr<uint8_t[]>(new uint8_t[size]);
memcpy(slice_i.get(), data_src->data.get() + i, size);
binarySet.Append(prefix + "_" + std::to_string(slice_num), slice_i, ri - i);
binarySet.Append(
prefix + "_" + std::to_string(slice_num), slice_i, ri - i);
i = ri;
}
ret[NAME] = prefix;
@ -54,7 +58,8 @@ Assemble(BinarySet& binarySet) {
return;
}
Config meta_data = Config::parse(std::string(reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
Config meta_data = Config::parse(std::string(
reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
for (auto& item : meta_data[META]) {
std::string prefix = item[NAME];
@ -64,7 +69,9 @@ Assemble(BinarySet& binarySet) {
int64_t pos = 0;
for (auto i = 0; i < slice_num; ++i) {
auto slice_i_sp = binarySet.Erase(prefix + "_" + std::to_string(i));
memcpy(p_data.get() + pos, slice_i_sp->data.get(), static_cast<size_t>(slice_i_sp->size));
memcpy(p_data.get() + pos,
slice_i_sp->data.get(),
static_cast<size_t>(slice_i_sp->size));
pos += slice_i_sp->size;
}
binarySet.Append(prefix, p_data, total_len);
@ -76,8 +83,8 @@ Disassemble(BinarySet& binarySet) {
Config meta_info;
auto slice_meta = EraseSliceMeta(binarySet);
if (slice_meta != nullptr) {
Config last_meta_data =
Config::parse(std::string(reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
Config last_meta_data = Config::parse(std::string(
reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
for (auto& item : last_meta_data[META]) {
meta_info[META].emplace_back(item);
}
@ -92,7 +99,8 @@ Disassemble(BinarySet& binarySet) {
}
for (auto& key : slice_key_list) {
Config slice_i;
Slice(key, binarySet.Erase(key), slice_size_in_byte, binarySet, slice_i);
Slice(
key, binarySet.Erase(key), slice_size_in_byte, binarySet, slice_i);
meta_info[META].emplace_back(slice_i);
}
if (!slice_key_list.empty()) {

View File

@ -27,7 +27,9 @@ namespace milvus {
// type erasure to work around virtual restriction
class SpanBase {
public:
explicit SpanBase(const void* data, int64_t row_count, int64_t element_sizeof)
explicit SpanBase(const void* data,
int64_t row_count,
int64_t element_sizeof)
: data_(data), row_count_(row_count), element_sizeof_(element_sizeof) {
}
@ -57,17 +59,21 @@ class Span;
// TODO: refine Span to support T=FloatVector
template <typename T>
class Span<T, typename std::enable_if_t<IsScalar<T> || std::is_same_v<T, PkType>>> {
class Span<
T,
typename std::enable_if_t<IsScalar<T> || std::is_same_v<T, PkType>>> {
public:
using embeded_type = T;
explicit Span(const T* data, int64_t row_count) : data_(data), row_count_(row_count) {
using embedded_type = T;
explicit Span(const T* data, int64_t row_count)
: data_(data), row_count_(row_count) {
}
operator SpanBase() const {
return SpanBase(data_, row_count_, sizeof(T));
}
explicit Span(const SpanBase& base) : Span(reinterpret_cast<const T*>(base.data()), base.row_count()) {
explicit Span(const SpanBase& base)
: Span(reinterpret_cast<const T*>(base.data()), base.row_count()) {
assert(base.element_sizeof() == sizeof(T));
}
@ -97,7 +103,9 @@ class Span<T, typename std::enable_if_t<IsScalar<T> || std::is_same_v<T, PkType>
};
template <typename VectorType>
class Span<VectorType, typename std::enable_if_t<std::is_base_of_v<VectorTrait, VectorType>>> {
class Span<
VectorType,
typename std::enable_if_t<std::is_base_of_v<VectorTrait, VectorType>>> {
public:
using embedded_type = typename VectorType::embedded_type;

View File

@ -24,7 +24,8 @@ namespace milvus {
class SystemPropertyImpl : public SystemProperty {
public:
bool
SystemFieldVerify(const FieldName& field_name, FieldId field_id) const override {
SystemFieldVerify(const FieldName& field_name,
FieldId field_id) const override {
if (!IsSystem(field_name)) {
return false;
}

View File

@ -90,7 +90,8 @@ template <class...>
constexpr std::false_type always_false{};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 64>>;
using aligned_vector =
std::vector<T, boost::alignment::aligned_allocator<T, 64>>;
namespace impl {
// hide identifier name to make auto-completion happy
@ -100,10 +101,15 @@ struct FieldOffsetTag;
struct SegOffsetTag;
}; // namespace impl
using FieldId = fluent::NamedType<int64_t, impl::FieldIdTag, fluent::Comparable, fluent::Hashable>;
using FieldName = fluent::NamedType<std::string, impl::FieldNameTag, fluent::Comparable, fluent::Hashable>;
using FieldId = fluent::
NamedType<int64_t, impl::FieldIdTag, fluent::Comparable, fluent::Hashable>;
using FieldName = fluent::NamedType<std::string,
impl::FieldNameTag,
fluent::Comparable,
fluent::Hashable>;
// using FieldOffset = fluent::NamedType<int64_t, impl::FieldOffsetTag, fluent::Comparable, fluent::Hashable>;
using SegOffset = fluent::NamedType<int64_t, impl::SegOffsetTag, fluent::Arithmetic>;
using SegOffset =
fluent::NamedType<int64_t, impl::SegOffsetTag, fluent::Arithmetic>;
using BitsetType = boost::dynamic_bitset<>;
using BitsetTypePtr = std::shared_ptr<boost::dynamic_bitset<>>;

View File

@ -75,7 +75,10 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
}
inline DatasetPtr
GenResultDataset(const int64_t nq, const int64_t topk, const int64_t* ids, const float* distance) {
GenResultDataset(const int64_t nq,
const int64_t topk,
const int64_t* ids,
const float* distance) {
auto ret_ds = std::make_shared<Dataset>();
ret_ds->SetRows(nq);
ret_ds->SetDim(topk);
@ -192,7 +195,8 @@ GetDataSize(const FieldMeta& field, size_t row_count, const DataArray* data) {
}
default:
PanicInfo(fmt::format("not supported data type {}", datatype_name(data_type)));
PanicInfo(fmt::format("not supported data type {}",
datatype_name(data_type)));
}
}
@ -200,7 +204,10 @@ GetDataSize(const FieldMeta& field, size_t row_count, const DataArray* data) {
}
inline void*
FillField(DataType data_type, size_t size, const LoadFieldDataInfo& info, void* dst) {
FillField(DataType data_type,
size_t size,
const LoadFieldDataInfo& info,
void* dst) {
auto data = info.field_data;
switch (data_type) {
case DataType::BOOL: {
@ -225,10 +232,12 @@ FillField(DataType data_type, size_t size, const LoadFieldDataInfo& info, void*
return memcpy(dst, data->scalars().long_data().data().data(), size);
}
case DataType::FLOAT: {
return memcpy(dst, data->scalars().float_data().data().data(), size);
return memcpy(
dst, data->scalars().float_data().data().data(), size);
}
case DataType::DOUBLE: {
return memcpy(dst, data->scalars().double_data().data().data(), size);
return memcpy(
dst, data->scalars().double_data().data().data(), size);
}
case DataType::VARCHAR: {
char* dest = reinterpret_cast<char*>(dst);
@ -243,7 +252,8 @@ FillField(DataType data_type, size_t size, const LoadFieldDataInfo& info, void*
return dst;
}
case DataType::VECTOR_FLOAT:
return memcpy(dst, data->vectors().float_vector().data().data(), size);
return memcpy(
dst, data->vectors().float_vector().data().data(), size);
case DataType::VECTOR_BINARY:
return memcpy(dst, data->vectors().binary_vector().data(), size);
@ -300,7 +310,8 @@ WriteFieldData(int fd, DataType data_type, const DataArray* data, size_t size) {
return total_written;
}
case DataType::VECTOR_FLOAT:
return write(fd, data->vectors().float_vector().data().data(), size);
return write(
fd, data->vectors().float_vector().data().data(), size);
case DataType::VECTOR_BINARY:
return write(fd, data->vectors().binary_vector().data(), size);
@ -315,7 +326,9 @@ WriteFieldData(int fd, DataType data_type, const DataArray* data, size_t size) {
// if mmap enabled, this writes field data to disk and create a map to the file,
// otherwise this just alloc memory
inline void*
CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataInfo& info) {
CreateMap(int64_t segment_id,
const FieldMeta& field_meta,
const LoadFieldDataInfo& info) {
static int mmap_flags = MAP_PRIVATE;
#ifdef MAP_POPULATE
// macOS doesn't support MAP_POPULATE
@ -324,33 +337,52 @@ CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataIn
// Allocate memory
if (info.mmap_dir_path == nullptr) {
auto data_type = field_meta.get_data_type();
auto data_size = GetDataSize(field_meta, info.row_count, info.field_data);
auto data_size =
GetDataSize(field_meta, info.row_count, info.field_data);
if (data_size == 0)
return nullptr;
// Use anon mapping so we are able to free these memory with munmap only
void* map = mmap(NULL, data_size, PROT_READ | PROT_WRITE, mmap_flags | MAP_ANON, -1, 0);
AssertInfo(map != MAP_FAILED, fmt::format("failed to create anon map, err: {}", strerror(errno)));
void* map = mmap(NULL,
data_size,
PROT_READ | PROT_WRITE,
mmap_flags | MAP_ANON,
-1,
0);
AssertInfo(
map != MAP_FAILED,
fmt::format("failed to create anon map, err: {}", strerror(errno)));
FillField(data_type, data_size, info, map);
return map;
}
auto filepath =
std::filesystem::path(info.mmap_dir_path) / std::to_string(segment_id) / std::to_string(info.field_id);
auto filepath = std::filesystem::path(info.mmap_dir_path) /
std::to_string(segment_id) / std::to_string(info.field_id);
auto dir = filepath.parent_path();
std::filesystem::create_directories(dir);
int fd = open(filepath.c_str(), O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR);
AssertInfo(fd != -1, fmt::format("failed to create mmap file {}", filepath.c_str()));
int fd =
open(filepath.c_str(), O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR);
AssertInfo(fd != -1,
fmt::format("failed to create mmap file {}", filepath.c_str()));
auto data_type = field_meta.get_data_type();
size_t size = field_meta.get_sizeof() * info.row_count;
auto written = WriteFieldData(fd, data_type, info.field_data, size);
AssertInfo(written == size || written != -1 && datatype_is_variable(field_meta.get_data_type()),
fmt::format("failed to write data file {}, written {} but total {}, err: {}", filepath.c_str(), written,
size, strerror(errno)));
AssertInfo(
written == size ||
written != -1 && datatype_is_variable(field_meta.get_data_type()),
fmt::format(
"failed to write data file {}, written {} but total {}, err: {}",
filepath.c_str(),
written,
size,
strerror(errno)));
int ok = fsync(fd);
AssertInfo(ok == 0, fmt::format("failed to fsync mmap data file {}, err: {}", filepath.c_str(), strerror(errno)));
AssertInfo(ok == 0,
fmt::format("failed to fsync mmap data file {}, err: {}",
filepath.c_str(),
strerror(errno)));
// Empty field
if (written == 0) {
@ -359,7 +391,9 @@ CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataIn
auto map = mmap(NULL, written, PROT_READ, mmap_flags, fd, 0);
AssertInfo(map != MAP_FAILED,
fmt::format("failed to create map for data file {}, err: {}", filepath.c_str(), strerror(errno)));
fmt::format("failed to create map for data file {}, err: {}",
filepath.c_str(),
strerror(errno)));
#ifndef MAP_POPULATE
// Manually access the mapping to populate it
@ -373,9 +407,15 @@ CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataIn
// unlink this data file so
// then it will be auto removed after we don't need it again
ok = unlink(filepath.c_str());
AssertInfo(ok == 0, fmt::format("failed to unlink mmap data file {}, err: {}", filepath.c_str(), strerror(errno)));
AssertInfo(ok == 0,
fmt::format("failed to unlink mmap data file {}, err: {}",
filepath.c_str(),
strerror(errno)));
ok = close(fd);
AssertInfo(ok == 0, fmt::format("failed to close data file {}, err: {}", filepath.c_str(), strerror(errno)));
AssertInfo(ok == 0,
fmt::format("failed to close data file {}, err: {}",
filepath.c_str(),
strerror(errno)));
return map;
}

View File

@ -50,7 +50,8 @@ constexpr bool IsVector = std::is_base_of_v<VectorTrait, T>;
template <typename T>
constexpr bool IsScalar =
std::is_fundamental_v<T> || std::is_same_v<T, std::string> || std::is_same_v<T, std::string_view>;
std::is_fundamental_v<T> || std::is_same_v<T, std::string> ||
std::is_same_v<T, std::string_view>;
template <typename T, typename Enabled = void>
struct EmbeddedTypeImpl;
@ -62,7 +63,8 @@ struct EmbeddedTypeImpl<T, std::enable_if_t<IsScalar<T>>> {
template <typename T>
struct EmbeddedTypeImpl<T, std::enable_if_t<IsVector<T>>> {
using type = std::conditional_t<std::is_same_v<T, FloatVector>, float, uint8_t>;
using type =
std::conditional_t<std::is_same_v<T, FloatVector>, float, uint8_t>;
};
template <typename T>

View File

@ -41,7 +41,10 @@ DeleteBinarySet(CBinarySet c_binary_set) {
}
CStatus
AppendIndexBinary(CBinarySet c_binary_set, void* index_binary, int64_t index_size, const char* c_index_key) {
AppendIndexBinary(CBinarySet c_binary_set,
void* index_binary,
int64_t index_size,
const char* c_index_key) {
auto status = CStatus();
try {
auto binary_set = (knowhere::BinarySet*)c_binary_set;
@ -68,13 +71,13 @@ GetBinarySetSize(CBinarySet c_binary_set) {
}
void
GetBinarySetKeys(CBinarySet c_binary_set, void* datas) {
GetBinarySetKeys(CBinarySet c_binary_set, void* data) {
auto binary_set = (knowhere::BinarySet*)c_binary_set;
auto& map_ = binary_set->binary_map_;
const char** datas_ = (const char**)datas;
const char** data_ = (const char**)data;
std::size_t i = 0;
for (auto it = map_.begin(); it != map_.end(); ++it, i++) {
datas_[i] = it->first.c_str();
data_[i] = it->first.c_str();
}
}

View File

@ -32,13 +32,16 @@ void
DeleteBinarySet(CBinarySet c_binary_set);
CStatus
AppendIndexBinary(CBinarySet c_binary_set, void* index_binary, int64_t index_size, const char* c_index_key);
AppendIndexBinary(CBinarySet c_binary_set,
void* index_binary,
int64_t index_size,
const char* c_index_key);
int
GetBinarySetSize(CBinarySet c_binary_set);
void
GetBinarySetKeys(CBinarySet c_binary_set, void* datas);
GetBinarySetKeys(CBinarySet c_binary_set, void* data);
int
GetBinarySetValueSize(CBinarySet c_set, const char* key);

View File

@ -29,7 +29,11 @@ void
InitLocalRootPath(const char* root_path) {
std::string local_path_root(root_path);
std::call_once(
flag1, [](std::string path) { milvus::ChunkMangerConfig::SetLocalRootPath(path); }, local_path_root);
flag1,
[](std::string path) {
milvus::ChunkMangerConfig::SetLocalRootPath(path);
},
local_path_root);
}
void
@ -41,7 +45,9 @@ InitIndexSliceSize(const int64_t size) {
void
InitThreadCoreCoefficient(const int64_t value) {
std::call_once(
flag3, [](int64_t value) { milvus::SetThreadCoreCoefficient(value); }, value);
flag3,
[](int64_t value) { milvus::SetThreadCoreCoefficient(value); },
value);
}
void

View File

@ -45,7 +45,8 @@ KnowhereInitImpl(const char* conf_file) {
if (conf_file != nullptr) {
el::Configurations el_conf(conf_file);
el::Loggers::reconfigureAllLoggers(el_conf);
LOG_SERVER_DEBUG_ << "Config easylogging with yaml file: " << conf_file;
LOG_SERVER_DEBUG_ << "Config easylogging with yaml file: "
<< conf_file;
}
LOG_SERVER_DEBUG_ << "Knowhere init successfully";
#endif

View File

@ -45,7 +45,8 @@ EasyAssertInfo(bool value,
if (!value) {
std::string info;
info += "Assert \"" + std::string(expr_str) + "\"";
info += " at " + std::string(filename) + ":" + std::to_string(lineno) + "\n";
info += " at " + std::string(filename) + ":" + std::to_string(lineno) +
"\n";
if (!extra_info.empty()) {
info += " => " + std::string(extra_info);
}

View File

@ -54,13 +54,14 @@ class SegcoreError : public std::runtime_error {
} // namespace milvus
#define AssertInfo(expr, info) \
do { \
auto _expr_res = bool(expr); \
/* call func only when needed */ \
if (!_expr_res) { \
milvus::impl::EasyAssertInfo(_expr_res, #expr, __FILE__, __LINE__, (info)); \
} \
#define AssertInfo(expr, info) \
do { \
auto _expr_res = bool(expr); \
/* call func only when needed */ \
if (!_expr_res) { \
milvus::impl::EasyAssertInfo( \
_expr_res, #expr, __FILE__, __LINE__, (info)); \
} \
} while (0)
#define Assert(expr) AssertInfo((expr), "")
@ -70,8 +71,9 @@ class SegcoreError : public std::runtime_error {
__builtin_unreachable(); \
} while (0)
#define PanicCodeInfo(errcode, info) \
do { \
milvus::impl::EasyAssertInfo(false, (info), __FILE__, __LINE__, "", errcode); \
__builtin_unreachable(); \
#define PanicCodeInfo(errcode, info) \
do { \
milvus::impl::EasyAssertInfo( \
false, (info), __FILE__, __LINE__, "", errcode); \
__builtin_unreachable(); \
} while (0)

View File

@ -34,7 +34,9 @@ class IndexBase {
Load(const BinarySet& binary_set, const Config& config = {}) = 0;
virtual void
BuildWithRawData(size_t n, const void* values, const Config& config = {}) = 0;
BuildWithRawData(size_t n,
const void* values,
const Config& config = {}) = 0;
virtual void
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) = 0;

View File

@ -27,7 +27,8 @@
namespace milvus::index {
IndexBasePtr
IndexFactory::CreateIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager) {
IndexFactory::CreateIndex(const CreateIndexInfo& create_index_info,
storage::FileManagerImplPtr file_manager) {
if (datatype_is_vector(create_index_info.field_type)) {
return CreateVectorIndex(create_index_info, file_manager);
}
@ -62,13 +63,15 @@ IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info) {
case DataType::VARCHAR:
return CreateScalarIndex<std::string>(index_type);
default:
throw std::invalid_argument(std::string("invalid data type to build index: ") +
std::to_string(int(data_type)));
throw std::invalid_argument(
std::string("invalid data type to build index: ") +
std::to_string(int(data_type)));
}
}
IndexBasePtr
IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager) {
IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info,
storage::FileManagerImplPtr file_manager) {
auto data_type = create_index_info.field_type;
auto index_type = create_index_info.index_type;
auto metric_type = create_index_info.metric_type;
@ -79,20 +82,24 @@ IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info, storag
if (is_in_disk_list(index_type)) {
switch (data_type) {
case DataType::VECTOR_FLOAT: {
return std::make_unique<VectorDiskAnnIndex<float>>(index_type, metric_type, index_mode, file_manager);
return std::make_unique<VectorDiskAnnIndex<float>>(
index_type, metric_type, index_mode, file_manager);
}
default:
throw std::invalid_argument(std::string("invalid data type to build disk index: ") +
std::to_string(int(data_type)));
throw std::invalid_argument(
std::string("invalid data type to build disk index: ") +
std::to_string(int(data_type)));
}
}
#endif
if (is_in_nm_list(index_type)) {
return std::make_unique<VectorMemNMIndex>(index_type, metric_type, index_mode);
return std::make_unique<VectorMemNMIndex>(
index_type, metric_type, index_mode);
}
// create mem index
return std::make_unique<VectorMemIndex>(index_type, metric_type, index_mode);
return std::make_unique<VectorMemIndex>(
index_type, metric_type, index_mode);
}
} // namespace milvus::index

View File

@ -53,10 +53,12 @@ class IndexFactory {
}
IndexBasePtr
CreateIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager);
CreateIndex(const CreateIndexInfo& create_index_info,
storage::FileManagerImplPtr file_manager);
IndexBasePtr
CreateVectorIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager);
CreateVectorIndex(const CreateIndexInfo& create_index_info,
storage::FileManagerImplPtr file_manager);
IndexBasePtr
CreateScalarIndex(const CreateIndexInfo& create_index_info);

View File

@ -38,9 +38,14 @@ ScalarIndex<T>::Query(const DatasetPtr& dataset) {
case OpType::Range: {
auto lower_bound_value = dataset->Get<T>(LOWER_BOUND_VALUE);
auto upper_bound_value = dataset->Get<T>(UPPER_BOUND_VALUE);
auto lower_bound_inclusive = dataset->Get<bool>(LOWER_BOUND_INCLUSIVE);
auto upper_bound_inclusive = dataset->Get<bool>(UPPER_BOUND_INCLUSIVE);
return Range(lower_bound_value, lower_bound_inclusive, upper_bound_value, upper_bound_inclusive);
auto lower_bound_inclusive =
dataset->Get<bool>(LOWER_BOUND_INCLUSIVE);
auto upper_bound_inclusive =
dataset->Get<bool>(UPPER_BOUND_INCLUSIVE);
return Range(lower_bound_value,
lower_bound_inclusive,
upper_bound_value,
upper_bound_inclusive);
}
case OpType::In: {
@ -58,13 +63,16 @@ ScalarIndex<T>::Query(const DatasetPtr& dataset) {
case OpType::PrefixMatch:
case OpType::PostfixMatch:
default:
throw std::invalid_argument(std::string("unsupported operator type: " + std::to_string(op)));
throw std::invalid_argument(std::string(
"unsupported operator type: " + std::to_string(op)));
}
}
template <>
inline void
ScalarIndex<std::string>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<std::string>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
// TODO :: use arrow
proto::schema::StringArray arr;
auto ok = arr.ParseFromArray(values, n);
@ -77,7 +85,9 @@ ScalarIndex<std::string>::BuildWithRawData(size_t n, const void* values, const C
template <>
inline void
ScalarIndex<bool>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<bool>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
proto::schema::BoolArray arr;
auto ok = arr.ParseFromArray(values, n);
Assert(ok);
@ -86,42 +96,54 @@ ScalarIndex<bool>::BuildWithRawData(size_t n, const void* values, const Config&
template <>
inline void
ScalarIndex<int8_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<int8_t>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
auto data = reinterpret_cast<int8_t*>(const_cast<void*>(values));
Build(n, data);
}
template <>
inline void
ScalarIndex<int16_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<int16_t>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
auto data = reinterpret_cast<int16_t*>(const_cast<void*>(values));
Build(n, data);
}
template <>
inline void
ScalarIndex<int32_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<int32_t>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
auto data = reinterpret_cast<int32_t*>(const_cast<void*>(values));
Build(n, data);
}
template <>
inline void
ScalarIndex<int64_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<int64_t>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
auto data = reinterpret_cast<int64_t*>(const_cast<void*>(values));
Build(n, data);
}
template <>
inline void
ScalarIndex<float>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<float>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
auto data = reinterpret_cast<float*>(const_cast<void*>(values));
Build(n, data);
}
template <>
inline void
ScalarIndex<double>::BuildWithRawData(size_t n, const void* values, const Config& config) {
ScalarIndex<double>::BuildWithRawData(size_t n,
const void* values,
const Config& config) {
auto data = reinterpret_cast<double*>(const_cast<void*>(values));
Build(n, data);
}

View File

@ -31,10 +31,13 @@ template <typename T>
class ScalarIndex : public IndexBase {
public:
void
BuildWithRawData(size_t n, const void* values, const Config& config = {}) override;
BuildWithRawData(size_t n,
const void* values,
const Config& config = {}) override;
void
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override {
BuildWithDataset(const DatasetPtr& dataset,
const Config& config = {}) override {
PanicInfo("scalar index don't support build index with dataset");
};
@ -52,7 +55,10 @@ class ScalarIndex : public IndexBase {
Range(T value, OpType op) = 0;
virtual const TargetBitmapPtr
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) = 0;
Range(T lower_bound_value,
bool lb_inclusive,
T upper_bound_value,
bool ub_inclusive) = 0;
virtual T
Reverse_Lookup(size_t offset) const = 0;

View File

@ -32,7 +32,8 @@ inline ScalarIndexSort<T>::ScalarIndexSort() : is_built_(false), data_() {
}
template <typename T>
inline ScalarIndexSort<T>::ScalarIndexSort(const size_t n, const T* values) : is_built_(false) {
inline ScalarIndexSort<T>::ScalarIndexSort(const size_t n, const T* values)
: is_built_(false) {
ScalarIndexSort<T>::BuildWithDataset(n, values);
}
@ -43,7 +44,8 @@ ScalarIndexSort<T>::Build(const size_t n, const T* values) {
return;
if (n == 0) {
// todo: throw an exception
throw std::invalid_argument("ScalarIndexSort cannot build null values!");
throw std::invalid_argument(
"ScalarIndexSort cannot build null values!");
}
data_.reserve(n);
idx_to_offsets_.resize(n);
@ -104,12 +106,15 @@ ScalarIndexSort<T>::In(const size_t n, const T* values) {
AssertInfo(is_built_, "index has not been built");
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto lb = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
for (; lb < ub; ++lb) {
if (lb->a_ != *(values + i)) {
std::cout << "error happens in ScalarIndexSort<T>::In, experted value is: " << *(values + i)
<< ", but real value is: " << lb->a_;
std::cout << "error happens in ScalarIndexSort<T>::In, "
"experted value is: "
<< *(values + i) << ", but real value is: " << lb->a_;
}
bitset->set(lb->idx_);
}
@ -124,12 +129,15 @@ ScalarIndexSort<T>::NotIn(const size_t n, const T* values) {
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto lb = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
for (; lb < ub; ++lb) {
if (lb->a_ != *(values + i)) {
std::cout << "error happens in ScalarIndexSort<T>::NotIn, experted value is: " << *(values + i)
<< ", but real value is: " << lb->a_;
std::cout << "error happens in ScalarIndexSort<T>::NotIn, "
"experted value is: "
<< *(values + i) << ", but real value is: " << lb->a_;
}
bitset->reset(lb->idx_);
}
@ -146,19 +154,24 @@ ScalarIndexSort<T>::Range(const T value, const OpType op) {
auto ub = data_.end();
switch (op) {
case OpType::LessThan:
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
ub = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OpType::LessEqual:
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
ub = std::upper_bound(
data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OpType::GreaterThan:
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
lb = std::upper_bound(
data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OpType::GreaterEqual:
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
lb = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(value));
break;
default:
throw std::invalid_argument(std::string("Invalid OperatorType: ") + std::to_string((int)op) + "!");
throw std::invalid_argument(std::string("Invalid OperatorType: ") +
std::to_string((int)op) + "!");
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
@ -168,24 +181,32 @@ ScalarIndexSort<T>::Range(const T value, const OpType op) {
template <typename T>
inline const TargetBitmapPtr
ScalarIndexSort<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
ScalarIndexSort<T>::Range(T lower_bound_value,
bool lb_inclusive,
T upper_bound_value,
bool ub_inclusive) {
AssertInfo(is_built_, "index has not been built");
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
if (lower_bound_value > upper_bound_value ||
(lower_bound_value == upper_bound_value && !(lb_inclusive && ub_inclusive))) {
(lower_bound_value == upper_bound_value &&
!(lb_inclusive && ub_inclusive))) {
return bitset;
}
auto lb = data_.begin();
auto ub = data_.end();
if (lb_inclusive) {
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
lb = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
} else {
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
lb = std::upper_bound(
data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
}
if (ub_inclusive) {
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
ub = std::upper_bound(
data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
} else {
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
ub = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);

View File

@ -56,7 +56,10 @@ class ScalarIndexSort : public ScalarIndex<T> {
Range(T value, OpType op) override;
const TargetBitmapPtr
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override;
Range(T lower_bound_value,
bool lb_inclusive,
T upper_bound_value,
bool ub_inclusive) override;
T
Reverse_Lookup(size_t offset) const override;

View File

@ -63,7 +63,8 @@ StringIndexMarisa::Serialize(const Config& config) {
auto uuid_string = boost::uuids::to_string(uuid);
auto file = std::string("/tmp/") + uuid_string;
auto fd = open(file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
auto fd = open(
file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
trie_.write(fd);
auto size = get_file_size(fd);
@ -101,7 +102,8 @@ StringIndexMarisa::Load(const BinarySet& set, const Config& config) {
auto index = set.GetByName(MARISA_TRIE_INDEX);
auto len = index->size;
auto fd = open(file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
auto fd = open(
file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
lseek(fd, 0, SEEK_SET);
while (write(fd, index->data.get(), len) != len) {
lseek(fd, 0, SEEK_SET);
@ -182,7 +184,9 @@ StringIndexMarisa::Range(std::string value, OpType op) {
set = raw_data.compare(value) >= 0;
break;
default:
throw std::invalid_argument(std::string("Invalid OperatorType: ") + std::to_string((int)op) + "!");
throw std::invalid_argument(
std::string("Invalid OperatorType: ") +
std::to_string((int)op) + "!");
}
if (set) {
bitset->set(offset);
@ -199,7 +203,8 @@ StringIndexMarisa::Range(std::string lower_bound_value,
auto count = Count();
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(count);
if (lower_bound_value.compare(upper_bound_value) > 0 ||
(lower_bound_value.compare(upper_bound_value) == 0 && !(lb_inclusive && ub_inclusive))) {
(lower_bound_value.compare(upper_bound_value) == 0 &&
!(lb_inclusive && ub_inclusive))) {
return bitset;
}
marisa::Agent agent;

View File

@ -58,7 +58,10 @@ class StringIndexMarisa : public StringIndex {
Range(std::string value, OpType op) override;
const TargetBitmapPtr
Range(std::string lower_bound_value, bool lb_inclusive, std::string upper_bound_value, bool ub_inclusive) override;
Range(std::string lower_bound_value,
bool lb_inclusive,
std::string upper_bound_value,
bool ub_inclusive) override;
const TargetBitmapPtr
PrefixMatch(std::string prefix) override;

View File

@ -62,7 +62,8 @@ DISK_LIST() {
std::vector<std::tuple<IndexType, MetricType>>
unsupported_index_combinations() {
static std::vector<std::tuple<IndexType, MetricType>> ret{
std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, knowhere::metric::L2),
std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
knowhere::metric::L2),
};
return ret;
}
@ -84,8 +85,9 @@ is_in_disk_list(const IndexType& index_type) {
bool
is_unsupported(const IndexType& index_type, const MetricType& metric_type) {
return is_in_list<std::tuple<IndexType, MetricType>>(std::make_tuple(index_type, metric_type),
unsupported_index_combinations);
return is_in_list<std::tuple<IndexType, MetricType>>(
std::make_tuple(index_type, metric_type),
unsupported_index_combinations);
}
bool
@ -123,7 +125,8 @@ GetIndexTypeFromConfig(const Config& config) {
IndexMode
GetIndexModeFromConfig(const Config& config) {
auto mode = GetValueFromConfig<std::string>(config, INDEX_MODE);
return mode.has_value() ? GetIndexMode(mode.value()) : knowhere::IndexMode::MODE_CPU;
return mode.has_value() ? GetIndexMode(mode.value())
: knowhere::IndexMode::MODE_CPU;
}
IndexMode
@ -144,22 +147,28 @@ storage::FieldDataMeta
GetFieldDataMetaFromConfig(const Config& config) {
storage::FieldDataMeta field_data_meta;
// set collection id
auto collection_id = index::GetValueFromConfig<std::string>(config, index::COLLECTION_ID);
AssertInfo(collection_id.has_value(), "collection id not exist in index config");
auto collection_id =
index::GetValueFromConfig<std::string>(config, index::COLLECTION_ID);
AssertInfo(collection_id.has_value(),
"collection id not exist in index config");
field_data_meta.collection_id = std::stol(collection_id.value());
// set partition id
auto partition_id = index::GetValueFromConfig<std::string>(config, index::PARTITION_ID);
AssertInfo(partition_id.has_value(), "partition id not exist in index config");
auto partition_id =
index::GetValueFromConfig<std::string>(config, index::PARTITION_ID);
AssertInfo(partition_id.has_value(),
"partition id not exist in index config");
field_data_meta.partition_id = std::stol(partition_id.value());
// set segment id
auto segment_id = index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
auto segment_id =
index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
AssertInfo(segment_id.has_value(), "segment id not exist in index config");
field_data_meta.segment_id = std::stol(segment_id.value());
// set field id
auto field_id = index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
auto field_id =
index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
AssertInfo(field_id.has_value(), "field id not exist in index config");
field_data_meta.field_id = std::stol(field_id.value());
@ -170,22 +179,27 @@ storage::IndexMeta
GetIndexMetaFromConfig(const Config& config) {
storage::IndexMeta index_meta;
// set segment id
auto segment_id = index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
auto segment_id =
index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
AssertInfo(segment_id.has_value(), "segment id not exist in index config");
index_meta.segment_id = std::stol(segment_id.value());
// set field id
auto field_id = index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
auto field_id =
index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
AssertInfo(field_id.has_value(), "field id not exist in index config");
index_meta.field_id = std::stol(field_id.value());
// set index version
auto index_version = index::GetValueFromConfig<std::string>(config, index::INDEX_VERSION);
AssertInfo(index_version.has_value(), "index_version id not exist in index config");
auto index_version =
index::GetValueFromConfig<std::string>(config, index::INDEX_VERSION);
AssertInfo(index_version.has_value(),
"index_version id not exist in index config");
index_meta.index_version = std::stol(index_version.value());
// set index id
auto build_id = index::GetValueFromConfig<std::string>(config, index::INDEX_BUILD_ID);
auto build_id =
index::GetValueFromConfig<std::string>(config, index::INDEX_BUILD_ID);
AssertInfo(build_id.has_value(), "build id not exist in index config");
index_meta.build_id = std::stol(build_id.value());
@ -193,7 +207,8 @@ GetIndexMetaFromConfig(const Config& config) {
}
Config
ParseConfigFromIndexParams(const std::map<std::string, std::string>& index_params) {
ParseConfigFromIndexParams(
const std::map<std::string, std::string>& index_params) {
Config config;
for (auto& p : index_params) {
config[p.first] = p.second;

View File

@ -121,6 +121,7 @@ storage::IndexMeta
GetIndexMetaFromConfig(const Config& config);
Config
ParseConfigFromIndexParams(const std::map<std::string, std::string>& index_params);
ParseConfigFromIndexParams(
const std::map<std::string, std::string>& index_params);
} // namespace milvus::index

View File

@ -35,12 +35,14 @@ namespace milvus::index {
#define kPrepareRows 1
template <typename T>
VectorDiskAnnIndex<T>::VectorDiskAnnIndex(const IndexType& index_type,
const MetricType& metric_type,
const IndexMode& index_mode,
storage::FileManagerImplPtr file_manager)
VectorDiskAnnIndex<T>::VectorDiskAnnIndex(
const IndexType& index_type,
const MetricType& metric_type,
const IndexMode& index_mode,
storage::FileManagerImplPtr file_manager)
: VectorIndex(index_type, index_mode, metric_type) {
file_manager_ = std::dynamic_pointer_cast<storage::DiskFileManagerImpl>(file_manager);
file_manager_ =
std::dynamic_pointer_cast<storage::DiskFileManagerImpl>(file_manager);
auto& local_chunk_manager = storage::LocalChunkManager::GetInstance();
auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
@ -52,17 +54,22 @@ VectorDiskAnnIndex<T>::VectorDiskAnnIndex(const IndexType& index_type,
}
local_chunk_manager.CreateDir(local_index_path_prefix);
auto diskann_index_pack = knowhere::Pack(std::shared_ptr<knowhere::FileManager>(file_manager));
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(), diskann_index_pack);
auto diskann_index_pack =
knowhere::Pack(std::shared_ptr<knowhere::FileManager>(file_manager));
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(),
diskann_index_pack);
}
template <typename T>
void
VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */, const Config& config) {
VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */,
const Config& config) {
knowhere::Json load_config = update_load_json(config);
auto index_files = GetValueFromConfig<std::vector<std::string>>(config, "index_files");
AssertInfo(index_files.has_value(), "index file paths is empty when load disk ann index data");
auto index_files =
GetValueFromConfig<std::vector<std::string>>(config, "index_files");
AssertInfo(index_files.has_value(),
"index file paths is empty when load disk ann index data");
file_manager_->CacheIndexToDisk(index_files.value());
// todo : replace by index::load function later
@ -79,21 +86,25 @@ VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */, const Co
template <typename T>
void
VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset, const Config& config) {
VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
const Config& config) {
auto& local_chunk_manager = storage::LocalChunkManager::GetInstance();
knowhere::Json build_config;
build_config.update(config);
// set data path
auto segment_id = file_manager_->GetFileDataMeta().segment_id;
auto field_id = file_manager_->GetFileDataMeta().field_id;
auto local_data_path = storage::GenFieldRawDataPathPrefix(segment_id, field_id) + "raw_data";
auto local_data_path =
storage::GenFieldRawDataPathPrefix(segment_id, field_id) + "raw_data";
build_config[DISK_ANN_RAW_DATA_PATH] = local_data_path;
auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
build_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;
auto num_threads = GetValueFromConfig<std::string>(build_config, DISK_ANN_BUILD_THREAD_NUM);
AssertInfo(num_threads.has_value(), "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty");
auto num_threads = GetValueFromConfig<std::string>(
build_config, DISK_ANN_BUILD_THREAD_NUM);
AssertInfo(num_threads.has_value(),
"param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty");
build_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str());
if (!local_chunk_manager.Exist(local_data_path)) {
@ -116,14 +127,17 @@ VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset, const Config&
knowhere::DataSet* ds_ptr = nullptr;
index_.Build(*ds_ptr, build_config);
local_chunk_manager.RemoveDir(storage::GetSegmentRawDataPathPrefix(segment_id));
local_chunk_manager.RemoveDir(
storage::GetSegmentRawDataPathPrefix(segment_id));
// TODO ::
// SetDim(index_->Dim());
}
template <typename T>
std::unique_ptr<SearchResult>
VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) {
VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) {
AssertInfo(GetMetricType() == search_info.metric_type_,
"Metric type of field index isn't the same with search info");
auto num_queries = dataset->GetRows();
@ -135,12 +149,18 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
search_config[knowhere::meta::METRIC_TYPE] = GetMetricType();
// set search list size
auto search_list_size = GetValueFromConfig<uint32_t>(search_info.search_params_, DISK_ANN_QUERY_LIST);
AssertInfo(search_list_size.has_value(), "param " + std::string(DISK_ANN_QUERY_LIST) + "is empty");
AssertInfo(search_list_size.value() >= topk, "search_list should be greater than or equal to topk");
AssertInfo(search_list_size.value() <= std::max(uint32_t(topk * 10), uint32_t(kSearchListMaxValue1)) &&
search_list_size.value() <= uint32_t(kSearchListMaxValue2),
"search_list should be less than max(topk*10, 200) and less than 65535");
auto search_list_size = GetValueFromConfig<uint32_t>(
search_info.search_params_, DISK_ANN_QUERY_LIST);
AssertInfo(search_list_size.has_value(),
"param " + std::string(DISK_ANN_QUERY_LIST) + "is empty");
AssertInfo(search_list_size.value() >= topk,
"search_list should be greater than or equal to topk");
AssertInfo(
search_list_size.value() <=
std::max(uint32_t(topk * 10), uint32_t(kSearchListMaxValue1)) &&
search_list_size.value() <= uint32_t(kSearchListMaxValue2),
"search_list should be less than max(topk*10, 200) and less than "
"65535");
search_config[DISK_ANN_SEARCH_LIST_SIZE] = search_list_size.value();
// set beamwidth
@ -154,25 +174,33 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0;
auto final = [&] {
auto radius = GetValueFromConfig<float>(search_info.search_params_, RADIUS);
auto radius =
GetValueFromConfig<float>(search_info.search_params_, RADIUS);
if (radius.has_value()) {
search_config[RADIUS] = radius.value();
auto range_filter = GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
auto range_filter = GetValueFromConfig<float>(
search_info.search_params_, RANGE_FILTER);
if (range_filter.has_value()) {
search_config[RANGE_FILTER] = range_filter.value();
CheckRangeSearchParam(search_config[RADIUS], search_config[RANGE_FILTER], GetMetricType());
CheckRangeSearchParam(search_config[RADIUS],
search_config[RANGE_FILTER],
GetMetricType());
}
auto res = index_.RangeSearch(*dataset, search_config, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
"failed to range search, " + MatchKnowhereError(res.error()));
"failed to range search, " +
MatchKnowhereError(res.error()));
}
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
return SortRangeSearchResult(
res.value(), topk, num_queries, GetMetricType());
} else {
auto res = index_.Search(*dataset, search_config, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search, " + MatchKnowhereError(res.error()));
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to search, " + MatchKnowhereError(res.error()));
}
return res.value();
}
@ -226,12 +254,15 @@ VectorDiskAnnIndex<T>::update_load_json(const Config& config) {
load_config[DISK_ANN_PREPARE_USE_BFS_CACHE] = false;
// set threads number
auto num_threads = GetValueFromConfig<std::string>(load_config, DISK_ANN_LOAD_THREAD_NUM);
AssertInfo(num_threads.has_value(), "param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty");
auto num_threads =
GetValueFromConfig<std::string>(load_config, DISK_ANN_LOAD_THREAD_NUM);
AssertInfo(num_threads.has_value(),
"param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty");
load_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str());
// update search_beamwidth
auto beamwidth = GetValueFromConfig<std::string>(load_config, DISK_ANN_QUERY_BEAMWIDTH);
auto beamwidth =
GetValueFromConfig<std::string>(load_config, DISK_ANN_QUERY_BEAMWIDTH);
if (beamwidth.has_value()) {
search_beamwidth_ = std::atoi(beamwidth.value().c_str());
}

View File

@ -49,13 +49,17 @@ class VectorDiskAnnIndex : public VectorIndex {
}
void
Load(const BinarySet& binary_set /* not used */, const Config& config = {}) override;
Load(const BinarySet& binary_set /* not used */,
const Config& config = {}) override;
void
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override;
BuildWithDataset(const DatasetPtr& dataset,
const Config& config = {}) override;
std::unique_ptr<SearchResult>
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) override;
void
CleanLocalData() override;

View File

@ -32,18 +32,26 @@ namespace milvus::index {
class VectorIndex : public IndexBase {
public:
explicit VectorIndex(const IndexType& index_type, const IndexMode& index_mode, const MetricType& metric_type)
: index_type_(index_type), index_mode_(index_mode), metric_type_(metric_type) {
explicit VectorIndex(const IndexType& index_type,
const IndexMode& index_mode,
const MetricType& metric_type)
: index_type_(index_type),
index_mode_(index_mode),
metric_type_(metric_type) {
}
public:
void
BuildWithRawData(size_t n, const void* values, const Config& config = {}) override {
BuildWithRawData(size_t n,
const void* values,
const Config& config = {}) override {
PanicInfo("vector index don't support build index with raw data");
};
virtual std::unique_ptr<SearchResult>
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) = 0;
Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) = 0;
IndexType
GetIndexType() const {

View File

@ -30,9 +30,12 @@
namespace milvus::index {
VectorMemIndex::VectorMemIndex(const IndexType& index_type, const MetricType& metric_type, const IndexMode& index_mode)
VectorMemIndex::VectorMemIndex(const IndexType& index_type,
const MetricType& metric_type,
const IndexMode& index_mode)
: VectorIndex(index_type, index_mode, metric_type) {
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
AssertInfo(!is_unsupported(index_type, metric_type),
index_type + " doesn't support metric: " + metric_type);
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType());
}
@ -42,7 +45,8 @@ VectorMemIndex::Serialize(const Config& config) {
knowhere::BinarySet ret;
auto stat = index_.Serialize(ret);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to serialize index, " + MatchKnowhereError(stat));
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
"failed to serialize index, " + MatchKnowhereError(stat));
milvus::Disassemble(ret);
return ret;
@ -53,12 +57,15 @@ VectorMemIndex::Load(const BinarySet& binary_set, const Config& config) {
milvus::Assemble(const_cast<BinarySet&>(binary_set));
auto stat = index_.Deserialize(binary_set);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to Deserialize index, " + MatchKnowhereError(stat));
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to Deserialize index, " + MatchKnowhereError(stat));
SetDim(index_.Dim());
}
void
VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, const Config& config) {
VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset,
const Config& config) {
knowhere::Json index_config;
index_config.update(config);
@ -67,13 +74,16 @@ VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, const Config& config
knowhere::TimeRecorder rc("BuildWithoutIds", 1);
auto stat = index_.Build(*dataset, index_config);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::BuildIndexError, "failed to build index, " + MatchKnowhereError(stat));
PanicCodeInfo(ErrorCodeEnum::BuildIndexError,
"failed to build index, " + MatchKnowhereError(stat));
rc.ElapseFromBegin("Done");
SetDim(index_.Dim());
}
std::unique_ptr<SearchResult>
VectorMemIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) {
VectorMemIndex::Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) {
// AssertInfo(GetMetricType() == search_info.metric_type_,
// "Metric type of field index isn't the same with search info");
@ -87,18 +97,24 @@ VectorMemIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, c
auto index_type = GetIndexType();
if (CheckKeyInConfig(search_conf, RADIUS)) {
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
CheckRangeSearchParam(search_conf[RADIUS], search_conf[RANGE_FILTER], GetMetricType());
CheckRangeSearchParam(search_conf[RADIUS],
search_conf[RANGE_FILTER],
GetMetricType());
}
auto res = index_.RangeSearch(*dataset, search_conf, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
"failed to range search, " + MatchKnowhereError(res.error()));
"failed to range search, " +
MatchKnowhereError(res.error()));
}
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
return SortRangeSearchResult(
res.value(), topk, num_queries, GetMetricType());
} else {
auto res = index_.Search(*dataset, search_conf, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search, " + MatchKnowhereError(res.error()));
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to search, " + MatchKnowhereError(res.error()));
}
return res.value();
}

View File

@ -28,7 +28,9 @@ namespace milvus::index {
class VectorMemIndex : public VectorIndex {
public:
explicit VectorMemIndex(const IndexType& index_type, const MetricType& metric_type, const IndexMode& index_mode);
explicit VectorMemIndex(const IndexType& index_type,
const MetricType& metric_type,
const IndexMode& index_mode);
BinarySet
Serialize(const Config& config) override;
@ -37,7 +39,8 @@ class VectorMemIndex : public VectorIndex {
Load(const BinarySet& binary_set, const Config& config = {}) override;
void
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override;
BuildWithDataset(const DatasetPtr& dataset,
const Config& config = {}) override;
int64_t
Count() override {
@ -45,7 +48,9 @@ class VectorMemIndex : public VectorIndex {
}
std::unique_ptr<SearchResult>
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) override;
protected:
Config config_;

View File

@ -31,10 +31,12 @@ VectorMemNMIndex::Serialize(const Config& config) {
knowhere::BinarySet ret;
auto stat = index_.Serialize(ret);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to serialize index, " + MatchKnowhereError(stat));
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
"failed to serialize index, " + MatchKnowhereError(stat));
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
auto raw_data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
auto raw_data = std::shared_ptr<uint8_t[]>(
static_cast<uint8_t*>(raw_data_.data()), deleter);
ret.Append(RAW_DATA, raw_data, raw_data_.size());
milvus::Disassemble(ret);
@ -42,7 +44,8 @@ VectorMemNMIndex::Serialize(const Config& config) {
}
void
VectorMemNMIndex::BuildWithDataset(const DatasetPtr& dataset, const Config& config) {
VectorMemNMIndex::BuildWithDataset(const DatasetPtr& dataset,
const Config& config) {
VectorMemIndex::BuildWithDataset(dataset, config);
knowhere::TimeRecorder rc("store_raw_data", 1);
store_raw_data(dataset);
@ -53,12 +56,16 @@ void
VectorMemNMIndex::Load(const BinarySet& binary_set, const Config& config) {
VectorMemIndex::Load(binary_set, config);
if (binary_set.Contains(RAW_DATA)) {
std::call_once(raw_data_loaded_, [&]() { LOG_SEGCORE_INFO_C << "NM index load raw data done!"; });
std::call_once(raw_data_loaded_, [&]() {
LOG_SEGCORE_INFO_C << "NM index load raw data done!";
});
}
}
std::unique_ptr<SearchResult>
VectorMemNMIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) {
VectorMemNMIndex::Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) {
auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer
// load -> query, raw data has been loaded
// build -> query, this case just for test, should load raw data before query
@ -88,16 +95,20 @@ VectorMemNMIndex::LoadRawData() {
knowhere::BinarySet bs;
auto stat = index_.Serialize(bs);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to Serialize index, " + MatchKnowhereError(stat));
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
"failed to Serialize index, " + MatchKnowhereError(stat));
auto bptr = std::make_shared<knowhere::Binary>();
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
bptr->data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
bptr->data = std::shared_ptr<uint8_t[]>(
static_cast<uint8_t*>(raw_data_.data()), deleter);
bptr->size = raw_data_.size();
bs.Append(RAW_DATA, bptr);
stat = index_.Deserialize(bs);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to Deserialize index, " + MatchKnowhereError(stat));
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to Deserialize index, " + MatchKnowhereError(stat));
}
} // namespace milvus::index

View File

@ -28,7 +28,9 @@ namespace milvus::index {
class VectorMemNMIndex : public VectorMemIndex {
public:
explicit VectorMemNMIndex(const IndexType& index_type, const MetricType& metric_type, const IndexMode& index_mode)
explicit VectorMemNMIndex(const IndexType& index_type,
const MetricType& metric_type,
const IndexMode& index_mode)
: VectorMemIndex(index_type, metric_type, index_mode) {
AssertInfo(is_in_nm_list(index_type), "not valid nm index type");
}
@ -37,13 +39,16 @@ class VectorMemNMIndex : public VectorMemIndex {
Serialize(const Config& config) override;
void
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override;
BuildWithDataset(const DatasetPtr& dataset,
const Config& config = {}) override;
void
Load(const BinarySet& binary_set, const Config& config = {}) override;
std::unique_ptr<SearchResult>
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) override;
private:
void

View File

@ -45,7 +45,8 @@ class IndexFactory {
const char* index_params,
const storage::StorageConfig& storage_config) {
auto real_dtype = DataType(dtype);
auto invalid_dtype_msg = std::string("invalid data type: ") + std::to_string(int(real_dtype));
auto invalid_dtype_msg = std::string("invalid data type: ") +
std::to_string(int(real_dtype));
switch (real_dtype) {
case DataType::BOOL:
@ -61,7 +62,8 @@ class IndexFactory {
case DataType::VECTOR_FLOAT:
case DataType::VECTOR_BINARY:
return std::make_unique<VecIndexCreator>(real_dtype, type_params, index_params, storage_config);
return std::make_unique<VecIndexCreator>(
real_dtype, type_params, index_params, storage_config);
default:
throw std::invalid_argument(invalid_dtype_msg);
}

View File

@ -20,7 +20,9 @@
namespace milvus::indexbuilder {
ScalarIndexCreator::ScalarIndexCreator(DataType dtype, const char* type_params, const char* index_params)
ScalarIndexCreator::ScalarIndexCreator(DataType dtype,
const char* type_params,
const char* index_params)
: dtype_(dtype) {
// TODO: move parse-related logic to a common interface.
proto::indexcgo::TypeParams type_params_;
@ -42,7 +44,8 @@ ScalarIndexCreator::ScalarIndexCreator(DataType dtype, const char* type_params,
index_info.field_type = dtype_;
index_info.index_type = index_type();
index_info.index_mode = IndexMode::MODE_CPU;
index_ = index::IndexFactory::GetInstance().CreateIndex(index_info, nullptr);
index_ =
index::IndexFactory::GetInstance().CreateIndex(index_info, nullptr);
}
void

View File

@ -22,7 +22,9 @@ namespace milvus::indexbuilder {
class ScalarIndexCreator : public IndexCreatorBase {
public:
ScalarIndexCreator(DataType data_type, const char* type_params, const char* index_params);
ScalarIndexCreator(DataType data_type,
const char* type_params,
const char* index_params);
void
Build(const milvus::DatasetPtr& dataset) override;
@ -46,8 +48,11 @@ class ScalarIndexCreator : public IndexCreatorBase {
using ScalarIndexCreatorPtr = std::unique_ptr<ScalarIndexCreator>;
inline ScalarIndexCreatorPtr
CreateScalarIndex(DataType dtype, const char* type_params, const char* index_params) {
return std::make_unique<ScalarIndexCreator>(dtype, type_params, index_params);
CreateScalarIndex(DataType dtype,
const char* type_params,
const char* index_params) {
return std::make_unique<ScalarIndexCreator>(
dtype, type_params, index_params);
}
} // namespace milvus::indexbuilder

View File

@ -30,8 +30,10 @@ VecIndexCreator::VecIndexCreator(DataType data_type,
: data_type_(data_type) {
proto::indexcgo::TypeParams type_params_;
proto::indexcgo::IndexParams index_params_;
milvus::index::ParseFromString(type_params_, std::string(serialized_type_params));
milvus::index::ParseFromString(index_params_, std::string(serialized_index_params));
milvus::index::ParseFromString(type_params_,
std::string(serialized_type_params));
milvus::index::ParseFromString(index_params_,
std::string(serialized_index_params));
for (auto i = 0; i < type_params_.params_size(); ++i) {
const auto& param = type_params_.params(i);
@ -54,12 +56,16 @@ VecIndexCreator::VecIndexCreator(DataType data_type,
if (index::is_in_disk_list(index_info.index_type)) {
// For now, only support diskann index
file_manager = std::make_shared<storage::DiskFileManagerImpl>(
index::GetFieldDataMetaFromConfig(config_), index::GetIndexMetaFromConfig(config_), storage_config);
index::GetFieldDataMetaFromConfig(config_),
index::GetIndexMetaFromConfig(config_),
storage_config);
}
#endif
index_ = index::IndexFactory::GetInstance().CreateIndex(index_info, file_manager);
AssertInfo(index_ != nullptr, "[VecIndexCreator]Index is null after create index");
index_ = index::IndexFactory::GetInstance().CreateIndex(index_info,
file_manager);
AssertInfo(index_ != nullptr,
"[VecIndexCreator]Index is null after create index");
}
int64_t
@ -83,7 +89,9 @@ VecIndexCreator::Load(const milvus::BinarySet& binary_set) {
}
std::unique_ptr<SearchResult>
VecIndexCreator::Query(const milvus::DatasetPtr& dataset, const SearchInfo& search_info, const BitsetView& bitset) {
VecIndexCreator::Query(const milvus::DatasetPtr& dataset,
const SearchInfo& search_info,
const BitsetView& bitset) {
auto vector_index = dynamic_cast<index::VectorIndex*>(index_.get());
return vector_index->Query(dataset, search_info, bitset);
}

View File

@ -44,7 +44,9 @@ class VecIndexCreator : public IndexCreatorBase {
dim();
std::unique_ptr<SearchResult>
Query(const milvus::DatasetPtr& dataset, const SearchInfo& search_info, const BitsetView& bitset);
Query(const milvus::DatasetPtr& dataset,
const SearchInfo& search_info,
const BitsetView& bitset);
public:
void

View File

@ -39,18 +39,23 @@ CreateIndex(enum CDataType dtype,
std::string remote_root_path(c_storage_config.remote_root_path);
std::string storage_type(c_storage_config.storage_type);
std::string iam_endpoint(c_storage_config.iam_endpoint);
auto storage_config = milvus::storage::StorageConfig{address,
bucket_name,
access_key,
access_value,
remote_root_path,
storage_type,
iam_endpoint,
c_storage_config.useSSL,
c_storage_config.useIAM};
auto storage_config =
milvus::storage::StorageConfig{address,
bucket_name,
access_key,
access_value,
remote_root_path,
storage_type,
iam_endpoint,
c_storage_config.useSSL,
c_storage_config.useIAM};
auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex(
dtype, serialized_type_params, serialized_index_params, storage_config);
auto index =
milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex(
dtype,
serialized_type_params,
serialized_index_params,
storage_config);
*res_index = index.release();
status.error_code = Success;
status.error_msg = "";
@ -66,7 +71,8 @@ DeleteIndex(CIndex index) {
auto status = CStatus();
try {
AssertInfo(index, "failed to delete index, passed index was null");
auto cIndex = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
delete cIndex;
status.error_code = Success;
status.error_msg = "";
@ -78,12 +84,17 @@ DeleteIndex(CIndex index) {
}
CStatus
BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors) {
BuildFloatVecIndex(CIndex index,
int64_t float_value_num,
const float* vectors) {
auto status = CStatus();
try {
AssertInfo(index, "failed to build float vector index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
AssertInfo(index,
"failed to build float vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = float_value_num / dim;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
@ -101,9 +112,13 @@ CStatus
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
auto status = CStatus();
try {
AssertInfo(index, "failed to build binary vector index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
AssertInfo(
index,
"failed to build binary vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = (data_size * 8) / dim;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
@ -126,9 +141,11 @@ CStatus
BuildScalarIndex(CIndex c_index, int64_t size, const void* field_data) {
auto status = CStatus();
try {
AssertInfo(c_index, "failed to build scalar index, passed index was null");
AssertInfo(c_index,
"failed to build scalar index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(c_index);
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(c_index);
const int64_t dim = 8; // not important here
auto dataset = knowhere::GenDataSet(size, dim, field_data);
real_index->Build(dataset);
@ -146,9 +163,13 @@ CStatus
SerializeIndexToBinarySet(CIndex index, CBinarySet* c_binary_set) {
auto status = CStatus();
try {
AssertInfo(index, "failed to serialize index to binary set, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto binary = std::make_unique<knowhere::BinarySet>(real_index->Serialize());
AssertInfo(
index,
"failed to serialize index to binary set, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto binary =
std::make_unique<knowhere::BinarySet>(real_index->Serialize());
*c_binary_set = binary.release();
status.error_code = Success;
status.error_msg = "";
@ -163,8 +184,11 @@ CStatus
LoadIndexFromBinarySet(CIndex index, CBinarySet c_binary_set) {
auto status = CStatus();
try {
AssertInfo(index, "failed to load index from binary set, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
AssertInfo(
index,
"failed to load index from binary set, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto binary_set = reinterpret_cast<knowhere::BinarySet*>(c_binary_set);
real_index->Load(*binary_set);
status.error_code = Success;
@ -180,9 +204,12 @@ CStatus
CleanLocalData(CIndex index) {
auto status = CStatus();
try {
AssertInfo(index, "failed to build float vector index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
AssertInfo(index,
"failed to build float vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
cIndex->CleanLocalData();
status.error_code = Success;
status.error_msg = "";

View File

@ -64,7 +64,7 @@ SetThreadName(const std::string& name) {
std::string
GetThreadName() {
std::string thread_name = "unamed";
std::string thread_name = "unnamed";
char name[16];
size_t len = 16;
auto err = pthread_getname_np(pthread_self(), name, len);
@ -108,12 +108,17 @@ get_thread_starttime() {
int64_t pid = getpid();
char filename[256];
snprintf(filename, sizeof(filename), "/proc/%lld/task/%lld/stat", (long long)pid, (long long)tid); // NOLINT
snprintf(filename,
sizeof(filename),
"/proc/%lld/task/%lld/stat",
(long long)pid,
(long long)tid); // NOLINT
int64_t val = 0;
char comm[16], state;
FILE* thread_stat = fopen(filename, "r");
auto ret = fscanf(thread_stat, "%lld %s %s ", (long long*)&val, comm, &state); // NOLINT
auto ret = fscanf(
thread_stat, "%lld %s %s ", (long long*)&val, comm, &state); // NOLINT
for (auto i = 4; i < 23; i++) {
ret = fscanf(thread_stat, "%lld ", (long long*)&val); // NOLINT
@ -131,7 +136,8 @@ get_thread_starttime() {
int64_t
get_thread_start_timestamp() {
try {
return get_now_timestamp() - get_system_boottime() + get_thread_starttime();
return get_now_timestamp() - get_system_boottime() +
get_thread_starttime();
} catch (...) {
return 0;
}

View File

@ -42,33 +42,44 @@
// Use this macro whenever possible
// Depends variables: context Context
#define MLOG(level, module, error_code) \
LOG(level) << " | " << VAR_REQUEST_ID << " | " << #level << " | " << VAR_COLLECTION_NAME << " | " << VAR_CLIENT_ID \
<< " | " << VAR_CLIENT_TAG << " | " << VAR_CLIENT_IPPORT << " | " << VAR_THREAD_ID << " | " \
<< VAR_THREAD_START_TIMESTAMP << " | " << VAR_COMMAND_TAG << " | " << #module << " | " << error_code \
<< " | "
#define MLOG(level, module, error_code) \
LOG(level) << " | " << VAR_REQUEST_ID << " | " << #level << " | " \
<< VAR_COLLECTION_NAME << " | " << VAR_CLIENT_ID << " | " \
<< VAR_CLIENT_TAG << " | " << VAR_CLIENT_IPPORT << " | " \
<< VAR_THREAD_ID << " | " << VAR_THREAD_START_TIMESTAMP \
<< " | " << VAR_COMMAND_TAG << " | " << #module << " | " \
<< error_code << " | "
// Use in some background process only
#define MLOG_(level, module, error_code) \
LOG(level) << " | " \
<< "" \
<< " | " << #level << " | " \
<< "" \
<< " | " \
<< "" \
<< " | " \
<< "" \
<< " | " \
<< "" \
<< " | " << VAR_THREAD_ID << " | " << VAR_THREAD_START_TIMESTAMP << " | " \
<< "" \
#define MLOG_(level, module, error_code) \
LOG(level) << " | " \
<< "" \
<< " | " << #level << " | " \
<< "" \
<< " | " \
<< "" \
<< " | " \
<< "" \
<< " | " \
<< "" \
<< " | " << VAR_THREAD_ID << " | " \
<< VAR_THREAD_START_TIMESTAMP << " | " \
<< "" \
<< " | " << #module << " | " << error_code << " | "
/////////////////////////////////////////////////////////////////////////////////////////////////
#define SEGCORE_MODULE_NAME "SEGCORE"
#define SEGCORE_MODULE_CLASS_FUNCTION \
LogOut("[%s][%s::%s][%s] ", SEGCORE_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str())
#define SEGCORE_MODULE_FUNCTION LogOut("[%s][%s][%s] ", SEGCORE_MODULE_NAME, __FUNCTION__, GetThreadName().c_str())
LogOut("[%s][%s::%s][%s] ", \
SEGCORE_MODULE_NAME, \
(typeid(*this).name()), \
__FUNCTION__, \
GetThreadName().c_str())
#define SEGCORE_MODULE_FUNCTION \
LogOut("[%s][%s][%s] ", \
SEGCORE_MODULE_NAME, \
__FUNCTION__, \
GetThreadName().c_str())
#define LOG_SEGCORE_TRACE_C LOG(TRACE) << SEGCORE_MODULE_CLASS_FUNCTION
#define LOG_SEGCORE_DEBUG_C LOG(DEBUG) << SEGCORE_MODULE_CLASS_FUNCTION
@ -87,8 +98,16 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
#define SERVER_MODULE_NAME "SERVER"
#define SERVER_MODULE_CLASS_FUNCTION \
LogOut("[%s][%s::%s][%s] ", SERVER_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str())
#define SERVER_MODULE_FUNCTION LogOut("[%s][%s][%s] ", SERVER_MODULE_NAME, __FUNCTION__, GetThreadName().c_str())
LogOut("[%s][%s::%s][%s] ", \
SERVER_MODULE_NAME, \
(typeid(*this).name()), \
__FUNCTION__, \
GetThreadName().c_str())
#define SERVER_MODULE_FUNCTION \
LogOut("[%s][%s][%s] ", \
SERVER_MODULE_NAME, \
__FUNCTION__, \
GetThreadName().c_str())
#define LOG_SERVER_TRACE_C LOG(TRACE) << SERVER_MODULE_CLASS_FUNCTION
#define LOG_SERVER_DEBUG_C LOG(DEBUG) << SERVER_MODULE_CLASS_FUNCTION

View File

@ -49,7 +49,8 @@ struct BinaryExprBase : Expr {
BinaryExprBase() = delete;
BinaryExprBase(ExprPtr& left, ExprPtr& right) : left_(std::move(left)), right_(std::move(right)) {
BinaryExprBase(ExprPtr& left, ExprPtr& right)
: left_(std::move(left)), right_(std::move(right)) {
}
};
@ -66,7 +67,8 @@ struct LogicalUnaryExpr : UnaryExprBase {
enum class OpType { Invalid = 0, LogicalNot = 1 };
const OpType op_type_;
LogicalUnaryExpr(const OpType op_type, ExprPtr& child) : UnaryExprBase(child), op_type_(op_type) {
LogicalUnaryExpr(const OpType op_type, ExprPtr& child)
: UnaryExprBase(child), op_type_(op_type) {
}
public:
@ -76,7 +78,13 @@ struct LogicalUnaryExpr : UnaryExprBase {
struct LogicalBinaryExpr : BinaryExprBase {
// Note: bitA - bitB == bitA & ~bitB, alias to LogicalMinus
enum class OpType { Invalid = 0, LogicalAnd = 1, LogicalOr = 2, LogicalXor = 3, LogicalMinus = 4 };
enum class OpType {
Invalid = 0,
LogicalAnd = 1,
LogicalOr = 2,
LogicalXor = 3,
LogicalMinus = 4
};
const OpType op_type_;
LogicalBinaryExpr(const OpType op_type, ExprPtr& left, ExprPtr& right)
@ -93,10 +101,11 @@ struct TermExpr : Expr {
const DataType data_type_;
protected:
// prevent accidential instantiation
// prevent accidental instantiation
TermExpr() = delete;
TermExpr(const FieldId field_id, const DataType data_type) : field_id_(field_id), data_type_(data_type) {
TermExpr(const FieldId field_id, const DataType data_type)
: field_id_(field_id), data_type_(data_type) {
}
public:
@ -106,14 +115,20 @@ struct TermExpr : Expr {
static const std::map<std::string, ArithOpType> arith_op_mapping_ = {
// arith_op_name -> arith_op
{"add", ArithOpType::Add}, {"sub", ArithOpType::Sub}, {"mul", ArithOpType::Mul},
{"div", ArithOpType::Div}, {"mod", ArithOpType::Mod},
{"add", ArithOpType::Add},
{"sub", ArithOpType::Sub},
{"mul", ArithOpType::Mul},
{"div", ArithOpType::Div},
{"mod", ArithOpType::Mod},
};
static const std::map<ArithOpType, std::string> mapping_arith_op_ = {
// arith_op_name -> arith_op
{ArithOpType::Add, "add"}, {ArithOpType::Sub, "sub"}, {ArithOpType::Mul, "mul"},
{ArithOpType::Div, "div"}, {ArithOpType::Mod, "mod"},
{ArithOpType::Add, "add"},
{ArithOpType::Sub, "sub"},
{ArithOpType::Mul, "mul"},
{ArithOpType::Div, "div"},
{ArithOpType::Mod, "mod"},
};
struct BinaryArithOpEvalRangeExpr : Expr {
@ -123,14 +138,17 @@ struct BinaryArithOpEvalRangeExpr : Expr {
const ArithOpType arith_op_;
protected:
// prevent accidential instantiation
// prevent accidental instantiation
BinaryArithOpEvalRangeExpr() = delete;
BinaryArithOpEvalRangeExpr(const FieldId field_id,
const DataType data_type,
const OpType op_type,
const ArithOpType arith_op)
: field_id_(field_id), data_type_(data_type), op_type_(op_type), arith_op_(arith_op) {
: field_id_(field_id),
data_type_(data_type),
op_type_(op_type),
arith_op_(arith_op) {
}
public:
@ -140,9 +158,14 @@ struct BinaryArithOpEvalRangeExpr : Expr {
static const std::map<std::string, OpType> mapping_ = {
// op_name -> op
{"lt", OpType::LessThan}, {"le", OpType::LessEqual}, {"lte", OpType::LessEqual},
{"gt", OpType::GreaterThan}, {"ge", OpType::GreaterEqual}, {"gte", OpType::GreaterEqual},
{"eq", OpType::Equal}, {"ne", OpType::NotEqual},
{"lt", OpType::LessThan},
{"le", OpType::LessEqual},
{"lte", OpType::LessEqual},
{"gt", OpType::GreaterThan},
{"ge", OpType::GreaterEqual},
{"gte", OpType::GreaterEqual},
{"eq", OpType::Equal},
{"ne", OpType::NotEqual},
};
struct UnaryRangeExpr : Expr {
@ -151,10 +174,12 @@ struct UnaryRangeExpr : Expr {
const OpType op_type_;
protected:
// prevent accidential instantiation
// prevent accidental instantiation
UnaryRangeExpr() = delete;
UnaryRangeExpr(const FieldId field_id, const DataType data_type, const OpType op_type)
UnaryRangeExpr(const FieldId field_id,
const DataType data_type,
const OpType op_type)
: field_id_(field_id), data_type_(data_type), op_type_(op_type) {
}
@ -170,7 +195,7 @@ struct BinaryRangeExpr : Expr {
const bool upper_inclusive_;
protected:
// prevent accidential instantiation
// prevent accidental instantiation
BinaryRangeExpr() = delete;
BinaryRangeExpr(const FieldId field_id,

View File

@ -28,7 +28,9 @@ template <typename T>
struct TermExprImpl : TermExpr {
const std::vector<T> terms_;
TermExprImpl(const FieldId field_id, const DataType data_type, const std::vector<T>& terms)
TermExprImpl(const FieldId field_id,
const DataType data_type,
const std::vector<T>& terms)
: TermExpr(field_id, data_type), terms_(terms) {
}
};
@ -54,7 +56,10 @@ template <typename T>
struct UnaryRangeExprImpl : UnaryRangeExpr {
const T value_;
UnaryRangeExprImpl(const FieldId field_id, const DataType data_type, const OpType op_type, const T value)
UnaryRangeExprImpl(const FieldId field_id,
const DataType data_type,
const OpType op_type,
const T value)
: UnaryRangeExpr(field_id, data_type, op_type), value_(value) {
}
};
@ -70,7 +75,8 @@ struct BinaryRangeExprImpl : BinaryRangeExpr {
const bool upper_inclusive,
const T lower_value,
const T upper_value)
: BinaryRangeExpr(field_id, data_type, lower_inclusive, upper_inclusive),
: BinaryRangeExpr(
field_id, data_type, lower_inclusive, upper_inclusive),
lower_value_(lower_value),
upper_value_(upper_value) {
}

View File

@ -232,7 +232,8 @@ Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) {
terms[i] = value;
}
std::sort(terms.begin(), terms.end());
return std::make_unique<TermExprImpl<T>>(schema.get_field_id(field_name), schema[field_name].get_data_type(),
return std::make_unique<TermExprImpl<T>>(schema.get_field_id(field_name),
schema[field_name].get_data_type(),
terms);
}
@ -277,8 +278,10 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
auto arith = item.value();
auto arith_body = arith.begin();
auto arith_op_name = boost::algorithm::to_lower_copy(std::string(arith_body.key()));
AssertInfo(arith_op_mapping_.count(arith_op_name), "arith op(" + arith_op_name + ") not found");
auto arith_op_name =
boost::algorithm::to_lower_copy(std::string(arith_body.key()));
AssertInfo(arith_op_mapping_.count(arith_op_name),
"arith op(" + arith_op_name + ") not found");
auto& arith_op_body = arith_body.value();
Assert(arith_op_body.is_object());
@ -299,8 +302,12 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
}
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
schema.get_field_id(field_name), schema[field_name].get_data_type(),
arith_op_mapping_.at(arith_op_name), right_operand, mapping_.at(op_name), value);
schema.get_field_id(field_name),
schema[field_name].get_data_type(),
arith_op_mapping_.at(arith_op_name),
right_operand,
mapping_.at(op_name),
value);
}
if constexpr (std::is_same_v<T, bool>) {
@ -313,7 +320,10 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
static_assert(always_false<T>, "unsupported type");
}
return std::make_unique<UnaryRangeExprImpl<T>>(
schema.get_field_id(field_name), schema[field_name].get_data_type(), mapping_.at(op_name), item.value());
schema.get_field_id(field_name),
schema[field_name].get_data_type(),
mapping_.at(op_name),
item.value());
} else if (body.size() == 2) {
bool has_lower_value = false;
bool has_upper_value = false;
@ -322,8 +332,10 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
T lower_value;
T upper_value;
for (auto& item : body.items()) {
auto op_name = boost::algorithm::to_lower_copy(std::string(item.key()));
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
auto op_name =
boost::algorithm::to_lower_copy(std::string(item.key()));
AssertInfo(mapping_.count(op_name),
"op(" + op_name + ") not found");
if constexpr (std::is_same_v<T, bool>) {
Assert(item.value().is_boolean());
} else if constexpr (std::is_integral_v<T>) {
@ -351,10 +363,15 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
PanicInfo("unsupported operator in binary-range node");
}
}
AssertInfo(has_lower_value && has_upper_value, "illegal binary-range node");
return std::make_unique<BinaryRangeExprImpl<T>>(schema.get_field_id(field_name),
schema[field_name].get_data_type(), lower_inclusive,
upper_inclusive, lower_value, upper_value);
AssertInfo(has_lower_value && has_upper_value,
"illegal binary-range node");
return std::make_unique<BinaryRangeExprImpl<T>>(
schema.get_field_id(field_name),
schema[field_name].get_data_type(),
lower_inclusive,
upper_inclusive,
lower_value,
upper_value);
} else {
PanicInfo("illegal range node, too more or too few ops");
}
@ -377,7 +394,10 @@ Parser::ParseItemList(const Json& body) {
}
auto old_size = results.size();
auto new_end = std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) { return x == nullptr; });
auto new_end =
std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) {
return x == nullptr;
});
results.resize(new_end - results.begin());
@ -421,7 +441,8 @@ Parser::ParseMustNode(const Json& body) {
auto item_list = ParseItemList(body);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = LogicalBinaryExpr::OpType;
return std::make_unique<LogicalBinaryExpr>(OpType::LogicalAnd, left, right);
return std::make_unique<LogicalBinaryExpr>(
OpType::LogicalAnd, left, right);
};
return ConstructTree(merger, std::move(item_list));
}
@ -432,7 +453,8 @@ Parser::ParseShouldNode(const Json& body) {
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = LogicalBinaryExpr::OpType;
return std::make_unique<LogicalBinaryExpr>(OpType::LogicalOr, left, right);
return std::make_unique<LogicalBinaryExpr>(
OpType::LogicalOr, left, right);
};
return ConstructTree(merger, std::move(item_list));
}
@ -443,7 +465,8 @@ Parser::ParseMustNotNode(const Json& body) {
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = LogicalBinaryExpr::OpType;
return std::make_unique<LogicalBinaryExpr>(OpType::LogicalAnd, left, right);
return std::make_unique<LogicalBinaryExpr>(
OpType::LogicalAnd, left, right);
};
auto subtree = ConstructTree(merger, std::move(item_list));

View File

@ -23,16 +23,21 @@ namespace milvus::query {
// deprecated
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, const std::string& placeholder_group_blob) {
return ParsePlaceholderGroup(plan, reinterpret_cast<const uint8_t*>(placeholder_group_blob.c_str()),
placeholder_group_blob.size());
ParsePlaceholderGroup(const Plan* plan,
const std::string& placeholder_group_blob) {
return ParsePlaceholderGroup(
plan,
reinterpret_cast<const uint8_t*>(placeholder_group_blob.c_str()),
placeholder_group_blob.size());
}
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, const int64_t blob_len) {
namespace ser = milvus::proto::common;
ParsePlaceholderGroup(const Plan* plan,
const uint8_t* blob,
const int64_t blob_len) {
namespace set = milvus::proto::common;
auto result = std::make_unique<PlaceholderGroup>();
ser::PlaceholderGroup ph_group;
set::PlaceholderGroup ph_group;
auto ok = ph_group.ParseFromArray(blob, blob_len);
Assert(ok);
for (auto& info : ph_group.placeholders()) {
@ -45,7 +50,8 @@ ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, const int64_t blob_
AssertInfo(element.num_of_queries_, "must have queries");
Assert(element.num_of_queries_ > 0);
element.line_sizeof_ = info.values().Get(0).size();
AssertInfo(field_meta.get_sizeof() == element.line_sizeof_, "vector dimension mismatch");
AssertInfo(field_meta.get_sizeof() == element.line_sizeof_,
"vector dimension mismatch");
auto& target = element.blob_;
target.reserve(element.line_sizeof_ * element.num_of_queries_);
for (auto& line : info.values()) {
@ -66,7 +72,9 @@ CreatePlan(const Schema& schema, const std::string& dsl_str) {
}
std::unique_ptr<Plan>
CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size) {
CreateSearchPlanByExpr(const Schema& schema,
const void* serialized_expr_plan,
const int64_t size) {
// Note: serialized_expr_plan is of binary format
proto::plan::PlanNode plan_node;
plan_node.ParseFromArray(serialized_expr_plan, size);
@ -74,7 +82,9 @@ CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, c
}
std::unique_ptr<RetrievePlan>
CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size) {
CreateRetrievePlanByExpr(const Schema& schema,
const void* serialized_expr_plan,
const int64_t size) {
proto::plan::PlanNode plan_node;
plan_node.ParseFromArray(serialized_expr_plan, size);
return ProtoParser(schema).CreateRetrievePlan(plan_node);
@ -111,9 +121,11 @@ Plan::check_identical(Plan& other) {
auto json = ShowPlanNodeVisitor().call_child(*this->plan_node_);
auto other_json = ShowPlanNodeVisitor().call_child(*other.plan_node_);
Assert(json.dump(2) == other_json.dump(2));
Assert(this->extra_info_opt_.has_value() == other.extra_info_opt_.has_value());
Assert(this->extra_info_opt_.has_value() ==
other.extra_info_opt_.has_value());
if (this->extra_info_opt_.has_value()) {
Assert(this->extra_info_opt_->involved_fields_ == other.extra_info_opt_->involved_fields_);
Assert(this->extra_info_opt_->involved_fields_ ==
other.extra_info_opt_->involved_fields_);
}
Assert(this->tag2field_ == other.tag2field_);
Assert(this->target_entries_ == other.target_entries_);

View File

@ -31,20 +31,27 @@ CreatePlan(const Schema& schema, const std::string& dsl);
// Note: serialized_expr_plan is of binary format
std::unique_ptr<Plan>
CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size);
CreateSearchPlanByExpr(const Schema& schema,
const void* serialized_expr_plan,
const int64_t size);
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, const int64_t blob_len);
ParsePlaceholderGroup(const Plan* plan,
const uint8_t* blob,
const int64_t blob_len);
// deprecated
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, const std::string& placeholder_group_blob);
ParsePlaceholderGroup(const Plan* plan,
const std::string& placeholder_group_blob);
int64_t
GetNumOfQueries(const PlaceholderGroup*);
std::unique_ptr<RetrievePlan>
CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size);
CreateRetrievePlanByExpr(const Schema& schema,
const void* serialized_expr_plan,
const int64_t size);
// Query Overall TopK from Plan
// Used to alloc result memory at Go side

View File

@ -25,7 +25,9 @@ namespace planpb = milvus::proto::plan;
template <typename T>
std::unique_ptr<TermExprImpl<T>>
ExtractTermExprImpl(FieldId field_id, DataType data_type, const planpb::TermExpr& expr_proto) {
ExtractTermExprImpl(FieldId field_id,
DataType data_type,
const planpb::TermExpr& expr_proto) {
static_assert(IsScalar<T>);
auto size = expr_proto.values_size();
std::vector<T> terms(size);
@ -53,7 +55,9 @@ ExtractTermExprImpl(FieldId field_id, DataType data_type, const planpb::TermExpr
template <typename T>
std::unique_ptr<UnaryRangeExprImpl<T>>
ExtractUnaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::UnaryRangeExpr& expr_proto) {
ExtractUnaryRangeExprImpl(FieldId field_id,
DataType data_type,
const planpb::UnaryRangeExpr& expr_proto) {
static_assert(IsScalar<T>);
auto getValue = [&](const auto& value_proto) -> T {
if constexpr (std::is_same_v<T, bool>) {
@ -72,13 +76,18 @@ ExtractUnaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::Un
static_assert(always_false<T>);
}
};
return std::make_unique<UnaryRangeExprImpl<T>>(field_id, data_type, static_cast<OpType>(expr_proto.op()),
getValue(expr_proto.value()));
return std::make_unique<UnaryRangeExprImpl<T>>(
field_id,
data_type,
static_cast<OpType>(expr_proto.op()),
getValue(expr_proto.value()));
}
template <typename T>
std::unique_ptr<BinaryRangeExprImpl<T>>
ExtractBinaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::BinaryRangeExpr& expr_proto) {
ExtractBinaryRangeExprImpl(FieldId field_id,
DataType data_type,
const planpb::BinaryRangeExpr& expr_proto) {
static_assert(IsScalar<T>);
auto getValue = [&](const auto& value_proto) -> T {
if constexpr (std::is_same_v<T, bool>) {
@ -97,16 +106,21 @@ ExtractBinaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::B
static_assert(always_false<T>);
}
};
return std::make_unique<BinaryRangeExprImpl<T>>(field_id, data_type, expr_proto.lower_inclusive(),
expr_proto.upper_inclusive(), getValue(expr_proto.lower_value()),
getValue(expr_proto.upper_value()));
return std::make_unique<BinaryRangeExprImpl<T>>(
field_id,
data_type,
expr_proto.lower_inclusive(),
expr_proto.upper_inclusive(),
getValue(expr_proto.lower_value()),
getValue(expr_proto.upper_value()));
}
template <typename T>
std::unique_ptr<BinaryArithOpEvalRangeExprImpl<T>>
ExtractBinaryArithOpEvalRangeExprImpl(FieldId field_id,
DataType data_type,
const planpb::BinaryArithOpEvalRangeExpr& expr_proto) {
ExtractBinaryArithOpEvalRangeExprImpl(
FieldId field_id,
DataType data_type,
const planpb::BinaryArithOpEvalRangeExpr& expr_proto) {
static_assert(std::is_fundamental_v<T>);
auto getValue = [&](const auto& value_proto) -> T {
if constexpr (std::is_same_v<T, bool>) {
@ -123,8 +137,12 @@ ExtractBinaryArithOpEvalRangeExprImpl(FieldId field_id,
}
};
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
field_id, data_type, static_cast<ArithOpType>(expr_proto.arith_op()), getValue(expr_proto.right_operand()),
static_cast<OpType>(expr_proto.op()), getValue(expr_proto.value()));
field_id,
data_type,
static_cast<ArithOpType>(expr_proto.arith_op()),
getValue(expr_proto.right_operand()),
static_cast<OpType>(expr_proto.op()),
getValue(expr_proto.value()));
}
std::unique_ptr<VectorPlanNode>
@ -165,12 +183,15 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
}
std::unique_ptr<RetrievePlanNode>
ProtoParser::RetrievePlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
ProtoParser::RetrievePlanNodeFromProto(
const planpb::PlanNode& plan_node_proto) {
Assert(plan_node_proto.has_predicates());
auto& predicate_proto = plan_node_proto.predicates();
auto expr_opt = [&]() -> ExprPtr { return ParseExpr(predicate_proto); }();
auto plan_node = [&]() -> std::unique_ptr<RetrievePlanNode> { return std::make_unique<RetrievePlanNode>(); }();
auto plan_node = [&]() -> std::unique_ptr<RetrievePlanNode> {
return std::make_unique<RetrievePlanNode>();
}();
plan_node->predicate_ = std::move(expr_opt);
return plan_node;
}
@ -224,28 +245,36 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) {
auto result = [&]() -> ExprPtr {
switch (data_type) {
case DataType::BOOL: {
return ExtractUnaryRangeExprImpl<bool>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<bool>(
field_id, data_type, expr_pb);
}
case DataType::INT8: {
return ExtractUnaryRangeExprImpl<int8_t>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<int8_t>(
field_id, data_type, expr_pb);
}
case DataType::INT16: {
return ExtractUnaryRangeExprImpl<int16_t>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<int16_t>(
field_id, data_type, expr_pb);
}
case DataType::INT32: {
return ExtractUnaryRangeExprImpl<int32_t>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<int32_t>(
field_id, data_type, expr_pb);
}
case DataType::INT64: {
return ExtractUnaryRangeExprImpl<int64_t>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<int64_t>(
field_id, data_type, expr_pb);
}
case DataType::FLOAT: {
return ExtractUnaryRangeExprImpl<float>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<float>(
field_id, data_type, expr_pb);
}
case DataType::DOUBLE: {
return ExtractUnaryRangeExprImpl<double>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<double>(
field_id, data_type, expr_pb);
}
case DataType::VARCHAR: {
return ExtractUnaryRangeExprImpl<std::string>(field_id, data_type, expr_pb);
return ExtractUnaryRangeExprImpl<std::string>(
field_id, data_type, expr_pb);
}
default: {
PanicInfo("unsupported data type");
@ -265,28 +294,36 @@ ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) {
auto result = [&]() -> ExprPtr {
switch (data_type) {
case DataType::BOOL: {
return ExtractBinaryRangeExprImpl<bool>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<bool>(
field_id, data_type, expr_pb);
}
case DataType::INT8: {
return ExtractBinaryRangeExprImpl<int8_t>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<int8_t>(
field_id, data_type, expr_pb);
}
case DataType::INT16: {
return ExtractBinaryRangeExprImpl<int16_t>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<int16_t>(
field_id, data_type, expr_pb);
}
case DataType::INT32: {
return ExtractBinaryRangeExprImpl<int32_t>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<int32_t>(
field_id, data_type, expr_pb);
}
case DataType::INT64: {
return ExtractBinaryRangeExprImpl<int64_t>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<int64_t>(
field_id, data_type, expr_pb);
}
case DataType::FLOAT: {
return ExtractBinaryRangeExprImpl<float>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<float>(
field_id, data_type, expr_pb);
}
case DataType::DOUBLE: {
return ExtractBinaryRangeExprImpl<double>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<double>(
field_id, data_type, expr_pb);
}
case DataType::VARCHAR: {
return ExtractBinaryRangeExprImpl<std::string>(field_id, data_type, expr_pb);
return ExtractBinaryRangeExprImpl<std::string>(
field_id, data_type, expr_pb);
}
default: {
PanicInfo("unsupported data type");
@ -301,12 +338,14 @@ ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) {
auto& left_column_info = expr_pb.left_column_info();
auto left_field_id = FieldId(left_column_info.field_id());
auto left_data_type = schema[left_field_id].get_data_type();
Assert(left_data_type == static_cast<DataType>(left_column_info.data_type()));
Assert(left_data_type ==
static_cast<DataType>(left_column_info.data_type()));
auto& right_column_info = expr_pb.right_column_info();
auto right_field_id = FieldId(right_column_info.field_id());
auto right_data_type = schema[right_field_id].get_data_type();
Assert(right_data_type == static_cast<DataType>(right_column_info.data_type()));
Assert(right_data_type ==
static_cast<DataType>(right_column_info.data_type()));
return [&]() -> ExprPtr {
auto result = std::make_unique<CompareExpr>();
@ -333,25 +372,31 @@ ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) {
return ExtractTermExprImpl<bool>(field_id, data_type, expr_pb);
}
case DataType::INT8: {
return ExtractTermExprImpl<int8_t>(field_id, data_type, expr_pb);
return ExtractTermExprImpl<int8_t>(
field_id, data_type, expr_pb);
}
case DataType::INT16: {
return ExtractTermExprImpl<int16_t>(field_id, data_type, expr_pb);
return ExtractTermExprImpl<int16_t>(
field_id, data_type, expr_pb);
}
case DataType::INT32: {
return ExtractTermExprImpl<int32_t>(field_id, data_type, expr_pb);
return ExtractTermExprImpl<int32_t>(
field_id, data_type, expr_pb);
}
case DataType::INT64: {
return ExtractTermExprImpl<int64_t>(field_id, data_type, expr_pb);
return ExtractTermExprImpl<int64_t>(
field_id, data_type, expr_pb);
}
case DataType::FLOAT: {
return ExtractTermExprImpl<float>(field_id, data_type, expr_pb);
}
case DataType::DOUBLE: {
return ExtractTermExprImpl<double>(field_id, data_type, expr_pb);
return ExtractTermExprImpl<double>(
field_id, data_type, expr_pb);
}
case DataType::VARCHAR: {
return ExtractTermExprImpl<std::string>(field_id, data_type, expr_pb);
return ExtractTermExprImpl<std::string>(
field_id, data_type, expr_pb);
}
default: {
PanicInfo("unsupported data type");
@ -378,7 +423,8 @@ ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
}
ExprPtr
ProtoParser::ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
ProtoParser::ParseBinaryArithOpEvalRangeExpr(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
@ -387,22 +433,28 @@ ProtoParser::ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEva
auto result = [&]() -> ExprPtr {
switch (data_type) {
case DataType::INT8: {
return ExtractBinaryArithOpEvalRangeExprImpl<int8_t>(field_id, data_type, expr_pb);
return ExtractBinaryArithOpEvalRangeExprImpl<int8_t>(
field_id, data_type, expr_pb);
}
case DataType::INT16: {
return ExtractBinaryArithOpEvalRangeExprImpl<int16_t>(field_id, data_type, expr_pb);
return ExtractBinaryArithOpEvalRangeExprImpl<int16_t>(
field_id, data_type, expr_pb);
}
case DataType::INT32: {
return ExtractBinaryArithOpEvalRangeExprImpl<int32_t>(field_id, data_type, expr_pb);
return ExtractBinaryArithOpEvalRangeExprImpl<int32_t>(
field_id, data_type, expr_pb);
}
case DataType::INT64: {
return ExtractBinaryArithOpEvalRangeExprImpl<int64_t>(field_id, data_type, expr_pb);
return ExtractBinaryArithOpEvalRangeExprImpl<int64_t>(
field_id, data_type, expr_pb);
}
case DataType::FLOAT: {
return ExtractBinaryArithOpEvalRangeExprImpl<float>(field_id, data_type, expr_pb);
return ExtractBinaryArithOpEvalRangeExprImpl<float>(
field_id, data_type, expr_pb);
}
case DataType::DOUBLE: {
return ExtractBinaryArithOpEvalRangeExprImpl<double>(field_id, data_type, expr_pb);
return ExtractBinaryArithOpEvalRangeExprImpl<double>(
field_id, data_type, expr_pb);
}
default: {
PanicInfo("unsupported data type");
@ -435,7 +487,8 @@ ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) {
return ParseCompareExpr(expr_pb.compare_expr());
}
case ppe::kBinaryArithOpEvalRangeExpr: {
return ParseBinaryArithOpEvalRangeExpr(expr_pb.binary_arith_op_eval_range_expr());
return ParseBinaryArithOpEvalRangeExpr(
expr_pb.binary_arith_op_eval_range_expr());
}
default:
PanicInfo("unsupported expr proto node");

View File

@ -30,7 +30,8 @@ class ProtoParser {
// ExprFromProto(const proto::plan::Expr& expr_proto);
ExprPtr
ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
ParseBinaryArithOpEvalRangeExpr(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
ExprPtr
ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb);

View File

@ -50,7 +50,10 @@ struct Relational {
template <typename T, typename U>
bool
operator()(const T& t, const U& u) const {
return RelationalImpl<Op, T, U>(t, u, typename TagDispatchTrait<T>::Tag{}, typename TagDispatchTrait<U>::Tag{});
return RelationalImpl<Op, T, U>(t,
u,
typename TagDispatchTrait<T>::Tag{},
typename TagDispatchTrait<U>::Tag{});
}
template <typename... T>

View File

@ -22,15 +22,19 @@
namespace milvus::query {
void
CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info) {
CheckBruteForceSearchParam(const FieldMeta& field,
const SearchInfo& search_info) {
auto data_type = field.get_data_type();
auto& metric_type = search_info.metric_type_;
AssertInfo(datatype_is_vector(data_type), "[BruteForceSearch] Data type isn't vector type");
AssertInfo(datatype_is_vector(data_type),
"[BruteForceSearch] Data type isn't vector type");
bool is_float_data_type = (data_type == DataType::VECTOR_FLOAT);
bool is_float_metric_type =
IsMetricType(metric_type, knowhere::metric::IP) || IsMetricType(metric_type, knowhere::metric::L2);
AssertInfo(is_float_data_type == is_float_metric_type, "[BruteForceSearch] Data type and metric type mis-match");
IsMetricType(metric_type, knowhere::metric::IP) ||
IsMetricType(metric_type, knowhere::metric::L2);
AssertInfo(is_float_data_type == is_float_metric_type,
"[BruteForceSearch] Data type and metric type miss-match");
}
SubSearchResult
@ -39,13 +43,17 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
int64_t chunk_rows,
const knowhere::Json& conf,
const BitsetView& bitset) {
SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal);
SubSearchResult sub_result(dataset.num_queries,
dataset.topk,
dataset.metric_type,
dataset.round_decimal);
try {
auto nq = dataset.num_queries;
auto dim = dataset.dim;
auto topk = dataset.topk;
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto base_dataset =
knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
auto config = knowhere::Json{
{knowhere::meta::METRIC_TYPE, dataset.metric_type},
@ -60,24 +68,36 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
config[RADIUS] = conf[RADIUS].get<float>();
if (conf.contains(RANGE_FILTER)) {
config[RANGE_FILTER] = conf[RANGE_FILTER].get<float>();
CheckRangeSearchParam(config[RADIUS], config[RANGE_FILTER], dataset.metric_type);
CheckRangeSearchParam(
config[RADIUS], config[RANGE_FILTER], dataset.metric_type);
}
auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, config, bitset);
auto res = knowhere::BruteForce::RangeSearch(
base_dataset, query_dataset, config, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
"failed to range search, " + MatchKnowhereError(res.error()));
"failed to range search, " +
MatchKnowhereError(res.error()));
}
auto result = SortRangeSearchResult(res.value(), topk, nq, dataset.metric_type);
std::copy_n(GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
std::copy_n(GetDatasetDistance(result), nq * topk, sub_result.get_distances());
auto result = SortRangeSearchResult(
res.value(), topk, nq, dataset.metric_type);
std::copy_n(
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
std::copy_n(GetDatasetDistance(result),
nq * topk,
sub_result.get_distances());
} else {
auto stat = knowhere::BruteForce::SearchWithBuf(base_dataset, query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(), config, bitset);
auto stat = knowhere::BruteForce::SearchWithBuf(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(),
config,
bitset);
if (stat != knowhere::Status::success) {
throw std::invalid_argument("invalid metric type, " + MatchKnowhereError(stat));
throw std::invalid_argument("invalid metric type, " +
MatchKnowhereError(stat));
}
}
} catch (std::exception& e) {

View File

@ -20,7 +20,8 @@
namespace milvus::query {
void
CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info);
CheckBruteForceSearchParam(const FieldMeta& field,
const SearchInfo& search_info);
SubSearchResult
BruteForceSearch(const dataset::SearchDataset& dataset,

View File

@ -37,31 +37,42 @@ FloatIndexSearch(const segcore::SegmentGrowingImpl& segment,
auto vecfield_id = info.field_id_;
auto& field = schema[vecfield_id];
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT");
dataset::SearchDataset search_dataset{info.metric_type_, num_queries, info.topk_,
info.round_decimal_, field.get_dim(), query_data};
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT,
"[FloatSearch]Field data type isn't VECTOR_FLOAT");
dataset::SearchDataset search_dataset{info.metric_type_,
num_queries,
info.topk_,
info.round_decimal_,
field.get_dim(),
query_data};
auto vec_ptr = record.get_field_data<FloatVector>(vecfield_id);
int current_chunk_id = 0;
if (indexing_record.is_in(vecfield_id)) {
auto max_indexed_id = indexing_record.get_finished_ack();
const auto& field_indexing = indexing_record.get_vec_field_indexing(vecfield_id);
const auto& field_indexing =
indexing_record.get_vec_field_indexing(vecfield_id);
auto search_params = field_indexing.get_search_params(info.topk_);
SearchInfo search_conf(info);
search_conf.search_params_ = search_params;
AssertInfo(vec_ptr->get_size_per_chunk() == field_indexing.get_size_per_chunk(),
"[FloatSearch]Chunk size of vector not equal to chunk size of field index");
AssertInfo(vec_ptr->get_size_per_chunk() ==
field_indexing.get_size_per_chunk(),
"[FloatSearch]Chunk size of vector not equal to chunk size "
"of field index");
auto size_per_chunk = field_indexing.get_size_per_chunk();
for (int chunk_id = current_chunk_id; chunk_id < max_indexed_id; ++chunk_id) {
for (int chunk_id = current_chunk_id; chunk_id < max_indexed_id;
++chunk_id) {
if ((chunk_id + 1) * size_per_chunk > ins_barrier) {
break;
}
auto indexing = field_indexing.get_chunk_indexing(chunk_id);
auto sub_view = bitset.subview(chunk_id * size_per_chunk, size_per_chunk);
auto sub_view =
bitset.subview(chunk_id * size_per_chunk, size_per_chunk);
auto vec_index = (index::VectorIndex*)(indexing);
auto sub_qr = SearchOnIndex(search_dataset, *vec_index, search_conf, sub_view);
auto sub_qr = SearchOnIndex(
search_dataset, *vec_index, search_conf, sub_view);
// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {
@ -87,7 +98,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
SearchResult& results) {
auto& schema = segment.get_schema();
auto& record = segment.get_insert_record();
auto active_count = std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
auto active_count =
std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
// step 1.1: get meta
// step 1.2: get which vector field to search
@ -96,7 +108,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
CheckBruteForceSearchParam(field, info);
auto data_type = field.get_data_type();
AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type");
AssertInfo(datatype_is_vector(data_type),
"[SearchOnGrowing]Data type isn't vector type");
auto dim = field.get_dim();
auto topk = info.topk_;
@ -105,11 +118,18 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
// step 2: small indexing search
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
dataset::SearchDataset search_dataset{
metric_type, num_queries, topk, round_decimal, dim, query_data};
int32_t current_chunk_id = 0;
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
current_chunk_id = FloatIndexSearch(segment, info, query_data, num_queries, active_count, bitset, final_qr);
current_chunk_id = FloatIndexSearch(segment,
info,
query_data,
num_queries,
active_count,
bitset,
final_qr);
}
// step 3: brute force search where small indexing is unavailable
@ -121,11 +141,16 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto chunk_data = vec_ptr->get_chunk_data(chunk_id);
auto element_begin = chunk_id * vec_size_per_chunk;
auto element_end = std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto element_end =
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin;
auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, info.search_params_, sub_view);
auto sub_qr = BruteForceSearch(search_dataset,
chunk_data,
size_per_chunk,
info.search_params_,
sub_view);
// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {

View File

@ -22,7 +22,8 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
auto dim = search_dataset.dim;
auto metric_type = search_dataset.metric_type;
auto round_decimal = search_dataset.round_decimal;
auto dataset = knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
auto dataset =
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
// NOTE: VecIndex Query API forget to add const qualifier
// NOTE: use const_cast as a workaround
@ -30,8 +31,10 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
auto ans = indexing_nonconst.Query(dataset, search_conf, bitset);
SubSearchResult sub_qr(num_queries, topK, metric_type, round_decimal);
std::copy_n(ans->distances_.data(), num_queries * topK, sub_qr.get_distances());
std::copy_n(ans->seg_offsets_.data(), num_queries * topK, sub_qr.get_seg_offsets());
std::copy_n(
ans->distances_.data(), num_queries * topK, sub_qr.get_distances());
std::copy_n(
ans->seg_offsets_.data(), num_queries * topK, sub_qr.get_seg_offsets());
sub_qr.round_values();
return sub_qr;
}

View File

@ -45,7 +45,8 @@ SearchOnSealedIndex(const Schema& schema,
auto conf = search_info.search_params_;
conf[knowhere::meta::TOPK] = search_info.topk_;
conf[knowhere::meta::METRIC_TYPE] = field_indexing->metric_type_;
auto vec_index = dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
auto vec_index =
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
auto index_type = vec_index->GetIndexType();
return vec_index->Query(ds, search_info, bitset);
}();
@ -81,11 +82,16 @@ SearchOnSealed(const Schema& schema,
auto field_id = search_info.field_id_;
auto& field = schema[field_id];
query::dataset::SearchDataset dataset{search_info.metric_type_, num_queries, search_info.topk_,
search_info.round_decimal_, field.get_dim(), query_data};
query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
field.get_dim(),
query_data};
CheckBruteForceSearchParam(field, search_info);
auto sub_qr = BruteForceSearch(dataset, vec_data, row_count, search_info.search_params_, bitset);
auto sub_qr = BruteForceSearch(
dataset, vec_data, row_count, search_info.search_params_, bitset);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());

View File

@ -19,10 +19,13 @@ namespace milvus::query {
template <bool is_desc>
void
SubSearchResult::merge_impl(const SubSearchResult& right) {
AssertInfo(num_queries_ == right.num_queries_, "[SubSearchResult]Nq check failed");
AssertInfo(num_queries_ == right.num_queries_,
"[SubSearchResult]Nq check failed");
AssertInfo(topk_ == right.topk_, "[SubSearchResult]Topk check failed");
AssertInfo(metric_type_ == right.metric_type_, "[SubSearchResult]Metric type check failed");
AssertInfo(is_desc == PositivelyRelated(metric_type_), "[SubSearchResult]Metric type isn't desc");
AssertInfo(metric_type_ == right.metric_type_,
"[SubSearchResult]Metric type check failed");
AssertInfo(is_desc == PositivelyRelated(metric_type_),
"[SubSearchResult]Metric type isn't desc");
for (int64_t qn = 0; qn < num_queries_; ++qn) {
auto offset = qn * topk_;
@ -72,7 +75,8 @@ SubSearchResult::merge_impl(const SubSearchResult& right) {
void
SubSearchResult::merge(const SubSearchResult& sub_result) {
AssertInfo(metric_type_ == sub_result.metric_type_, "[SubSearchResult]Metric type check failed when merge");
AssertInfo(metric_type_ == sub_result.metric_type_,
"[SubSearchResult]Metric type check failed when merge");
if (PositivelyRelated(metric_type_)) {
this->merge_impl<true>(sub_result);
} else {
@ -85,7 +89,8 @@ SubSearchResult::round_values() {
if (round_decimal_ == -1)
return;
const float multiplier = pow(10.0, round_decimal_);
for (auto it = this->distances_.begin(); it != this->distances_.end(); it++) {
for (auto it = this->distances_.begin(); it != this->distances_.end();
it++) {
*it = round(*it * multiplier) / multiplier;
}
}

View File

@ -22,7 +22,10 @@ namespace milvus::query {
class SubSearchResult {
public:
SubSearchResult(int64_t num_queries, int64_t topk, const MetricType& metric_type, int64_t round_decimal)
SubSearchResult(int64_t num_queries,
int64_t topk,
const MetricType& metric_type,
int64_t round_decimal)
: num_queries_(num_queries),
topk_(topk),
round_decimal_(round_decimal),
@ -43,7 +46,8 @@ class SubSearchResult {
public:
static float
init_value(const MetricType& metric_type) {
return (PositivelyRelated(metric_type) ? -1 : 1) * std::numeric_limits<float>::max();
return (PositivelyRelated(metric_type) ? -1 : 1) *
std::numeric_limits<float>::max();
}
public:

View File

@ -38,7 +38,9 @@ Match<std::string>(const std::string& str, const std::string& val, OpType op) {
template <>
inline bool
Match<std::string_view>(const std::string_view& str, const std::string& val, OpType op) {
Match<std::string_view>(const std::string_view& str,
const std::string& val,
OpType op) {
switch (op) {
case OpType::PrefixMatch:
return PrefixMatch(str, val);

View File

@ -22,7 +22,9 @@ namespace milvus {
namespace query_old {
BinaryQueryPtr
ConstructBinTree(std::vector<BooleanQueryPtr> queries, QueryRelation relation, uint64_t idx) {
ConstructBinTree(std::vector<BooleanQueryPtr> queries,
QueryRelation relation,
uint64_t idx) {
if (idx == queries.size()) {
return nullptr;
} else if (idx == queries.size() - 1) {
@ -40,7 +42,9 @@ ConstructBinTree(std::vector<BooleanQueryPtr> queries, QueryRelation relation, u
}
Status
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_queries, BinaryQueryPtr binary_query, uint64_t idx) {
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_queries,
BinaryQueryPtr binary_query,
uint64_t idx) {
if (idx == leaf_queries.size()) {
return Status::OK();
}
@ -58,23 +62,27 @@ ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_queries, BinaryQueryPtr bina
binary_query->left_query->bin->relation = binary_query->relation;
binary_query->right_query->leaf = leaf_queries[idx];
++idx;
return ConstructLeafBinTree(leaf_queries, binary_query->left_query->bin, idx);
return ConstructLeafBinTree(
leaf_queries, binary_query->left_query->bin, idx);
}
}
Status
GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) {
if (query->getBooleanQueries().size() == 0) {
if (binary_query->relation == QueryRelation::AND || binary_query->relation == QueryRelation::OR) {
if (binary_query->relation == QueryRelation::AND ||
binary_query->relation == QueryRelation::OR) {
// Put VectorQuery to the end of leaf queries
auto query_size = query->getLeafQueries().size();
for (uint64_t i = 0; i < query_size; ++i) {
if (query->getLeafQueries()[i]->vector_placeholder.size() > 0) {
std::swap(query->getLeafQueries()[i], query->getLeafQueries()[0]);
std::swap(query->getLeafQueries()[i],
query->getLeafQueries()[0]);
break;
}
}
return ConstructLeafBinTree(query->getLeafQueries(), binary_query, 0);
return ConstructLeafBinTree(
query->getLeafQueries(), binary_query, 0);
} else {
switch (query->getOccur()) {
case Occur::MUST: {
@ -154,7 +162,8 @@ GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) {
if (must_not_queries.size() > 1) {
// Construct a must_not binary tree
must_not_bquery = ConstructBinTree(must_not_queries, QueryRelation::R1, 0);
must_not_bquery =
ConstructBinTree(must_not_queries, QueryRelation::R1, 0);
++bquery_num;
} else if (must_not_queries.size() == 1) {
must_not_bquery = must_not_queries[0]->getBinaryQuery();
@ -228,7 +237,8 @@ BinaryQueryHeight(BinaryQueryPtr& binary_query) {
*/
Status
rule_1(BooleanQueryPtr& boolean_query, std::stack<BooleanQueryPtr>& path_stack) {
rule_1(BooleanQueryPtr& boolean_query,
std::stack<BooleanQueryPtr>& path_stack) {
auto status = Status::OK();
if (boolean_query != nullptr) {
path_stack.push(boolean_query);
@ -236,9 +246,11 @@ rule_1(BooleanQueryPtr& boolean_query, std::stack<BooleanQueryPtr>& path_stack)
if (!leaf_query->vector_placeholder.empty()) {
while (!path_stack.empty()) {
auto query = path_stack.top();
if (query->getOccur() == Occur::SHOULD || query->getOccur() == Occur::MUST_NOT) {
if (query->getOccur() == Occur::SHOULD ||
query->getOccur() == Occur::MUST_NOT) {
std::string msg =
"The child node of 'should' and 'must_not' can only be 'term query' and 'range query'.";
"The child node of 'should' and 'must_not' can "
"only be 'term query' and 'range query'.";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
path_stack.pop();
@ -259,8 +271,10 @@ Status
rule_2(BooleanQueryPtr& boolean_query) {
auto status = Status::OK();
if (boolean_query != nullptr) {
if (!boolean_query->getBooleanQueries().empty() && !boolean_query->getLeafQueries().empty()) {
std::string msg = "One layer cannot include bool query and leaf query.";
if (!boolean_query->getBooleanQueries().empty() &&
!boolean_query->getLeafQueries().empty()) {
std::string msg =
"One layer cannot include bool query and leaf query.";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
} else {
for (auto query : boolean_query->getBooleanQueries()) {
@ -279,8 +293,11 @@ ValidateBooleanQuery(BooleanQueryPtr& boolean_query) {
auto status = Status::OK();
if (boolean_query != nullptr) {
for (auto& query : boolean_query->getBooleanQueries()) {
if (query->getOccur() == Occur::SHOULD || query->getOccur() == Occur::MUST_NOT) {
std::string msg = "The direct child node of 'bool' node cannot be 'should' node or 'must_not' node.";
if (query->getOccur() == Occur::SHOULD ||
query->getOccur() == Occur::MUST_NOT) {
std::string msg =
"The direct child node of 'bool' node cannot be 'should' "
"node or 'must_not' node.";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
}

View File

@ -20,10 +20,14 @@ namespace milvus {
namespace query_old {
BinaryQueryPtr
ConstructBinTree(std::vector<BooleanQueryPtr> clauses, QueryRelation relation, uint64_t idx);
ConstructBinTree(std::vector<BooleanQueryPtr> clauses,
QueryRelation relation,
uint64_t idx);
Status
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_clauses, BinaryQueryPtr binary_query, uint64_t idx);
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_clauses,
BinaryQueryPtr binary_query,
uint64_t idx);
Status
GenBinaryQuery(BooleanQueryPtr clause, BinaryQueryPtr& binary_query);

View File

@ -131,9 +131,10 @@ using QueryPtr = std::shared_ptr<Query>;
namespace query {
struct QueryDeprecated {
int64_t num_queries; //
int topK; // topK of queries
std::string field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM)
int64_t num_queries; //
int topK; // topK of queries
std::string
field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM)
std::vector<float> query_raw_data; // must be size of num_queries * DIM
};

View File

@ -53,7 +53,7 @@ CopyRowRecords(const RepeatedPtrField<proto::service::PlaceholderValue>& grpc_re
memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t));
}
// step 3: contruct vectors
// step 3: construct vectors
vectors.vector_count_ = grpc_records.size();
vectors.float_data_.swap(float_array);
vectors.binary_data_.swap(binary_array);
@ -62,7 +62,9 @@ CopyRowRecords(const RepeatedPtrField<proto::service::PlaceholderValue>& grpc_re
#endif
Status
ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& query, std::string& field_name) {
ProcessLeafQueryJson(const milvus::json& query_json,
query_old::BooleanQueryPtr& query,
std::string& field_name) {
#if 1
if (query_json.contains("term")) {
auto leaf_query = std::make_shared<query_old::LeafQuery>();
@ -120,12 +122,15 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
for (auto& json : must_json) {
auto must_query = std::make_shared<query_old::BooleanQuery>();
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
STATUS_CHECK(ProcessBooleanQueryJson(json, must_query, query_ptr));
if (json.contains("must") || json.contains("should") ||
json.contains("must_not")) {
STATUS_CHECK(
ProcessBooleanQueryJson(json, must_query, query_ptr));
boolean_query->AddBooleanQuery(must_query);
} else {
std::string field_name;
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name));
STATUS_CHECK(
ProcessLeafQueryJson(json, boolean_query, field_name));
if (!field_name.empty()) {
query_ptr->index_fields.insert(field_name);
}
@ -141,12 +146,15 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
for (auto& json : should_json) {
auto should_query = std::make_shared<query_old::BooleanQuery>();
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
STATUS_CHECK(ProcessBooleanQueryJson(json, should_query, query_ptr));
if (json.contains("must") || json.contains("should") ||
json.contains("must_not")) {
STATUS_CHECK(
ProcessBooleanQueryJson(json, should_query, query_ptr));
boolean_query->AddBooleanQuery(should_query);
} else {
std::string field_name;
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name));
STATUS_CHECK(
ProcessLeafQueryJson(json, boolean_query, field_name));
if (!field_name.empty()) {
query_ptr->index_fields.insert(field_name);
}
@ -161,20 +169,25 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
}
for (auto& json : should_json) {
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
auto must_not_query = std::make_shared<query_old::BooleanQuery>();
STATUS_CHECK(ProcessBooleanQueryJson(json, must_not_query, query_ptr));
if (json.contains("must") || json.contains("should") ||
json.contains("must_not")) {
auto must_not_query =
std::make_shared<query_old::BooleanQuery>();
STATUS_CHECK(ProcessBooleanQueryJson(
json, must_not_query, query_ptr));
boolean_query->AddBooleanQuery(must_not_query);
} else {
std::string field_name;
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name));
STATUS_CHECK(
ProcessLeafQueryJson(json, boolean_query, field_name));
if (!field_name.empty()) {
query_ptr->index_fields.insert(field_name);
}
}
}
} else {
std::string msg = "BoolQuery json string does not include bool query";
std::string msg =
"BoolQuery json string does not include bool query";
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
}
@ -183,7 +196,8 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
}
Status
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<
::milvus::grpc::VectorParam>& vector_params,
const std::string& dsl_string,
query_old::BooleanQueryPtr& boolean_query,
query_old::QueryPtr& query_ptr) {
@ -196,7 +210,8 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::gr
}
auto status = Status::OK();
if (vector_params.empty()) {
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL must include vector query");
return Status(SERVER_INVALID_DSL_PARAMETER,
"DSL must include vector query");
}
for (const auto& vector_param : vector_params) {
const std::string& vector_string = vector_param.json();
@ -216,32 +231,41 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::gr
if (param_json.contains("metric_type")) {
std::string metric_type = param_json["metric_type"];
vector_query->metric_type = metric_type;
query_ptr->metric_types.insert({field_name, param_json["metric_type"]});
query_ptr->metric_types.insert(
{field_name, param_json["metric_type"]});
}
if (!vector_param_it.value()["params"].empty()) {
vector_query->extra_params = vector_param_it.value()["params"];
vector_query->extra_params =
vector_param_it.value()["params"];
}
query_ptr->index_fields.insert(field_name);
}
engine::VectorsData vector_data;
CopyRowRecords(vector_param.row_record().records(),
google::protobuf::RepeatedField<google::protobuf::int64>(), vector_data);
CopyRowRecords(
vector_param.row_record().records(),
google::protobuf::RepeatedField<google::protobuf::int64>(),
vector_data);
vector_query->query_vector.vector_count = vector_data.vector_count_;
vector_query->query_vector.binary_data.swap(vector_data.binary_data_);
vector_query->query_vector.binary_data.swap(
vector_data.binary_data_);
vector_query->query_vector.float_data.swap(vector_data.float_data_);
query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));
query_ptr->vectors.insert(
std::make_pair(placeholder, vector_query));
}
if (dsl_json.contains("bool")) {
auto boolean_query_json = dsl_json["bool"];
JSON_NULL_CHECK(boolean_query_json);
status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr);
status = ProcessBooleanQueryJson(
boolean_query_json, boolean_query, query_ptr);
if (!status.ok()) {
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool");
return Status(SERVER_INVALID_DSL_PARAMETER,
"DSL does not include bool");
}
} else {
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool query");
return Status(SERVER_INVALID_DSL_PARAMETER,
"DSL does not include bool query");
}
return Status::OK();
} catch (std::exception& e) {
@ -254,7 +278,8 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::gr
#endif
query_old::QueryPtr
Transformer(proto::service::Query* request) {
query_old::BooleanQueryPtr boolean_query = std::make_shared<query_old::BooleanQuery>();
query_old::BooleanQueryPtr boolean_query =
std::make_shared<query_old::BooleanQuery>();
query_old::QueryPtr query_ptr = std::make_shared<query_old::Query>();
#if 0
query_ptr->collection_id = request->collection_name();

View File

@ -43,8 +43,10 @@ CheckParameterRange(const milvus::json& json_params,
bool min_err = min_close ? value < min : value <= min;
bool max_err = max_closed ? value > max : value >= max;
if (min_err || max_err) {
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value) + ". Valid range is " +
(min_close ? "[" : "(") + std::to_string(min) + ", " + std::to_string(max) +
std::string msg = "Invalid " + param_name +
" value: " + std::to_string(value) +
". Valid range is " + (min_close ? "[" : "(") +
std::to_string(min) + ", " + std::to_string(max) +
(max_closed ? "]" : ")");
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
@ -60,7 +62,8 @@ CheckParameterRange(const milvus::json& json_params,
}
Status
CheckParameterExistence(const milvus::json& json_params, const std::string& param_name) {
CheckParameterExistence(const milvus::json& json_params,
const std::string& param_name) {
if (json_params.find(param_name) == json_params.end()) {
std::string msg = "Parameter list must contain: ";
msg += param_name;
@ -71,7 +74,8 @@ CheckParameterExistence(const milvus::json& json_params, const std::string& para
try {
int64_t value = json_params[param_name];
if (value < 0) {
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value);
std::string msg =
"Invalid " + param_name + " value: " + std::to_string(value);
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
@ -96,11 +100,13 @@ ValidateCollectionName(const std::string& collection_name) {
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
}
std::string invalid_msg = "Invalid collection name: " + collection_name + ". ";
std::string invalid_msg =
"Invalid collection name: " + collection_name + ". ";
// Collection name size shouldn't exceed engine::MAX_NAME_LENGTH.
if (collection_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a collection name must be less than " +
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
std::string msg =
invalid_msg + "The length of a collection name must be less than " +
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
}
@ -108,7 +114,9 @@ ValidateCollectionName(const std::string& collection_name) {
// Collection name first character should be underscore or character.
char first_char = collection_name[0];
if (first_char != '_' && std::isalpha(first_char) == 0) {
std::string msg = invalid_msg + "The first character of a collection name must be an underscore or letter.";
std::string msg = invalid_msg +
"The first character of a collection name must be an "
"underscore or letter.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
}
@ -116,8 +124,11 @@ ValidateCollectionName(const std::string& collection_name) {
int64_t table_name_size = collection_name.size();
for (int64_t i = 1; i < table_name_size; ++i) {
char name_char = collection_name[i];
if (name_char != '_' && name_char != '$' && std::isalnum(name_char) == 0) {
std::string msg = invalid_msg + "Collection name can only contain numbers, letters, and underscores.";
if (name_char != '_' && name_char != '$' &&
std::isalnum(name_char) == 0) {
std::string msg = invalid_msg +
"Collection name can only contain numbers, "
"letters, and underscores.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
}
@ -138,8 +149,9 @@ ValidateFieldName(const std::string& field_name) {
std::string invalid_msg = "Invalid field name: " + field_name + ". ";
// Field name size shouldn't exceed engine::MAX_NAME_LENGTH.
if (field_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a field name must be less than " +
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
std::string msg =
invalid_msg + "The length of a field name must be less than " +
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_FIELD_NAME, msg);
}
@ -147,7 +159,9 @@ ValidateFieldName(const std::string& field_name) {
// Field name first character should be underscore or character.
char first_char = field_name[0];
if (first_char != '_' && std::isalpha(first_char) == 0) {
std::string msg = invalid_msg + "The first character of a field name must be an underscore or letter.";
std::string msg = invalid_msg +
"The first character of a field name must be an "
"underscore or letter.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_FIELD_NAME, msg);
}
@ -156,7 +170,9 @@ ValidateFieldName(const std::string& field_name) {
for (int64_t i = 1; i < field_name_size; ++i) {
char name_char = field_name[i];
if (name_char != '_' && std::isalnum(name_char) == 0) {
std::string msg = invalid_msg + "Field name cannot only contain numbers, letters, and underscores.";
std::string msg = invalid_msg +
"Field name cannot only contain numbers, "
"letters, and underscores.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_FIELD_NAME, msg);
}
@ -175,7 +191,8 @@ ValidateVectorIndexType(std::string& index_type, bool is_binary) {
}
// string case insensitive
std::transform(index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
std::transform(
index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
static std::set<std::string> s_vector_index_type = {
knowhere::IndexEnum::INVALID,
@ -192,7 +209,8 @@ ValidateVectorIndexType(std::string& index_type, bool is_binary) {
knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
};
std::set<std::string>& index_types = is_binary ? s_binary_index_types : s_vector_index_type;
std::set<std::string>& index_types =
is_binary ? s_binary_index_types : s_vector_index_type;
if (index_types.find(index_type) == index_types.end()) {
std::string msg = "Invalid index type: " + index_type;
LOG_SERVER_ERROR_ << msg;
@ -212,7 +230,8 @@ ValidateStructuredIndexType(std::string& index_type) {
}
// string case insensitive
std::transform(index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
std::transform(
index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
static std::set<std::string> s_index_types = {
engine::DEFAULT_STRUCTURED_INDEX,
@ -230,14 +249,16 @@ ValidateStructuredIndexType(std::string& index_type) {
Status
ValidateDimension(int64_t dim, bool is_binary) {
if (dim <= 0 || dim > engine::MAX_DIMENSION) {
std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be in range 1 ~ " +
std::string msg = "Invalid dimension: " + std::to_string(dim) +
". Should be in range 1 ~ " +
std::to_string(engine::MAX_DIMENSION) + ".";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
if (is_binary && (dim % 8) != 0) {
std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be multiple of 8.";
std::string msg = "Invalid dimension: " + std::to_string(dim) +
". Should be multiple of 8.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
@ -246,30 +267,36 @@ ValidateDimension(int64_t dim, bool is_binary) {
}
Status
ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const std::string& index_type) {
ValidateIndexParams(const milvus::json& index_params,
int64_t dimension,
const std::string& index_type) {
if (engine::utils::IsFlatIndexType(index_type)) {
return Status::OK();
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT ||
index_type == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 ||
index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536);
auto status = CheckParameterRange(
index_params, knowhere::IndexParams::nlist, 1, 65536);
if (!status.ok()) {
return status;
}
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536);
auto status = CheckParameterRange(
index_params, knowhere::IndexParams::nlist, 1, 65536);
if (!status.ok()) {
return status;
}
status = CheckParameterExistence(index_params, knowhere::IndexParams::m);
status =
CheckParameterExistence(index_params, knowhere::IndexParams::m);
if (!status.ok()) {
return status;
}
// special check for 'm' parameter
int64_t m_value = index_params[knowhere::IndexParams::m];
if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, m_value)) {
if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension,
m_value)) {
std::string msg = "Invalid m, dimension can't not be divided by m ";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
@ -298,16 +325,19 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
return Status(SERVER_INVALID_ARGUMENT, msg);
}*/
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64);
auto status =
CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64);
if (!status.ok()) {
return status;
}
status = CheckParameterRange(index_params, knowhere::IndexParams::efConstruction, 8, 512);
status = CheckParameterRange(
index_params, knowhere::IndexParams::efConstruction, 8, 512);
if (!status.ok()) {
return status;
}
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024);
auto status = CheckParameterRange(
index_params, knowhere::IndexParams::n_trees, 1, 1024);
if (!status.ok()) {
return status;
}
@ -321,8 +351,10 @@ ValidateSegmentRowCount(int64_t segment_row_count) {
int64_t min = config.engine.build_index_threshold();
int max = engine::MAX_SEGMENT_ROW_COUNT;
if (segment_row_count < min || segment_row_count > max) {
std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " +
"Should be in range " + std::to_string(min) + " ~ " + std::to_string(max) + ".";
std::string msg =
"Invalid segment row count: " + std::to_string(segment_row_count) +
". " + "Should be in range " + std::to_string(min) + " ~ " +
std::to_string(max) + ".";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg);
}
@ -330,21 +362,26 @@ ValidateSegmentRowCount(int64_t segment_row_count) {
}
Status
ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type) {
ValidateIndexMetricType(const std::string& metric_type,
const std::string& index_type) {
if (engine::utils::IsFlatIndexType(index_type)) {
// pass
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
// binary
if (metric_type != knowhere::Metric::HAMMING && metric_type != knowhere::Metric::JACCARD &&
if (metric_type != knowhere::Metric::HAMMING &&
metric_type != knowhere::Metric::JACCARD &&
metric_type != knowhere::Metric::TANIMOTO) {
std::string msg = "Index metric type " + metric_type + " does not match index type " + index_type;
std::string msg = "Index metric type " + metric_type +
" does not match index type " + index_type;
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
} else {
// float
if (metric_type != knowhere::Metric::L2 && metric_type != knowhere::Metric::IP) {
std::string msg = "Index metric type " + metric_type + " does not match index type " + index_type;
if (metric_type != knowhere::Metric::L2 &&
metric_type != knowhere::Metric::IP) {
std::string msg = "Index metric type " + metric_type +
" does not match index type " + index_type;
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
@ -357,16 +394,22 @@ Status
ValidateSearchMetricType(const std::string& metric_type, bool is_binary) {
if (is_binary) {
// binary
if (metric_type == knowhere::Metric::L2 || metric_type == knowhere::Metric::IP) {
std::string msg = "Cannot search binary entities with index metric type " + metric_type;
if (metric_type == knowhere::Metric::L2 ||
metric_type == knowhere::Metric::IP) {
std::string msg =
"Cannot search binary entities with index metric type " +
metric_type;
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
} else {
// float
if (metric_type == knowhere::Metric::HAMMING || metric_type == knowhere::Metric::JACCARD ||
if (metric_type == knowhere::Metric::HAMMING ||
metric_type == knowhere::Metric::JACCARD ||
metric_type == knowhere::Metric::TANIMOTO) {
std::string msg = "Cannot search float entities with index metric type " + metric_type;
std::string msg =
"Cannot search float entities with index metric type " +
metric_type;
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
@ -378,8 +421,8 @@ ValidateSearchMetricType(const std::string& metric_type, bool is_binary) {
Status
ValidateSearchTopk(int64_t top_k) {
if (top_k <= 0 || top_k > QUERY_MAX_TOPK) {
std::string msg =
"Invalid topk: " + std::to_string(top_k) + ". " + "The topk must be within the range of 1 ~ 16384.";
std::string msg = "Invalid topk: " + std::to_string(top_k) + ". " +
"The topk must be within the range of 1 ~ 16384.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_TOPK, msg);
}
@ -400,7 +443,9 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
std::string invalid_msg = "Invalid partition tag: " + tag + ". ";
// Partition tag size shouldn't exceed 255.
if (tag.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a partition tag must be less than 255 characters.";
std::string msg = invalid_msg +
"The length of a partition tag must be less than "
"255 characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_PARTITION_TAG, msg);
}
@ -408,7 +453,9 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
// Partition tag first character should be underscore or character.
char first_char = tag[0];
if (first_char != '_' && std::isalnum(first_char) == 0) {
std::string msg = invalid_msg + "The first character of a partition tag must be an underscore or letter.";
std::string msg = invalid_msg +
"The first character of a partition tag must be "
"an underscore or letter.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_PARTITION_TAG, msg);
}
@ -416,8 +463,11 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
int64_t tag_size = tag.size();
for (int64_t i = 1; i < tag_size; ++i) {
char name_char = tag[i];
if (name_char != '_' && name_char != '$' && std::isalnum(name_char) == 0) {
std::string msg = invalid_msg + "Partition tag can only contain numbers, letters, and underscores.";
if (name_char != '_' && name_char != '$' &&
std::isalnum(name_char) == 0) {
std::string msg = invalid_msg +
"Partition tag can only contain numbers, "
"letters, and underscores.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_PARTITION_TAG, msg);
}
@ -457,8 +507,9 @@ ValidateInsertDataSize(const InsertParam& insert_param) {
}
if (chunk_size > engine::MAX_INSERT_DATA_SIZE) {
std::string msg = "The amount of data inserted each time cannot exceed " +
std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB";
std::string msg =
"The amount of data inserted each time cannot exceed " +
std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB";
return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg);
}
@ -468,7 +519,9 @@ ValidateInsertDataSize(const InsertParam& insert_param) {
Status
ValidateCompactThreshold(double threshold) {
if (threshold > 1.0 || threshold < 0.0) {
std::string msg = "Invalid compact threshold: " + std::to_string(threshold) + ". Should be in range [0.0, 1.0]";
std::string msg =
"Invalid compact threshold: " + std::to_string(threshold) +
". Should be in range [0.0, 1.0]";
return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg);
}

View File

@ -42,13 +42,16 @@ extern Status
ValidateStructuredIndexType(std::string& index_type);
extern Status
ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const std::string& index_type);
ValidateIndexParams(const milvus::json& index_params,
int64_t dimension,
const std::string& index_type);
extern Status
ValidateSegmentRowCount(int64_t segment_row_count);
extern Status
ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type);
ValidateIndexMetricType(const std::string& metric_type,
const std::string& index_type);
extern Status
ValidateSearchMetricType(const std::string& metric_type, bool is_binary);

View File

@ -45,7 +45,9 @@ class ExecExprVisitor : public ExprVisitor {
visit(CompareExpr& expr) override;
public:
ExecExprVisitor(const segcore::SegmentInternalInterface& segment, int64_t row_count, Timestamp timestamp)
ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
int64_t row_count,
Timestamp timestamp)
: segment_(segment), row_count_(row_count), timestamp_(timestamp) {
}
@ -62,11 +64,15 @@ class ExecExprVisitor : public ExprVisitor {
public:
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecRangeVisitorImpl(FieldId field_id, IndexFunc func, ElementFunc element_func) -> BitsetType;
ExecRangeVisitorImpl(FieldId field_id,
IndexFunc func,
ElementFunc element_func) -> BitsetType;
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func) -> BitsetType;
ExecDataRangeVisitorImpl(FieldId field_id,
IndexFunc index_func,
ElementFunc element_func) -> BitsetType;
template <typename T>
auto
@ -74,7 +80,8 @@ class ExecExprVisitor : public ExprVisitor {
template <typename T>
auto
ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
ExecBinaryArithOpEvalRangeVisitorDispatcher(
BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
template <typename T>
auto
@ -90,7 +97,8 @@ class ExecExprVisitor : public ExprVisitor {
template <typename CmpFunc>
auto
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) -> BitsetType;
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
-> BitsetType;
private:
const segcore::SegmentInternalInterface& segment_;

View File

@ -34,10 +34,13 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
Timestamp timestamp,
const PlaceholderGroup* placeholder_group)
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
: segment_(segment),
timestamp_(timestamp),
placeholder_group_(placeholder_group) {
}
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, Timestamp timestamp)
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
Timestamp timestamp)
: segment_(segment), timestamp_(timestamp) {
placeholder_group_ = nullptr;
}

View File

@ -40,7 +40,8 @@ class ExtractInfoExprVisitor : public ExprVisitor {
visit(CompareExpr& expr) override;
public:
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info)
: plan_info_(plan_info) {
}
private:

View File

@ -28,7 +28,8 @@ class ExtractInfoPlanNodeVisitor : public PlanNodeVisitor {
visit(RetrievePlanNode& node) override;
public:
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info)
: plan_info_(plan_info) {
}
private:

View File

@ -28,15 +28,19 @@ namespace milvus::query {
namespace impl {
class ExecExprVisitor : ExprVisitor {
public:
ExecExprVisitor(const segcore::SegmentInternalInterface& segment, int64_t row_count, Timestamp timestamp)
ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
int64_t row_count,
Timestamp timestamp)
: segment_(segment), row_count_(row_count), timestamp_(timestamp) {
}
BitsetType
call_child(Expr& expr) {
AssertInfo(!bitset_opt_.has_value(), "[ExecExprVisitor]Bitset already has value before accept");
AssertInfo(!bitset_opt_.has_value(),
"[ExecExprVisitor]Bitset already has value before accept");
expr.accept(*this);
AssertInfo(bitset_opt_.has_value(), "[ExecExprVisitor]Bitset doesn't have value after accept");
AssertInfo(bitset_opt_.has_value(),
"[ExecExprVisitor]Bitset doesn't have value after accept");
auto res = std::move(bitset_opt_);
bitset_opt_ = std::nullopt;
return std::move(res.value());
@ -45,7 +49,9 @@ class ExecExprVisitor : ExprVisitor {
public:
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecRangeVisitorImpl(FieldId field_id, IndexFunc func, ElementFunc element_func) -> BitsetType;
ExecRangeVisitorImpl(FieldId field_id,
IndexFunc func,
ElementFunc element_func) -> BitsetType;
template <typename T>
auto
@ -53,7 +59,8 @@ class ExecExprVisitor : ExprVisitor {
template <typename T>
auto
ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
ExecBinaryArithOpEvalRangeVisitorDispatcher(
BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
template <typename T>
auto
@ -69,7 +76,8 @@ class ExecExprVisitor : ExprVisitor {
template <typename CmpFunc>
auto
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) -> BitsetType;
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
-> BitsetType;
private:
const segcore::SegmentInternalInterface& segment_;
@ -93,7 +101,8 @@ ExecExprVisitor::visit(LogicalUnaryExpr& expr) {
PanicInfo("Invalid Unary Op");
}
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
@ -102,7 +111,8 @@ ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
using OpType = LogicalBinaryExpr::OpType;
auto left = call_child(*expr.left_);
auto right = call_child(*expr.right_);
AssertInfo(left.size() == right.size(), "[ExecExprVisitor]Left size not equal to right size");
AssertInfo(left.size() == right.size(),
"[ExecExprVisitor]Left size not equal to right size");
auto res = std::move(left);
switch (expr.op_type_) {
case OpType::LogicalAnd: {
@ -125,7 +135,8 @@ ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
PanicInfo("Invalid Binary Op");
}
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
@ -151,7 +162,9 @@ Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func) -> BitsetType {
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
IndexFunc index_func,
ElementFunc element_func) -> BitsetType {
auto& schema = segment_.get_schema();
auto& field_meta = schema[field_id];
auto indexing_barrier = segment_.num_chunk_index(field_id);
@ -159,18 +172,24 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, El
auto num_chunk = upper_div(row_count_, size_per_chunk);
std::deque<BitsetType> results;
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
const Index& indexing = segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
const Index& indexing =
segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
// NOTE: knowhere is not const-ready
// This is a dirty workaround
auto data = index_func(const_cast<Index*>(&indexing));
AssertInfo(data->size() == size_per_chunk, "[ExecExprVisitor]Data size not equal to size_per_chunk");
AssertInfo(data->size() == size_per_chunk,
"[ExecExprVisitor]Data size not equal to size_per_chunk");
results.emplace_back(std::move(*data));
}
for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
auto this_size = chunk_id == num_chunk - 1
? row_count_ - chunk_id * size_per_chunk
: size_per_chunk;
BitsetType result(this_size);
auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
const T* data = chunk.data();
@ -182,13 +201,16 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, El
results.emplace_back(std::move(result));
}
auto final_result = Assemble(results);
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
AssertInfo(final_result.size() == row_count_,
"[ExecExprVisitor]Final result size not equal to row count");
return final_result;
}
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func)
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
IndexFunc index_func,
ElementFunc element_func)
-> BitsetType {
auto& schema = segment_.get_schema();
auto& field_meta = schema[field_id];
@ -205,23 +227,31 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func
// if sealed segment has loaded raw data on this field, then index_barrier = 0 and data_barrier = 1
// in this case, sealed segment execute expr plan using raw data
for (auto chunk_id = 0; chunk_id < data_barrier; ++chunk_id) {
auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
auto this_size = chunk_id == num_chunk - 1
? row_count_ - chunk_id * size_per_chunk
: size_per_chunk;
BitsetType result(this_size);
auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
const T* data = chunk.data();
for (int index = 0; index < this_size; ++index) {
result[index] = element_func(data[index]);
}
AssertInfo(result.size() == this_size, "[ExecExprVisitor]Chunk result size not equal to expected size");
AssertInfo(
result.size() == this_size,
"[ExecExprVisitor]Chunk result size not equal to expected size");
results.emplace_back(std::move(result));
}
// if sealed segment has loaded scalar index for this field, then index_barrier = 1 and data_barrier = 0
// in this case, sealed segment execute expr plan using scalar index
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
for (auto chunk_id = data_barrier; chunk_id < indexing_barrier; ++chunk_id) {
auto& indexing = segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
for (auto chunk_id = data_barrier; chunk_id < indexing_barrier;
++chunk_id) {
auto& indexing =
segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
auto this_size = const_cast<Index*>(&indexing)->Count();
BitsetType result(this_size);
for (int offset = 0; offset < this_size; ++offset) {
@ -231,7 +261,8 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func
}
auto final_result = Assemble(results);
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
AssertInfo(final_result.size() == row_count_,
"[ExecExprVisitor]Final result size not equal to row count");
return final_result;
}
@ -239,8 +270,11 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType {
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw)
-> BitsetType {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
auto& expr = static_cast<UnaryRangeExprImpl<IndexInnerType>&>(expr_raw);
@ -248,34 +282,52 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
auto val = IndexInnerType(expr.value_);
switch (op) {
case OpType::Equal: {
auto index_func = [val](Index* index) { return index->In(1, &val); };
auto index_func = [val](Index* index) {
return index->In(1, &val);
};
auto elem_func = [val](T x) { return (x == val); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case OpType::NotEqual: {
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
auto index_func = [val](Index* index) {
return index->NotIn(1, &val);
};
auto elem_func = [val](T x) { return (x != val); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case OpType::GreaterEqual: {
auto index_func = [val](Index* index) { return index->Range(val, OpType::GreaterEqual); };
auto index_func = [val](Index* index) {
return index->Range(val, OpType::GreaterEqual);
};
auto elem_func = [val](T x) { return (x >= val); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case OpType::GreaterThan: {
auto index_func = [val](Index* index) { return index->Range(val, OpType::GreaterThan); };
auto index_func = [val](Index* index) {
return index->Range(val, OpType::GreaterThan);
};
auto elem_func = [val](T x) { return (x > val); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case OpType::LessEqual: {
auto index_func = [val](Index* index) { return index->Range(val, OpType::LessEqual); };
auto index_func = [val](Index* index) {
return index->Range(val, OpType::LessEqual);
};
auto elem_func = [val](T x) { return (x <= val); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case OpType::LessThan: {
auto index_func = [val](Index* index) { return index->Range(val, OpType::LessThan); };
auto index_func = [val](Index* index) {
return index->Range(val, OpType::LessThan);
};
auto elem_func = [val](T x) { return (x < val); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case OpType::PrefixMatch: {
auto index_func = [val](Index* index) {
@ -285,7 +337,8 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
return index->Query(std::move(dataset));
};
auto elem_func = [val, op](T x) { return Match(x, val, op); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
// TODO: PostfixMatch
default: {
@ -299,7 +352,8 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(
BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
auto& expr = static_cast<BinaryArithOpEvalRangeExprImpl<T>&>(expr_raw);
using Index = index::ScalarIndex<T>;
auto arith_op = expr.arith_op_;
@ -311,46 +365,64 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRa
case OpType::Equal: {
switch (arith_op) {
case ArithOpType::Add: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x + right_operand) == val;
};
auto elem_func = [val, right_operand](T x) { return ((x + right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x + right_operand) == val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Sub: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x - right_operand) == val;
};
auto elem_func = [val, right_operand](T x) { return ((x - right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x - right_operand) == val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Mul: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x * right_operand) == val;
};
auto elem_func = [val, right_operand](T x) { return ((x * right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x * right_operand) == val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Div: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x / right_operand) == val;
};
auto elem_func = [val, right_operand](T x) { return ((x / right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x / right_operand) == val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Mod: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return static_cast<T>(fmod(x, right_operand)) == val;
};
auto elem_func = [val, right_operand](T x) {
return (static_cast<T>(fmod(x, right_operand)) == val);
};
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
default: {
PanicInfo("unsupported arithmetic operation");
@ -360,46 +432,64 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRa
case OpType::NotEqual: {
switch (arith_op) {
case ArithOpType::Add: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x + right_operand) != val;
};
auto elem_func = [val, right_operand](T x) { return ((x + right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x + right_operand) != val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Sub: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x - right_operand) != val;
};
auto elem_func = [val, right_operand](T x) { return ((x - right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x - right_operand) != val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Mul: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x * right_operand) != val;
};
auto elem_func = [val, right_operand](T x) { return ((x * right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x * right_operand) != val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Div: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return (x / right_operand) != val;
};
auto elem_func = [val, right_operand](T x) { return ((x / right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
auto elem_func = [val, right_operand](T x) {
return ((x / right_operand) != val);
};
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
case ArithOpType::Mod: {
auto index_func = [val, right_operand](Index* index, size_t offset) {
auto index_func = [val, right_operand](Index* index,
size_t offset) {
auto x = index->Reverse_Lookup(offset);
return static_cast<T>(fmod(x, right_operand)) != val;
};
auto elem_func = [val, right_operand](T x) {
return (static_cast<T>(fmod(x, right_operand)) != val);
};
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
return ExecDataRangeVisitorImpl<T>(
expr.field_id_, index_func, elem_func);
}
default: {
PanicInfo("unsupported arithmetic operation");
@ -417,8 +507,11 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRa
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType {
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw)
-> BitsetType {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
auto& expr = static_cast<BinaryRangeExprImpl<IndexInnerType>&>(expr_raw);
@ -427,7 +520,9 @@ ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) ->
IndexInnerType val1 = IndexInnerType(expr.lower_value_);
IndexInnerType val2 = IndexInnerType(expr.upper_value_);
auto index_func = [=](Index* index) { return index->Range(val1, lower_inclusive, val2, upper_inclusive); };
auto index_func = [=](Index* index) {
return index->Range(val1, lower_inclusive, val2, upper_inclusive);
};
if (lower_inclusive && upper_inclusive) {
auto elem_func = [val1, val2](T x) { return (val1 <= x && x <= val2); };
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
@ -490,7 +585,8 @@ ExecExprVisitor::visit(UnaryRangeExpr& expr) {
default:
PanicInfo("unsupported");
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
@ -528,7 +624,8 @@ ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
default:
PanicInfo("unsupported");
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
@ -578,7 +675,8 @@ ExecExprVisitor::visit(BinaryRangeExpr& expr) {
default:
PanicInfo("unsupported");
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
@ -598,8 +696,16 @@ struct relational {
template <typename Op>
auto
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) -> BitsetType {
using number = boost::variant<bool, int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
-> BitsetType {
using number = boost::variant<bool,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
std::string>;
auto size_per_chunk = segment_.size_per_chunk();
auto num_chunk = upper_div(row_count_, size_per_chunk);
std::deque<BitsetType> bitsets;
@ -607,120 +713,194 @@ ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) -> BitsetTy
// check for sealed segment, load either raw field data or index
auto left_indexing_barrier = segment_.num_chunk_index(expr.left_field_id_);
auto left_data_barrier = segment_.num_chunk_data(expr.left_field_id_);
AssertInfo(std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
"max(left_data_barrier, left_indexing_barrier) not equal to num_chunk");
AssertInfo(
std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
"max(left_data_barrier, left_indexing_barrier) not equal to num_chunk");
auto right_indexing_barrier = segment_.num_chunk_index(expr.right_field_id_);
auto right_indexing_barrier =
segment_.num_chunk_index(expr.right_field_id_);
auto right_data_barrier = segment_.num_chunk_data(expr.right_field_id_);
AssertInfo(std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
"max(right_data_barrier, right_indexing_barrier) not equal to num_chunk");
AssertInfo(
std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
"max(right_data_barrier, right_indexing_barrier) not equal to "
"num_chunk");
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
auto getChunkData = [&, chunk_id](DataType type, FieldId field_id,
int64_t data_barrier) -> std::function<const number(int)> {
auto size = chunk_id == num_chunk - 1
? row_count_ - chunk_id * size_per_chunk
: size_per_chunk;
auto getChunkData =
[&, chunk_id](DataType type, FieldId field_id, int64_t data_barrier)
-> std::function<const number(int)> {
switch (type) {
case DataType::BOOL: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<bool>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<bool>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<bool>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<bool>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::INT8: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<int8_t>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<int8_t>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<int8_t>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<int8_t>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::INT16: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<int16_t>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<int16_t>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<int16_t>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<int16_t>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::INT32: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<int32_t>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<int32_t>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<int32_t>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<int32_t>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::INT64: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<int64_t>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<int64_t>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<int64_t>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<int64_t>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::FLOAT: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<float>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<float>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<float>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<float>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::DOUBLE: {
if (chunk_id < data_barrier) {
auto chunk_data = segment_.chunk_data<double>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_.chunk_data<double>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<double>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing = segment_.chunk_scalar_index<double>(
field_id, chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
case DataType::VARCHAR: {
if (chunk_id < data_barrier) {
if (segment_.type() == SegmentType::Growing) {
auto chunk_data = segment_.chunk_data<std::string>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
auto chunk_data =
segment_
.chunk_data<std::string>(field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return chunk_data[i];
};
} else {
auto chunk_data = segment_.chunk_data<std::string_view>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return std::string(chunk_data[i]); };
auto chunk_data = segment_
.chunk_data<std::string_view>(
field_id, chunk_id)
.data();
return [chunk_data](int i) -> const number {
return std::string(chunk_data[i]);
};
}
} else {
// for case, sealed segment has loaded index for scalar field instead of raw data
auto& indexing = segment_.chunk_scalar_index<std::string>(field_id, chunk_id);
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
auto& indexing =
segment_.chunk_scalar_index<std::string>(field_id,
chunk_id);
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
default:
PanicInfo("unsupported datatype");
}
};
auto left = getChunkData(expr.left_data_type_, expr.left_field_id_, left_data_barrier);
auto right = getChunkData(expr.right_data_type_, expr.right_field_id_, right_data_barrier);
auto left = getChunkData(
expr.left_data_type_, expr.left_field_id_, left_data_barrier);
auto right = getChunkData(
expr.right_data_type_, expr.right_field_id_, right_data_barrier);
BitsetType bitset(size);
for (int i = 0; i < size; ++i) {
bool is_in = boost::apply_visitor(Relational<decltype(op)>{}, left(i), right(i));
bool is_in = boost::apply_visitor(
Relational<decltype(op)>{}, left(i), right(i));
bitset[i] = is_in;
}
bitsets.emplace_back(std::move(bitset));
}
auto final_result = Assemble(bitsets);
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(final_result.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
return final_result;
}
@ -729,10 +909,12 @@ ExecExprVisitor::visit(CompareExpr& expr) {
auto& schema = segment_.get_schema();
auto& left_field_meta = schema[expr.left_field_id_];
auto& right_field_meta = schema[expr.right_field_id_];
AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(),
"[ExecExprVisitor]Left data type not equal to left field mata type");
AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(),
"[ExecExprVisitor]right data type not equal to right field mata type");
AssertInfo(
expr.left_data_type_ == left_field_meta.get_data_type(),
"[ExecExprVisitor]Left data type not equal to left field meta type");
AssertInfo(
expr.right_data_type_ == right_field_meta.get_data_type(),
"[ExecExprVisitor]right data type not equal to right field meta type");
BitsetType res;
switch (expr.op_type_) {
@ -761,7 +943,8 @@ ExecExprVisitor::visit(CompareExpr& expr) {
break;
}
case OpType::PrefixMatch: {
res = ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
res =
ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
break;
}
// case OpType::PostfixMatch: {
@ -770,7 +953,8 @@ ExecExprVisitor::visit(CompareExpr& expr) {
PanicInfo("unsupported optype");
}
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
@ -785,7 +969,8 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
bool use_pk_index = false;
if (primary_filed_id.has_value()) {
use_pk_index = primary_filed_id.value() == field_id && IsPrimaryKeyDataType(field_meta.get_data_type());
use_pk_index = primary_filed_id.value() == field_id &&
IsPrimaryKeyDataType(field_meta.get_data_type());
}
if (use_pk_index) {
@ -816,7 +1001,8 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
auto _offset = (int64_t)offset.get();
bitset[_offset] = true;
}
AssertInfo(bitset.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(bitset.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
return bitset;
}
@ -825,27 +1011,34 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
template <>
auto
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw) -> BitsetType {
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw)
-> BitsetType {
return ExecTermVisitorImplTemplate<std::string>(expr_raw);
}
template <>
auto
ExecExprVisitor::ExecTermVisitorImpl<std::string_view>(TermExpr& expr_raw) -> BitsetType {
ExecExprVisitor::ExecTermVisitorImpl<std::string_view>(TermExpr& expr_raw)
-> BitsetType {
return ExecTermVisitorImplTemplate<std::string_view>(expr_raw);
}
template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
auto& expr = static_cast<TermExprImpl<IndexInnerType>&>(expr_raw);
const std::vector<IndexInnerType> terms(expr.terms_.begin(), expr.terms_.end());
const std::vector<IndexInnerType> terms(expr.terms_.begin(),
expr.terms_.end());
auto n = terms.size();
std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());
auto index_func = [&terms, n](Index* index) { return index->In(n, terms.data()); };
auto index_func = [&terms, n](Index* index) {
return index->In(n, terms.data());
};
auto elem_func = [&terms, &term_set](T x) {
//// terms has already been sorted.
// return std::binary_search(terms.begin(), terms.end(), x);
@ -858,7 +1051,8 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
// TODO: bool is so ugly here.
template <>
auto
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw) -> BitsetType {
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw)
-> BitsetType {
using T = bool;
auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
using Index = index::ScalarIndex<T>;
@ -932,7 +1126,8 @@ ExecExprVisitor::visit(TermExpr& expr) {
default:
PanicInfo("unsupported");
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
AssertInfo(res.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
} // namespace milvus::query

View File

@ -29,7 +29,9 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
Timestamp timestamp,
const PlaceholderGroup& placeholder_group)
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
: segment_(segment),
timestamp_(timestamp),
placeholder_group_(placeholder_group) {
}
SearchResult
@ -59,7 +61,10 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
static SearchResult
empty_search_result(int64_t num_queries, SearchInfo& search_info) {
SearchResult final_result;
SubSearchResult result(num_queries, search_info.topk_, search_info.metric_type_, search_info.round_decimal_);
SubSearchResult result(num_queries,
search_info.topk_,
search_info.metric_type_,
search_info.round_decimal_);
final_result.total_nq_ = num_queries;
final_result.unity_topK_ = search_info.topk_;
final_result.seg_offsets_ = std::move(result.mutable_seg_offsets());
@ -72,7 +77,8 @@ void
ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
// TODO: optimize here, remove the dynamic cast
assert(!search_result_opt_.has_value());
auto segment = dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
auto segment =
dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
AssertInfo(segment, "support SegmentSmallIndex Only");
SearchResult search_result;
auto& ph = placeholder_group_->at(0);
@ -85,14 +91,16 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
// skip all calculation
if (active_count == 0) {
search_result_opt_ = empty_search_result(num_queries, node.search_info_);
search_result_opt_ =
empty_search_result(num_queries, node.search_info_);
return;
}
std::unique_ptr<BitsetType> bitset_holder;
if (node.predicate_.has_value()) {
bitset_holder = std::make_unique<BitsetType>(
ExecExprVisitor(*segment, active_count, timestamp_).call_child(*node.predicate_.value()));
ExecExprVisitor(*segment, active_count, timestamp_)
.call_child(*node.predicate_.value()));
bitset_holder->flip();
} else {
bitset_holder = std::make_unique<BitsetType>(active_count, false);
@ -102,11 +110,17 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
segment->mask_with_delete(*bitset_holder, active_count, timestamp_);
// if bitset_holder is all 1's, we got empty result
if (bitset_holder->all()) {
search_result_opt_ = empty_search_result(num_queries, node.search_info_);
search_result_opt_ =
empty_search_result(num_queries, node.search_info_);
return;
}
BitsetView final_view = *bitset_holder;
segment->vector_search(node.search_info_, src_data, num_queries, timestamp_, final_view, search_result);
segment->vector_search(node.search_info_,
src_data,
num_queries,
timestamp_,
final_view,
search_result);
search_result_opt_ = std::move(search_result);
}
@ -114,7 +128,8 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
void
ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
assert(!retrieve_result_opt_.has_value());
auto segment = dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
auto segment =
dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
AssertInfo(segment, "Support SegmentSmallIndex Only");
RetrieveResult retrieve_result;
@ -127,7 +142,8 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
BitsetType bitset_holder;
if (node.predicate_ != nullptr) {
bitset_holder = ExecExprVisitor(*segment, active_count, timestamp_).call_child(*(node.predicate_));
bitset_holder = ExecExprVisitor(*segment, active_count, timestamp_)
.call_child(*(node.predicate_));
bitset_holder.flip();
}
@ -142,8 +158,9 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
BitsetView final_view = bitset_holder;
auto seg_offsets = segment->search_ids(final_view, timestamp_);
retrieve_result.result_offsets_.assign((int64_t*)seg_offsets.data(),
(int64_t*)seg_offsets.data() + seg_offsets.size());
retrieve_result.result_offsets_.assign(
(int64_t*)seg_offsets.data(),
(int64_t*)seg_offsets.data() + seg_offsets.size());
retrieve_result_opt_ = std::move(retrieve_result);
}

View File

@ -19,7 +19,8 @@ namespace impl {
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
class ExtractInfoExprVisitor : ExprVisitor {
public:
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info)
: plan_info_(plan_info) {
}
private:

View File

@ -20,7 +20,8 @@ namespace impl {
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
class ExtractInfoPlanNodeVisitor : PlanNodeVisitor {
public:
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info)
: plan_info_(plan_info) {
}
private:

View File

@ -58,11 +58,13 @@ class ShowExprNodeVisitor : ExprVisitor {
void
ShowExprVisitor::visit(LogicalUnaryExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
using OpType = LogicalUnaryExpr::OpType;
// TODO: use magic_enum if available
AssertInfo(expr.op_type_ == OpType::LogicalNot, "[ShowExprVisitor]Expr op type isn't LogicNot");
AssertInfo(expr.op_type_ == OpType::LogicalNot,
"[ShowExprVisitor]Expr op type isn't LogicNot");
auto op_name = "LogicalNot";
Json extra{
@ -74,7 +76,8 @@ ShowExprVisitor::visit(LogicalUnaryExpr& expr) {
void
ShowExprVisitor::visit(LogicalBinaryExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
using OpType = LogicalBinaryExpr::OpType;
// TODO: use magic_enum if available
@ -108,8 +111,10 @@ TermExtract(const TermExpr& expr_raw) {
void
ShowExprVisitor::visit(TermExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false,
"[ShowExprVisitor]Data type of expr isn't vector type");
auto terms = [&] {
switch (expr.data_type_) {
case DataType::BOOL:
@ -145,7 +150,9 @@ UnaryRangeExtract(const UnaryRangeExpr& expr_raw) {
using proto::plan::OpType;
using proto::plan::OpType_Name;
auto expr = dynamic_cast<const UnaryRangeExprImpl<T>*>(&expr_raw);
AssertInfo(expr, "[ShowExprVisitor]UnaryRangeExpr cast to UnaryRangeExprImpl failed");
AssertInfo(
expr,
"[ShowExprVisitor]UnaryRangeExpr cast to UnaryRangeExprImpl failed");
Json res{{"expr_type", "UnaryRange"},
{"field_id", expr->field_id_.get()},
{"data_type", datatype_name(expr->data_type_)},
@ -156,8 +163,10 @@ UnaryRangeExtract(const UnaryRangeExpr& expr_raw) {
void
ShowExprVisitor::visit(UnaryRangeExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false,
"[ShowExprVisitor]Data type of expr isn't vector type");
switch (expr.data_type_) {
case DataType::BOOL:
json_opt_ = UnaryRangeExtract<bool>(expr);
@ -191,7 +200,9 @@ BinaryRangeExtract(const BinaryRangeExpr& expr_raw) {
using proto::plan::OpType;
using proto::plan::OpType_Name;
auto expr = dynamic_cast<const BinaryRangeExprImpl<T>*>(&expr_raw);
AssertInfo(expr, "[ShowExprVisitor]BinaryRangeExpr cast to BinaryRangeExprImpl failed");
AssertInfo(
expr,
"[ShowExprVisitor]BinaryRangeExpr cast to BinaryRangeExprImpl failed");
Json res{{"expr_type", "BinaryRange"},
{"field_id", expr->field_id_.get()},
{"data_type", datatype_name(expr->data_type_)},
@ -204,8 +215,10 @@ BinaryRangeExtract(const BinaryRangeExpr& expr_raw) {
void
ShowExprVisitor::visit(BinaryRangeExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false,
"[ShowExprVisitor]Data type of expr isn't vector type");
switch (expr.data_type_) {
case DataType::BOOL:
json_opt_ = BinaryRangeExtract<bool>(expr);
@ -237,7 +250,8 @@ void
ShowExprVisitor::visit(CompareExpr& expr) {
using proto::plan::OpType;
using proto::plan::OpType_Name;
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
Json res{{"expr_type", "Compare"},
{"left_field_id", expr.left_field_id_.get()},
@ -256,13 +270,17 @@ BinaryArithOpEvalRangeExtract(const BinaryArithOpEvalRangeExpr& expr_raw) {
using proto::plan::OpType;
using proto::plan::OpType_Name;
auto expr = dynamic_cast<const BinaryArithOpEvalRangeExprImpl<T>*>(&expr_raw);
AssertInfo(expr, "[ShowExprVisitor]BinaryArithOpEvalRangeExpr cast to BinaryArithOpEvalRangeExprImpl failed");
auto expr =
dynamic_cast<const BinaryArithOpEvalRangeExprImpl<T>*>(&expr_raw);
AssertInfo(expr,
"[ShowExprVisitor]BinaryArithOpEvalRangeExpr cast to "
"BinaryArithOpEvalRangeExprImpl failed");
Json res{{"expr_type", "BinaryArithOpEvalRange"},
{"field_offset", expr->field_id_.get()},
{"data_type", datatype_name(expr->data_type_)},
{"arith_op", ArithOpType_Name(static_cast<ArithOpType>(expr->arith_op_))},
{"arith_op",
ArithOpType_Name(static_cast<ArithOpType>(expr->arith_op_))},
{"right_operand", expr->right_operand_},
{"op", OpType_Name(static_cast<OpType>(expr->op_type_))},
{"value", expr->value_}};
@ -271,8 +289,10 @@ BinaryArithOpEvalRangeExtract(const BinaryArithOpEvalRangeExpr& expr_raw) {
void
ShowExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
AssertInfo(!json_opt_.has_value(),
"[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false,
"[ShowExprVisitor]Data type of expr isn't vector type");
switch (expr.data_type_) {
case DataType::INT8:
json_opt_ = BinaryArithOpEvalRangeExtract<int8_t>(expr);

View File

@ -62,8 +62,10 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
};
if (node.predicate_.has_value()) {
ShowExprVisitor expr_show;
AssertInfo(node.predicate_.value(), "[ShowPlanNodeVisitor]Can't get value from node predict");
json_body["predicate"] = expr_show.call_child(node.predicate_->operator*());
AssertInfo(node.predicate_.value(),
"[ShowPlanNodeVisitor]Can't get value from node predict");
json_body["predicate"] =
expr_show.call_child(node.predicate_->operator*());
} else {
json_body["predicate"] = "None";
}
@ -84,8 +86,10 @@ ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) {
};
if (node.predicate_.has_value()) {
ShowExprVisitor expr_show;
AssertInfo(node.predicate_.value(), "[ShowPlanNodeVisitor]Can't get value from node predict");
json_body["predicate"] = expr_show.call_child(node.predicate_->operator*());
AssertInfo(node.predicate_.value(),
"[ShowPlanNodeVisitor]Can't get value from node predict");
json_body["predicate"] =
expr_show.call_child(node.predicate_->operator*());
} else {
json_body["predicate"] = "None";
}

View File

@ -16,7 +16,8 @@
namespace milvus::segcore {
Collection::Collection(const std::string& collection_proto) : schema_proto_(collection_proto) {
Collection::Collection(const std::string& collection_proto)
: schema_proto_(collection_proto) {
parse();
}
@ -35,7 +36,8 @@ Collection::parse() {
Assert(!schema_proto_.empty());
milvus::proto::schema::CollectionSchema collection_schema;
auto suc = google::protobuf::TextFormat::ParseFromString(schema_proto_, &collection_schema);
auto suc = google::protobuf::TextFormat::ParseFromString(
schema_proto_, &collection_schema);
if (!suc) {
std::cerr << "unmarshal schema string failed" << std::endl;

View File

@ -20,9 +20,13 @@ VectorBase::set_data_raw(ssize_t element_offset,
const FieldMeta& field_meta) {
if (field_meta.is_vector()) {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
return set_data_raw(element_offset, data->vectors().float_vector().data().data(), element_count);
return set_data_raw(element_offset,
data->vectors().float_vector().data().data(),
element_count);
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
return set_data_raw(element_offset, data->vectors().binary_vector().data(), element_count);
return set_data_raw(element_offset,
data->vectors().binary_vector().data(),
element_count);
} else {
PanicInfo("unsupported");
}
@ -30,7 +34,9 @@ VectorBase::set_data_raw(ssize_t element_offset,
switch (field_meta.get_data_type()) {
case DataType::BOOL: {
return set_data_raw(element_offset, data->scalars().bool_data().data().data(), element_count);
return set_data_raw(element_offset,
data->scalars().bool_data().data().data(),
element_count);
}
case DataType::INT8: {
auto src_data = data->scalars().int_data().data();
@ -45,16 +51,24 @@ VectorBase::set_data_raw(ssize_t element_offset,
return set_data_raw(element_offset, data_raw.data(), element_count);
}
case DataType::INT32: {
return set_data_raw(element_offset, data->scalars().int_data().data().data(), element_count);
return set_data_raw(element_offset,
data->scalars().int_data().data().data(),
element_count);
}
case DataType::INT64: {
return set_data_raw(element_offset, data->scalars().long_data().data().data(), element_count);
return set_data_raw(element_offset,
data->scalars().long_data().data().data(),
element_count);
}
case DataType::FLOAT: {
return set_data_raw(element_offset, data->scalars().float_data().data().data(), element_count);
return set_data_raw(element_offset,
data->scalars().float_data().data().data(),
element_count);
}
case DataType::DOUBLE: {
return set_data_raw(element_offset, data->scalars().double_data().data().data(), element_count);
return set_data_raw(element_offset,
data->scalars().double_data().data().data(),
element_count);
}
case DataType::VARCHAR: {
auto begin = data->scalars().string_data().data().begin();
@ -69,12 +83,16 @@ VectorBase::set_data_raw(ssize_t element_offset,
}
void
VectorBase::fill_chunk_data(ssize_t element_count, const DataArray* data, const FieldMeta& field_meta) {
VectorBase::fill_chunk_data(ssize_t element_count,
const DataArray* data,
const FieldMeta& field_meta) {
if (field_meta.is_vector()) {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
return fill_chunk_data(data->vectors().float_vector().data().data(), element_count);
return fill_chunk_data(data->vectors().float_vector().data().data(),
element_count);
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
return fill_chunk_data(data->vectors().binary_vector().data(), element_count);
return fill_chunk_data(data->vectors().binary_vector().data(),
element_count);
} else {
PanicInfo("unsupported");
}
@ -82,7 +100,8 @@ VectorBase::fill_chunk_data(ssize_t element_count, const DataArray* data, const
switch (field_meta.get_data_type()) {
case DataType::BOOL: {
return fill_chunk_data(data->scalars().bool_data().data().data(), element_count);
return fill_chunk_data(data->scalars().bool_data().data().data(),
element_count);
}
case DataType::INT8: {
auto src_data = data->scalars().int_data().data();
@ -97,16 +116,20 @@ VectorBase::fill_chunk_data(ssize_t element_count, const DataArray* data, const
return fill_chunk_data(data_raw.data(), element_count);
}
case DataType::INT32: {
return fill_chunk_data(data->scalars().int_data().data().data(), element_count);
return fill_chunk_data(data->scalars().int_data().data().data(),
element_count);
}
case DataType::INT64: {
return fill_chunk_data(data->scalars().long_data().data().data(), element_count);
return fill_chunk_data(data->scalars().long_data().data().data(),
element_count);
}
case DataType::FLOAT: {
return fill_chunk_data(data->scalars().float_data().data().data(), element_count);
return fill_chunk_data(data->scalars().float_data().data().data(),
element_count);
}
case DataType::DOUBLE: {
return fill_chunk_data(data->scalars().double_data().data().data(), element_count);
return fill_chunk_data(data->scalars().double_data().data().data(),
element_count);
}
case DataType::VARCHAR: {
auto vec = static_cast<ConcurrentVector<std::string>*>(this);

View File

@ -81,7 +81,8 @@ class ThreadSafeVector {
class VectorBase {
public:
explicit VectorBase(int64_t size_per_chunk) : size_per_chunk_(size_per_chunk) {
explicit VectorBase(int64_t size_per_chunk)
: size_per_chunk_(size_per_chunk) {
}
virtual ~VectorBase() = default;
@ -89,16 +90,23 @@ class VectorBase {
grow_to_at_least(int64_t element_count) = 0;
virtual void
set_data_raw(ssize_t element_offset, const void* source, ssize_t element_count) = 0;
set_data_raw(ssize_t element_offset,
const void* source,
ssize_t element_count) = 0;
void
set_data_raw(ssize_t element_offset, ssize_t element_count, const DataArray* data, const FieldMeta& field_meta);
set_data_raw(ssize_t element_offset,
ssize_t element_count,
const DataArray* data,
const FieldMeta& field_meta);
virtual void
fill_chunk_data(const void* source, ssize_t element_count) = 0;
void
fill_chunk_data(ssize_t element_count, const DataArray* data, const FieldMeta& field_meta);
fill_chunk_data(ssize_t element_count,
const DataArray* data,
const FieldMeta& field_meta);
virtual SpanBase
get_span_base(int64_t chunk_id) const = 0;
@ -135,7 +143,11 @@ class ConcurrentVectorImpl : public VectorBase {
operator=(const ConcurrentVectorImpl&) = delete;
using TraitType =
std::conditional_t<is_scalar, Type, std::conditional_t<std::is_same_v<Type, float>, FloatVector, BinaryVector>>;
std::conditional_t<is_scalar,
Type,
std::conditional_t<std::is_same_v<Type, float>,
FloatVector,
BinaryVector>>;
public:
explicit ConcurrentVectorImpl(ssize_t dim, int64_t size_per_chunk)
@ -160,11 +172,13 @@ class ConcurrentVectorImpl : public VectorBase {
auto& chunk = get_chunk(chunk_id);
if constexpr (is_scalar) {
return Span<TraitType>(chunk.data(), chunk.size());
} else if constexpr (std::is_same_v<Type, int64_t> || std::is_same_v<Type, int>) {
} else if constexpr (std::is_same_v<Type, int64_t> ||
std::is_same_v<Type, int>) {
// only for testing
PanicInfo("unimplemented");
} else {
static_assert(std::is_same_v<typename TraitType::embedded_type, Type>);
static_assert(
std::is_same_v<typename TraitType::embedded_type, Type>);
return Span<TraitType>(chunk.data(), chunk.size(), Dim);
}
}
@ -185,23 +199,29 @@ class ConcurrentVectorImpl : public VectorBase {
}
void
set_data_raw(ssize_t element_offset, const void* source, ssize_t element_count) override {
set_data_raw(ssize_t element_offset,
const void* source,
ssize_t element_count) override {
if (element_count == 0) {
return;
}
this->grow_to_at_least(element_offset + element_count);
set_data(element_offset, static_cast<const Type*>(source), element_count);
set_data(
element_offset, static_cast<const Type*>(source), element_count);
}
void
set_data(ssize_t element_offset, const Type* source, ssize_t element_count) {
set_data(ssize_t element_offset,
const Type* source,
ssize_t element_count) {
auto chunk_id = element_offset / size_per_chunk_;
auto chunk_offset = element_offset % size_per_chunk_;
ssize_t source_offset = 0;
// first partition:
if (chunk_offset + element_count <= size_per_chunk_) {
// only first
fill_chunk(chunk_id, chunk_offset, element_count, source, source_offset);
fill_chunk(
chunk_id, chunk_offset, element_count, source, source_offset);
return;
}
@ -280,17 +300,25 @@ class ConcurrentVectorImpl : public VectorBase {
private:
void
fill_chunk(
ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source, ssize_t source_offset) {
fill_chunk(ssize_t chunk_id,
ssize_t chunk_offset,
ssize_t element_count,
const Type* source,
ssize_t source_offset) {
if (element_count <= 0) {
return;
}
auto chunk_num = chunks_.size();
AssertInfo(chunk_id < chunk_num,
fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}", chunk_id, chunk_num));
AssertInfo(
chunk_id < chunk_num,
fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}",
chunk_id,
chunk_num));
Chunk& chunk = chunks_[chunk_id];
auto ptr = chunk.data();
std::copy_n(source + source_offset * Dim, element_count * Dim, ptr + chunk_offset * Dim);
std::copy_n(source + source_offset * Dim,
element_count * Dim,
ptr + chunk_offset * Dim);
}
const ssize_t Dim;
@ -304,20 +332,24 @@ class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
public:
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
explicit ConcurrentVector(int64_t size_per_chunk)
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(1, size_per_chunk) {
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(
1, size_per_chunk) {
}
};
template <>
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {
class ConcurrentVector<FloatVector>
: public ConcurrentVectorImpl<float, false> {
public:
ConcurrentVector(int64_t dim, int64_t size_per_chunk)
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(dim, size_per_chunk) {
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(
dim, size_per_chunk) {
}
};
template <>
class ConcurrentVector<BinaryVector> : public ConcurrentVectorImpl<uint8_t, false> {
class ConcurrentVector<BinaryVector>
: public ConcurrentVectorImpl<uint8_t, false> {
public:
explicit ConcurrentVector(int64_t dim, int64_t size_per_chunk)
: binary_dim_(dim), ConcurrentVectorImpl(dim / 8, size_per_chunk) {

View File

@ -32,7 +32,9 @@ struct DeletedRecord {
};
static constexpr int64_t deprecated_size_per_chunk = 32 * 1024;
DeletedRecord()
: lru_(std::make_shared<TmpBitmap>()), timestamps_(deprecated_size_per_chunk), pks_(deprecated_size_per_chunk) {
: lru_(std::make_shared<TmpBitmap>()),
timestamps_(deprecated_size_per_chunk),
pks_(deprecated_size_per_chunk) {
lru_->bitmap_ptr = std::make_shared<BitsetType>();
}
@ -43,12 +45,16 @@ struct DeletedRecord {
}
std::shared_ptr<TmpBitmap>
clone_lru_entry(int64_t insert_barrier, int64_t del_barrier, int64_t& old_del_barrier, bool& hit_cache) {
clone_lru_entry(int64_t insert_barrier,
int64_t del_barrier,
int64_t& old_del_barrier,
bool& hit_cache) {
std::shared_lock lck(shared_mutex_);
auto res = lru_->clone(insert_barrier);
old_del_barrier = lru_->del_barrier;
if (lru_->bitmap_ptr->size() == insert_barrier && lru_->del_barrier == del_barrier) {
if (lru_->bitmap_ptr->size() == insert_barrier &&
lru_->del_barrier == del_barrier) {
hit_cache = true;
} else {
res->del_barrier = del_barrier;
@ -61,7 +67,8 @@ struct DeletedRecord {
insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry, bool force = false) {
std::lock_guard lck(shared_mutex_);
if (new_entry->del_barrier <= lru_->del_barrier) {
if (!force || new_entry->bitmap_ptr->size() <= lru_->bitmap_ptr->size()) {
if (!force ||
new_entry->bitmap_ptr->size() <= lru_->bitmap_ptr->size()) {
// DO NOTHING
return;
}
@ -81,7 +88,8 @@ struct DeletedRecord {
};
inline auto
DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBitmap> {
DeletedRecord::TmpBitmap::clone(int64_t capacity)
-> std::shared_ptr<TmpBitmap> {
auto res = std::make_shared<TmpBitmap>();
res->del_barrier = this->del_barrier;
res->bitmap_ptr = std::make_shared<BitsetType>();

View File

@ -21,8 +21,11 @@
namespace milvus::segcore {
void
VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
AssertInfo(field_meta_.get_data_type() == DataType::VECTOR_FLOAT, "Data type of vector field is not VECTOR_FLOAT");
VectorFieldIndexing::BuildIndexRange(int64_t ack_beg,
int64_t ack_end,
const VectorBase* vec_base) {
AssertInfo(field_meta_.get_data_type() == DataType::VECTOR_FLOAT,
"Data type of vector field is not VECTOR_FLOAT");
auto dim = field_meta_.get_dim();
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
@ -33,9 +36,12 @@ VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vec
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
const auto& chunk = source->get_chunk(chunk_id);
auto indexing = std::make_unique<index::VectorMemNMIndex>(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
knowhere::metric::L2, IndexMode::MODE_CPU);
auto dataset = knowhere::GenDataSet(source->get_size_per_chunk(), dim, chunk.data());
auto indexing = std::make_unique<index::VectorMemNMIndex>(
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
knowhere::metric::L2,
IndexMode::MODE_CPU);
auto dataset = knowhere::GenDataSet(
source->get_size_per_chunk(), dim, chunk.data());
indexing->BuildWithDataset(dataset, conf);
data_[chunk_id] = std::move(indexing);
}
@ -45,7 +51,8 @@ knowhere::Json
VectorFieldIndexing::get_build_params() const {
// TODO
auto type_opt = field_meta_.get_metric_type();
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
AssertInfo(type_opt.has_value(),
"Metric type of field meta doesn't have value");
auto& metric_type = type_opt.value();
auto& config = segcore_config_.at(metric_type);
auto base_params = config.build_params;
@ -61,12 +68,14 @@ knowhere::Json
VectorFieldIndexing::get_search_params(int top_K) const {
// TODO
auto type_opt = field_meta_.get_metric_type();
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
AssertInfo(type_opt.has_value(),
"Metric type of field meta doesn't have value");
auto& metric_type = type_opt.value();
auto& config = segcore_config_.at(metric_type);
auto base_params = config.search_params;
AssertInfo(base_params.count("nprobe"), "Can't get nprobe from base params");
AssertInfo(base_params.count("nprobe"),
"Can't get nprobe from base params");
base_params[knowhere::meta::TOPK] = top_K;
base_params[knowhere::meta::METRIC_TYPE] = metric_type;
@ -75,7 +84,9 @@ VectorFieldIndexing::get_search_params(int top_K) const {
template <typename T>
void
ScalarFieldIndexing<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
ScalarFieldIndexing<T>::BuildIndexRange(int64_t ack_beg,
int64_t ack_end,
const VectorBase* vec_base) {
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
AssertInfo(source, "vec_base can't cast to ConcurrentVector type");
auto num_chunk = source->num_chunk();
@ -101,7 +112,8 @@ std::unique_ptr<FieldIndexing>
CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config) {
if (field_meta.is_vector()) {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
return std::make_unique<VectorFieldIndexing>(field_meta, segcore_config);
return std::make_unique<VectorFieldIndexing>(field_meta,
segcore_config);
} else {
// TODO
PanicInfo("unsupported");
@ -109,21 +121,29 @@ CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config) {
}
switch (field_meta.get_data_type()) {
case DataType::BOOL:
return std::make_unique<ScalarFieldIndexing<bool>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<bool>>(field_meta,
segcore_config);
case DataType::INT8:
return std::make_unique<ScalarFieldIndexing<int8_t>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<int8_t>>(
field_meta, segcore_config);
case DataType::INT16:
return std::make_unique<ScalarFieldIndexing<int16_t>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<int16_t>>(
field_meta, segcore_config);
case DataType::INT32:
return std::make_unique<ScalarFieldIndexing<int32_t>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<int32_t>>(
field_meta, segcore_config);
case DataType::INT64:
return std::make_unique<ScalarFieldIndexing<int64_t>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<int64_t>>(
field_meta, segcore_config);
case DataType::FLOAT:
return std::make_unique<ScalarFieldIndexing<float>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<float>>(field_meta,
segcore_config);
case DataType::DOUBLE:
return std::make_unique<ScalarFieldIndexing<double>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<double>>(
field_meta, segcore_config);
case DataType::VARCHAR:
return std::make_unique<ScalarFieldIndexing<std::string>>(field_meta, segcore_config);
return std::make_unique<ScalarFieldIndexing<std::string>>(
field_meta, segcore_config);
default:
PanicInfo("unsupported");
}

View File

@ -31,7 +31,8 @@ namespace milvus::segcore {
// All concurrent
class FieldIndexing {
public:
explicit FieldIndexing(const FieldMeta& field_meta, const SegcoreConfig& segcore_config)
explicit FieldIndexing(const FieldMeta& field_meta,
const SegcoreConfig& segcore_config)
: field_meta_(field_meta), segcore_config_(segcore_config) {
}
FieldIndexing(const FieldIndexing&) = delete;
@ -41,7 +42,9 @@ class FieldIndexing {
// Do this in parallel
virtual void
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) = 0;
BuildIndexRange(int64_t ack_beg,
int64_t ack_end,
const VectorBase* vec_base) = 0;
const FieldMeta&
get_field_meta() {
@ -68,7 +71,9 @@ class ScalarFieldIndexing : public FieldIndexing {
using FieldIndexing::FieldIndexing;
void
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override;
BuildIndexRange(int64_t ack_beg,
int64_t ack_end,
const VectorBase* vec_base) override;
// concurrent
index::ScalarIndex<T>*
@ -86,7 +91,9 @@ class VectorFieldIndexing : public FieldIndexing {
using FieldIndexing::FieldIndexing;
void
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override;
BuildIndexRange(int64_t ack_beg,
int64_t ack_end,
const VectorBase* vec_base) override;
// concurrent
index::IndexBase*
@ -110,7 +117,8 @@ CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config);
class IndexingRecord {
public:
explicit IndexingRecord(const Schema& schema, const SegcoreConfig& segcore_config)
explicit IndexingRecord(const Schema& schema,
const SegcoreConfig& segcore_config)
: schema_(schema), segcore_config_(segcore_config) {
Initialize();
}
@ -132,7 +140,8 @@ class IndexingRecord {
}
}
field_indexings_.try_emplace(field_id, CreateIndex(field_meta, segcore_config_));
field_indexings_.try_emplace(
field_id, CreateIndex(field_meta, segcore_config_));
}
assert(offset_id == schema_.size());
}
@ -140,7 +149,8 @@ class IndexingRecord {
// concurrent, reentrant
template <bool is_sealed>
void
UpdateResourceAck(int64_t chunk_ack, const InsertRecord<is_sealed>& record) {
UpdateResourceAck(int64_t chunk_ack,
const InsertRecord<is_sealed>& record) {
if (resource_ack_ >= chunk_ack) {
return;
}
@ -189,7 +199,8 @@ class IndexingRecord {
template <typename T>
auto
get_scalar_field_indexing(FieldId field_id) const -> const ScalarFieldIndexing<T>& {
get_scalar_field_indexing(FieldId field_id) const
-> const ScalarFieldIndexing<T>& {
auto& entry = get_field_indexing(field_id);
auto ptr = dynamic_cast<const ScalarFieldIndexing<T>*>(&entry);
AssertInfo(ptr, "invalid indexing");

View File

@ -50,7 +50,8 @@ class OffsetHashMap : public OffsetMap {
std::vector<int64_t>
find(const PkType pk) const {
auto offset_vector = map_.find(std::get<T>(pk));
return offset_vector != map_.end() ? offset_vector->second : std::vector<int64_t>();
return offset_vector != map_.end() ? offset_vector->second
: std::vector<int64_t>();
}
void
@ -60,7 +61,8 @@ class OffsetHashMap : public OffsetMap {
void
seal() {
PanicInfo("OffsetHashMap used for growing segment could not be sealed.");
PanicInfo(
"OffsetHashMap used for growing segment could not be sealed.");
}
bool
@ -138,26 +140,32 @@ struct InsertRecord {
// pks to row offset
std::unique_ptr<OffsetMap> pk2offset_;
InsertRecord(const Schema& schema, int64_t size_per_chunk) : row_ids_(size_per_chunk), timestamps_(size_per_chunk) {
InsertRecord(const Schema& schema, int64_t size_per_chunk)
: row_ids_(size_per_chunk), timestamps_(size_per_chunk) {
std::optional<FieldId> pk_field_id = schema.get_primary_field_id();
for (auto& field : schema) {
auto field_id = field.first;
auto& field_meta = field.second;
if (pk2offset_ == nullptr && pk_field_id.has_value() && pk_field_id.value() == field_id) {
if (pk2offset_ == nullptr && pk_field_id.has_value() &&
pk_field_id.value() == field_id) {
switch (field_meta.get_data_type()) {
case DataType::INT64: {
if (is_sealed)
pk2offset_ = std::make_unique<OffsetOrderedArray<int64_t>>();
pk2offset_ =
std::make_unique<OffsetOrderedArray<int64_t>>();
else
pk2offset_ = std::make_unique<OffsetHashMap<int64_t>>();
pk2offset_ =
std::make_unique<OffsetHashMap<int64_t>>();
break;
}
case DataType::VARCHAR: {
if (is_sealed)
pk2offset_ = std::make_unique<OffsetOrderedArray<std::string>>();
pk2offset_ = std::make_unique<
OffsetOrderedArray<std::string>>();
else
pk2offset_ = std::make_unique<OffsetHashMap<std::string>>();
pk2offset_ =
std::make_unique<OffsetHashMap<std::string>>();
break;
}
default: {
@ -167,10 +175,13 @@ struct InsertRecord {
}
if (field_meta.is_vector()) {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
this->append_field_data<FloatVector>(field_id, field_meta.get_dim(), size_per_chunk);
this->append_field_data<FloatVector>(
field_id, field_meta.get_dim(), size_per_chunk);
continue;
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
this->append_field_data<BinaryVector>(field_id, field_meta.get_dim(), size_per_chunk);
} else if (field_meta.get_data_type() ==
DataType::VECTOR_BINARY) {
this->append_field_data<BinaryVector>(
field_id, field_meta.get_dim(), size_per_chunk);
continue;
} else {
PanicInfo("unsupported");
@ -206,7 +217,8 @@ struct InsertRecord {
break;
}
case DataType::VARCHAR: {
this->append_field_data<std::string>(field_id, size_per_chunk);
this->append_field_data<std::string>(field_id,
size_per_chunk);
break;
}
default: {
@ -263,7 +275,8 @@ struct InsertRecord {
VectorBase*
get_field_data_base(FieldId field_id) const {
AssertInfo(fields_data_.find(field_id) != fields_data_.end(),
"Cannot find field_data with field_id: " + std::to_string(field_id.get()));
"Cannot find field_data with field_id: " +
std::to_string(field_id.get()));
auto ptr = fields_data_.at(field_id).get();
return ptr;
}
@ -293,7 +306,8 @@ struct InsertRecord {
void
append_field_data(FieldId field_id, int64_t size_per_chunk) {
static_assert(IsScalar<Type>);
fields_data_.emplace(field_id, std::make_unique<ConcurrentVector<Type>>(size_per_chunk));
fields_data_.emplace(
field_id, std::make_unique<ConcurrentVector<Type>>(size_per_chunk));
}
// append a column of vector type
@ -301,7 +315,9 @@ struct InsertRecord {
void
append_field_data(FieldId field_id, int64_t dim, int64_t size_per_chunk) {
static_assert(std::is_base_of_v<VectorTrait, VectorType>);
fields_data_.emplace(field_id, std::make_unique<ConcurrentVector<VectorType>>(dim, size_per_chunk));
fields_data_.emplace(field_id,
std::make_unique<ConcurrentVector<VectorType>>(
dim, size_per_chunk));
}
void

View File

@ -27,7 +27,8 @@ void
ReduceHelper::Initialize() {
AssertInfo(search_results_.size() > 0, "empty search result");
AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs");
AssertInfo(slice_nqs_.size() == slice_topKs_.size(), "unaligned slice_nqs and slice_topKs");
AssertInfo(slice_nqs_.size() == slice_topKs_.size(),
"unaligned slice_nqs and slice_topKs");
total_nq_ = search_results_[0]->total_nq_;
num_segments_ = search_results_.size();
@ -36,10 +37,13 @@ ReduceHelper::Initialize() {
// prefix sum, get slices offsets
AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed");
slice_nqs_prefix_sum_.resize(num_slices_ + 1);
std::partial_sum(slice_nqs_.begin(), slice_nqs_.end(), slice_nqs_prefix_sum_.begin() + 1);
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_, "illegal req sizes, slice_nqs_prefix_sum_[last] = " +
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
", total_nq = " + std::to_string(total_nq_));
std::partial_sum(slice_nqs_.begin(),
slice_nqs_.end(),
slice_nqs_prefix_sum_.begin() + 1);
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_,
"illegal req sizes, slice_nqs_prefix_sum_[last] = " +
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
", total_nq = " + std::to_string(total_nq_));
// init final_search_records and final_read_topKs
final_search_records_.resize(num_segments_);
@ -59,7 +63,8 @@ ReduceHelper::Reduce() {
void
ReduceHelper::Marshal() {
// get search result data blobs of slices
search_result_data_blobs_ = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
search_result_data_blobs_ =
std::make_unique<milvus::segcore::SearchResultDataBlobs>();
search_result_data_blobs_->blobs.resize(num_slices_);
for (int i = 0; i < num_slices_; i++) {
auto proto = GetSearchResultDataSlice(i);
@ -72,10 +77,12 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
auto nq = search_result->total_nq_;
auto topK = search_result->unity_topK_;
AssertInfo(search_result->seg_offsets_.size() == nq * topK,
"wrong seg offsets size, size = " + std::to_string(search_result->seg_offsets_.size()) +
"wrong seg offsets size, size = " +
std::to_string(search_result->seg_offsets_.size()) +
", expected size = " + std::to_string(nq * topK));
AssertInfo(search_result->distances_.size() == nq * topK,
"wrong distances size, size = " + std::to_string(search_result->distances_.size()) +
"wrong distances size, size = " +
std::to_string(search_result->distances_.size()) +
", expected size = " + std::to_string(nq * topK));
std::vector<int64_t> real_topks(nq, 0);
uint32_t valid_index = 0;
@ -96,7 +103,9 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
distances.resize(valid_index);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
std::partial_sum(real_topks.begin(),
real_topks.end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
void
@ -107,7 +116,8 @@ ReduceHelper::FillPrimaryKey() {
for (auto& search_result : search_results_) {
FilterInvalidSearchResult(search_result);
if (search_result->get_total_result_count() > 0) {
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
auto segment =
static_cast<SegmentInterface*>(search_result->segment_);
segment->FillPrimaryKeys(plan_, *search_result);
search_results_[valid_index++] = search_result;
}
@ -144,20 +154,25 @@ ReduceHelper::RefreshSearchResult() {
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
}
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
std::partial_sum(real_topks.begin(),
real_topks.end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
}
void
ReduceHelper::FillEntryData() {
for (auto search_result : search_results_) {
auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
auto segment = static_cast<milvus::segcore::SegmentInterface*>(
search_result->segment_);
segment->FillTargetEntry(plan_, *search_result);
}
}
int64_t
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) {
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& offset) {
while (!heap_.empty()) {
heap_.pop();
}
@ -175,7 +190,8 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offs
auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg];
pairs_.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
pairs_.emplace_back(
primary_key, distance, search_result, i, offset_beg, offset_end);
heap_.push(&pairs_.back());
}
@ -218,10 +234,14 @@ ReduceHelper::ReduceResultData() {
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto result_count = search_result->get_total_result_count();
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size");
AssertInfo(search_result->seg_offsets_.size() == result_count, "incorrect search result seg offset size");
AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
AssertInfo(search_result != nullptr,
"search result must not equal to nullptr");
AssertInfo(search_result->distances_.size() == result_count,
"incorrect search result distance size");
AssertInfo(search_result->seg_offsets_.size() == result_count,
"incorrect search result seg offset size");
AssertInfo(search_result->primary_keys_.size() == result_count,
"incorrect search result primary key size");
}
int64_t skip_dup_cnt = 0;
@ -232,11 +252,13 @@ ReduceHelper::ReduceResultData() {
// reduce search results
int64_t offset = 0;
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], offset);
skip_dup_cnt += ReduceSearchResultForOneNQ(
qi, slice_topKs_[slice_index], offset);
}
}
if (skip_dup_cnt > 0) {
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = "
<< skip_dup_cnt;
}
}
@ -247,13 +269,15 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
int64_t result_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() ==
search_result->total_nq_ + 1,
"incorrect topk_per_nq_prefix_sum_ size in search result");
result_count +=
search_result->topk_per_nq_prefix_sum_[nq_end] - search_result->topk_per_nq_prefix_sum_[nq_begin];
result_count += search_result->topk_per_nq_prefix_sum_[nq_end] -
search_result->topk_per_nq_prefix_sum_[nq_begin];
}
auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
auto search_result_data =
std::make_unique<milvus::proto::schema::SearchResultData>();
// set unify_topK and total_nq
search_result_data->set_top_k(slice_topKs_[slice_index]);
search_result_data->set_num_queries(nq_end - nq_begin);
@ -263,14 +287,16 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
// reserve space for pks
auto primary_field_id = plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
auto primary_field_id =
plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
AssertInfo(primary_field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
auto pk_type = plan_->schema_[primary_field_id].get_data_type();
switch (pk_type) {
case milvus::DataType::INT64: {
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
ids->mutable_data()->Resize(result_count, 0);
search_result_data->mutable_ids()->set_allocated_int_id(ids.release());
search_result_data->mutable_ids()->set_allocated_int_id(
ids.release());
break;
}
case milvus::DataType::VARCHAR: {
@ -278,7 +304,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
std::vector<std::string> string_pks(result_count);
// TODO: prevent mem copy
*ids->mutable_data() = {string_pks.begin(), string_pks.end()};
search_result_data->mutable_ids()->set_allocated_str_id(ids.release());
search_result_data->mutable_ids()->set_allocated_str_id(
ids.release());
break;
}
default: {
@ -293,7 +320,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result != nullptr, "null search result when reorganize");
AssertInfo(search_result != nullptr,
"null search result when reorganize");
if (search_result->result_offsets_.size() == 0) {
continue;
}
@ -305,18 +333,26 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
for (auto ki = topk_start; ki < topk_end; ki++) {
auto loc = search_result->result_offsets_[ki];
AssertInfo(loc < result_count && loc >= 0,
"invalid loc when GetSearchResultDataSlice, loc = " + std::to_string(loc) +
", result_count = " + std::to_string(result_count));
"invalid loc when GetSearchResultDataSlice, loc = " +
std::to_string(loc) + ", result_count = " +
std::to_string(result_count));
// set result pks
switch (pk_type) {
case milvus::DataType::INT64: {
search_result_data->mutable_ids()->mutable_int_id()->mutable_data()->Set(
loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[ki]));
search_result_data->mutable_ids()
->mutable_int_id()
->mutable_data()
->Set(loc,
std::visit(Int64PKVisitor{},
search_result->primary_keys_[ki]));
break;
}
case milvus::DataType::VARCHAR: {
*search_result_data->mutable_ids()->mutable_str_id()->mutable_data()->Mutable(loc) =
std::visit(StrPKVisitor{}, search_result->primary_keys_[ki]);
*search_result_data->mutable_ids()
->mutable_str_id()
->mutable_data()
->Mutable(loc) = std::visit(
StrPKVisitor{}, search_result->primary_keys_[ki]);
break;
}
default: {
@ -325,7 +361,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
}
// set result distances
search_result_data->mutable_scores()->Set(loc, search_result->distances_[ki]);
search_result_data->mutable_scores()->Set(
loc, search_result->distances_[ki]);
// set result offset to fill output fields data
result_pairs[loc] = std::make_pair(search_result, ki);
}
@ -336,14 +373,17 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
}
AssertInfo(search_result_data->scores_size() == result_count,
"wrong scores size, size = " + std::to_string(search_result_data->scores_size()) +
"wrong scores size, size = " +
std::to_string(search_result_data->scores_size()) +
", expected size = " + std::to_string(result_count));
// set output fields
for (auto field_id : plan_->target_entries_) {
auto& field_meta = plan_->schema_[field_id];
auto field_data = milvus::segcore::MergeDataArray(result_pairs, field_meta);
search_result_data->mutable_fields_data()->AddAllocated(field_data.release());
auto field_data =
milvus::segcore::MergeDataArray(result_pairs, field_meta);
search_result_data->mutable_fields_data()->AddAllocated(
field_data.release());
}
// SearchResultData to blob

View File

@ -73,7 +73,9 @@ class ReduceHelper {
FillEntryData();
int64_t
ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset);
ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& result_offset);
void
ReduceResultData();
@ -102,7 +104,10 @@ class ReduceHelper {
// Used for merge results,
// define these here to avoid allocating them for each query
std::vector<SearchResultPair> pairs_;
std::priority_queue<SearchResultPair*, std::vector<SearchResultPair*>, SearchResultPairComparator> heap_;
std::priority_queue<SearchResultPair*,
std::vector<SearchResultPair*>,
SearchResultPairComparator>
heap_;
std::unordered_set<milvus::PkType> pk_set_;
};

View File

@ -27,8 +27,12 @@ struct SearchResultPair {
int64_t offset_;
int64_t offset_rb_; // right bound
SearchResultPair(
milvus::PkType primary_key, float distance, SearchResult* result, int64_t index, int64_t lb, int64_t rb)
SearchResultPair(milvus::PkType primary_key,
float distance,
SearchResult* result,
int64_t index,
int64_t lb,
int64_t rb)
: primary_key_(primary_key),
distance_(distance),
search_result_(result),

View File

@ -31,8 +31,12 @@ ScalarIndexVector::do_search_ids(const IdArray& ids) const {
for (auto id : src_ids.data()) {
using Pair = std::pair<T, SegOffset>;
auto [iter_beg, iter_end] =
std::equal_range(mapping_.begin(), mapping_.end(), std::make_pair(id, SegOffset(0)),
[](const Pair& left, const Pair& right) { return left.first < right.first; });
std::equal_range(mapping_.begin(),
mapping_.end(),
std::make_pair(id, SegOffset(0)),
[](const Pair& left, const Pair& right) {
return left.first < right.first;
});
for (auto& iter = iter_beg; iter != iter_end; iter++) {
auto [entry_id, entry_offset] = *iter;
@ -51,8 +55,12 @@ ScalarIndexVector::do_search_ids(const std::vector<idx_t>& ids) const {
for (auto id : ids) {
using Pair = std::pair<T, SegOffset>;
auto [iter_beg, iter_end] =
std::equal_range(mapping_.begin(), mapping_.end(), std::make_pair(id, SegOffset(0)),
[](const Pair& left, const Pair& right) { return left.first < right.first; });
std::equal_range(mapping_.begin(),
mapping_.end(),
std::make_pair(id, SegOffset(0)),
[](const Pair& left, const Pair& right) {
return left.first < right.first;
});
for (auto& iter = iter_beg; iter != iter_end; iter++) {
auto [entry_id, entry_offset] = *iter_beg;
@ -64,7 +72,9 @@ ScalarIndexVector::do_search_ids(const std::vector<idx_t>& ids) const {
}
void
ScalarIndexVector::append_data(const ScalarIndexVector::T* ids, int64_t count, SegOffset base) {
ScalarIndexVector::append_data(const ScalarIndexVector::T* ids,
int64_t count,
SegOffset base) {
for (int64_t i = 0; i < count; ++i) {
auto offset = base + SegOffset(i);
mapping_.emplace_back(ids[i], offset);

View File

@ -53,7 +53,8 @@ class ScalarIndexVector : public ScalarIndexBase {
debug() const override {
std::string dbg_str;
for (auto pr : mapping_) {
dbg_str += "<" + std::to_string(pr.first) + "->" + std::to_string(pr.second.get()) + ">";
dbg_str += "<" + std::to_string(pr.first) + "->" +
std::to_string(pr.second.get()) + ">";
}
return dbg_str;
}

View File

@ -33,7 +33,9 @@ using SealedIndexingEntryPtr = std::unique_ptr<SealedIndexingEntry>;
struct SealedIndexingRecord {
void
append_field_indexing(FieldId field_id, const MetricType& metric_type, index::IndexBasePtr indexing) {
append_field_indexing(FieldId field_id,
const MetricType& metric_type,
index::IndexBasePtr indexing) {
auto ptr = std::make_unique<SealedIndexingEntry>();
ptr->indexing_ = std::move(indexing);
ptr->metric_type_ = metric_type;

View File

@ -100,7 +100,8 @@ SegcoreConfig::parse_from(const std::string& config_path) {
} catch (const SegcoreError& e) {
throw e;
} catch (const std::exception& e) {
std::string str = std::string("Invalid Yaml: ") + config_path + ", err: " + e.what();
std::string str =
std::string("Invalid Yaml: ") + config_path + ", err: " + e.what();
PanicInfo(str);
}
}

View File

@ -76,7 +76,8 @@ class SegcoreConfig {
}
void
set_small_index_config(const MetricType& metric_type, const SmallIndexConf& small_index_conf) {
set_small_index_config(const MetricType& metric_type,
const SmallIndexConf& small_index_conf) {
table_[metric_type] = small_index_conf;
}

View File

@ -36,17 +36,21 @@ SegmentGrowingImpl::PreDelete(int64_t size) {
}
void
SegmentGrowingImpl::mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const {
SegmentGrowingImpl::mask_with_delete(BitsetType& bitset,
int64_t ins_barrier,
Timestamp timestamp) const {
auto del_barrier = get_barrier(get_deleted_record(), timestamp);
if (del_barrier == 0) {
return;
}
auto bitmap_holder = get_deleted_bitmap(del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(
del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
if (!bitmap_holder || !bitmap_holder->bitmap_ptr) {
return;
}
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
AssertInfo(delete_bitset.size() == bitset.size(), "Deleted bitmap size not equal to filtered bitmap size");
AssertInfo(delete_bitset.size() == bitset.size(),
"Deleted bitmap size not equal to filtered bitmap size");
bitset |= delete_bitset;
}
@ -56,7 +60,8 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
const int64_t* row_ids,
const Timestamp* timestamps_raw,
const InsertData* insert_data) {
AssertInfo(insert_data->num_rows() == size, "Entities_raw count not equal to insert size");
AssertInfo(insert_data->num_rows() == size,
"Entities_raw count not equal to insert size");
// AssertInfo(insert_data->fields_data_size() == schema_->size(),
// "num fields of insert data not equal to num of schema fields");
// step 1: check insert data if valid
@ -72,34 +77,45 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
// query node already guarantees that the timestamp is ordered, avoid field data copy in c++
// step 3: fill into Segment.ConcurrentVector
insert_record_.timestamps_.set_data_raw(reserved_offset, timestamps_raw, size);
insert_record_.timestamps_.set_data_raw(
reserved_offset, timestamps_raw, size);
insert_record_.row_ids_.set_data_raw(reserved_offset, row_ids, size);
for (auto [field_id, field_meta] : schema_->get_fields()) {
AssertInfo(field_id_to_offset.count(field_id), "Cannot find field_id");
auto data_offset = field_id_to_offset[field_id];
insert_record_.get_field_data_base(field_id)->set_data_raw(reserved_offset, size,
&insert_data->fields_data(data_offset), field_meta);
insert_record_.get_field_data_base(field_id)->set_data_raw(
reserved_offset,
size,
&insert_data->fields_data(data_offset),
field_meta);
}
// step 4: set pks to offset
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
AssertInfo(field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
std::vector<PkType> pks(size);
ParsePksFromFieldData(pks, insert_data->fields_data(field_id_to_offset[field_id]));
ParsePksFromFieldData(
pks, insert_data->fields_data(field_id_to_offset[field_id]));
for (int i = 0; i < size; ++i) {
insert_record_.insert_pk(pks[i], reserved_offset + i);
}
// step 5: update small indexes
insert_record_.ack_responder_.AddSegment(reserved_offset, reserved_offset + size);
insert_record_.ack_responder_.AddSegment(reserved_offset,
reserved_offset + size);
if (enable_small_index_) {
int64_t chunk_rows = segcore_config_.get_chunk_rows();
indexing_record_.UpdateResourceAck(insert_record_.ack_responder_.GetAck() / chunk_rows, insert_record_);
indexing_record_.UpdateResourceAck(
insert_record_.ack_responder_.GetAck() / chunk_rows,
insert_record_);
}
}
Status
SegmentGrowingImpl::Delete(int64_t reserved_begin, int64_t size, const IdArray* ids, const Timestamp* timestamps_raw) {
SegmentGrowingImpl::Delete(int64_t reserved_begin,
int64_t size,
const IdArray* ids,
const Timestamp* timestamps_raw) {
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
AssertInfo(field_id.get() != -1, "Primary key is -1");
auto& field_meta = schema_->operator[](field_id);
@ -122,9 +138,11 @@ SegmentGrowingImpl::Delete(int64_t reserved_begin, int64_t size, const IdArray*
}
// step 2: fill delete record
deleted_record_.timestamps_.set_data_raw(reserved_begin, sort_timestamps.data(), size);
deleted_record_.timestamps_.set_data_raw(
reserved_begin, sort_timestamps.data(), size);
deleted_record_.pks_.set_data_raw(reserved_begin, sort_pks.data(), size);
deleted_record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
deleted_record_.ack_responder_.AddSegment(reserved_begin,
reserved_begin + size);
return Status::OK();
}
@ -145,8 +163,10 @@ SegmentGrowingImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
AssertInfo(info.primary_keys, "Deleted primary keys is null");
AssertInfo(info.timestamps, "Deleted timestamps is null");
// step 1: get pks and timestamps
auto field_id = schema_->get_primary_field_id().value_or(FieldId(INVALID_FIELD_ID));
AssertInfo(field_id.get() != INVALID_FIELD_ID, "Primary key has invalid field id");
auto field_id =
schema_->get_primary_field_id().value_or(FieldId(INVALID_FIELD_ID));
AssertInfo(field_id.get() != INVALID_FIELD_ID,
"Primary key has invalid field id");
auto& field_meta = schema_->operator[](field_id);
int64_t size = info.row_count;
std::vector<PkType> pks(size);
@ -157,7 +177,8 @@ SegmentGrowingImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
auto reserved_begin = deleted_record_.reserved.fetch_add(size);
deleted_record_.pks_.set_data_raw(reserved_begin, pks.data(), size);
deleted_record_.timestamps_.set_data_raw(reserved_begin, timestamps, size);
deleted_record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
deleted_record_.ack_responder_.AddSegment(reserved_begin,
reserved_begin + size);
}
SpanBase
@ -181,70 +202,100 @@ SegmentGrowingImpl::vector_search(SearchInfo& search_info,
SearchResult& output) const {
auto& sealed_indexing = this->get_sealed_indexing_record();
if (sealed_indexing.is_ready(search_info.field_id_)) {
query::SearchOnSealedIndex(this->get_schema(), sealed_indexing, search_info, query_data, query_count, bitset,
query::SearchOnSealedIndex(this->get_schema(),
sealed_indexing,
search_info,
query_data,
query_count,
bitset,
output);
} else {
query::SearchOnGrowing(*this, search_info, query_data, query_count, timestamp, bitset, output);
query::SearchOnGrowing(*this,
search_info,
query_data,
query_count,
timestamp,
bitset,
output);
}
}
std::unique_ptr<DataArray>
SegmentGrowingImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const {
SegmentGrowingImpl::bulk_subscript(FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const {
// TODO: support more types
auto vec_ptr = insert_record_.get_field_data_base(field_id);
auto& field_meta = schema_->operator[](field_id);
if (field_meta.is_vector()) {
aligned_vector<char> output(field_meta.get_sizeof() * count);
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
bulk_subscript_impl<FloatVector>(field_meta.get_sizeof(), *vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<FloatVector>(field_meta.get_sizeof(),
*vec_ptr,
seg_offsets,
count,
output.data());
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
bulk_subscript_impl<BinaryVector>(field_meta.get_sizeof(), *vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<BinaryVector>(field_meta.get_sizeof(),
*vec_ptr,
seg_offsets,
count,
output.data());
} else {
PanicInfo("logical error");
}
return CreateVectorDataArrayFrom(output.data(), count, field_meta);
}
AssertInfo(!field_meta.is_vector(), "Scalar field meta type is vector type");
AssertInfo(!field_meta.is_vector(),
"Scalar field meta type is vector type");
switch (field_meta.get_data_type()) {
case DataType::BOOL: {
FixedVector<bool> output(count);
bulk_subscript_impl<bool>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<bool>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT8: {
FixedVector<bool> output(count);
bulk_subscript_impl<int8_t>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<int8_t>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT16: {
FixedVector<int16_t> output(count);
bulk_subscript_impl<int16_t>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<int16_t>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT32: {
FixedVector<int32_t> output(count);
bulk_subscript_impl<int32_t>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<int32_t>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT64: {
FixedVector<int64_t> output(count);
bulk_subscript_impl<int64_t>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<int64_t>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::FLOAT: {
FixedVector<float> output(count);
bulk_subscript_impl<float>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<float>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::DOUBLE: {
FixedVector<double> output(count);
bulk_subscript_impl<double>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<double>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::VARCHAR: {
FixedVector<std::string> output(count);
bulk_subscript_impl<std::string>(*vec_ptr, seg_offsets, count, output.data());
bulk_subscript_impl<std::string>(
*vec_ptr, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
default: {
@ -269,7 +320,9 @@ SegmentGrowingImpl::bulk_subscript_impl(int64_t element_sizeof,
for (int i = 0; i < count; ++i) {
auto dst = output_base + i * element_sizeof;
auto offset = seg_offsets[i];
const uint8_t* src = (offset == INVALID_SEG_OFFSET ? empty.data() : (const uint8_t*)vec.get_element(offset));
const uint8_t* src = (offset == INVALID_SEG_OFFSET
? empty.data()
: (const uint8_t*)vec.get_element(offset));
memcpy(dst, src, element_sizeof);
}
}
@ -300,10 +353,12 @@ SegmentGrowingImpl::bulk_subscript(SystemFieldType system_type,
void* output) const {
switch (system_type) {
case SystemFieldType::Timestamp:
bulk_subscript_impl<Timestamp>(this->insert_record_.timestamps_, seg_offsets, count, output);
bulk_subscript_impl<Timestamp>(
this->insert_record_.timestamps_, seg_offsets, count, output);
break;
case SystemFieldType::RowId:
bulk_subscript_impl<int64_t>(this->insert_record_.row_ids_, seg_offsets, count, output);
bulk_subscript_impl<int64_t>(
this->insert_record_.row_ids_, seg_offsets, count, output);
break;
default:
PanicInfo("unknown subscript fields");
@ -311,7 +366,8 @@ SegmentGrowingImpl::bulk_subscript(SystemFieldType system_type,
}
std::vector<SegOffset>
SegmentGrowingImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) const {
SegmentGrowingImpl::search_ids(const BitsetType& bitset,
Timestamp timestamp) const {
std::vector<SegOffset> res_offsets;
for (int i = 0; i < bitset.size(); i++) {
@ -326,7 +382,8 @@ SegmentGrowingImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) co
}
std::vector<SegOffset>
SegmentGrowingImpl::search_ids(const BitsetView& bitset, Timestamp timestamp) const {
SegmentGrowingImpl::search_ids(const BitsetView& bitset,
Timestamp timestamp) const {
std::vector<SegOffset> res_offsets;
for (int i = 0; i < bitset.size(); ++i) {
@ -341,7 +398,8 @@ SegmentGrowingImpl::search_ids(const BitsetView& bitset, Timestamp timestamp) co
}
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
SegmentGrowingImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const {
SegmentGrowingImpl::search_ids(const IdArray& id_array,
Timestamp timestamp) const {
AssertInfo(id_array.has_int_id(), "Id array doesn't have int_id element");
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
AssertInfo(field_id.get() != -1, "Primary key is -1");
@ -358,11 +416,13 @@ SegmentGrowingImpl::search_ids(const IdArray& id_array, Timestamp timestamp) con
for (auto offset : segOffsets) {
switch (data_type) {
case DataType::INT64: {
res_id_arr->mutable_int_id()->add_data(std::get<int64_t>(pk));
res_id_arr->mutable_int_id()->add_data(
std::get<int64_t>(pk));
break;
}
case DataType::VARCHAR: {
res_id_arr->mutable_str_id()->add_data(std::get<std::string>(pk));
res_id_arr->mutable_str_id()->add_data(
std::get<std::string>(pk));
break;
}
default: {
@ -384,13 +444,17 @@ int64_t
SegmentGrowingImpl::get_active_count(Timestamp ts) const {
auto row_count = this->get_row_count();
auto& ts_vec = this->get_insert_record().timestamps_;
auto iter = std::upper_bound(boost::make_counting_iterator((int64_t)0), boost::make_counting_iterator(row_count),
ts, [&](Timestamp ts, int64_t index) { return ts < ts_vec[index]; });
auto iter = std::upper_bound(
boost::make_counting_iterator((int64_t)0),
boost::make_counting_iterator(row_count),
ts,
[&](Timestamp ts, int64_t index) { return ts < ts_vec[index]; });
return *iter;
}
void
SegmentGrowingImpl::mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const {
SegmentGrowingImpl::mask_with_timestamps(BitsetType& bitset_chunk,
Timestamp timestamp) const {
// DO NOTHING
}

View File

@ -53,7 +53,10 @@ class SegmentGrowingImpl : public SegmentGrowing {
// TODO: add id into delete log, possibly bitmap
Status
Delete(int64_t reserverd_offset, int64_t size, const IdArray* pks, const Timestamp* timestamps) override;
Delete(int64_t reserved_offset,
int64_t size,
const IdArray* pks,
const Timestamp* timestamps) override;
int64_t
GetMemoryUsageInBytes() const override;
@ -111,7 +114,8 @@ class SegmentGrowingImpl : public SegmentGrowing {
// deprecated
const index::IndexBase*
chunk_index_impl(FieldId field_id, int64_t chunk_id) const final {
return indexing_record_.get_field_indexing(field_id).get_chunk_indexing(chunk_id);
return indexing_record_.get_field_indexing(field_id).get_chunk_indexing(
chunk_id);
}
int64_t
@ -142,7 +146,10 @@ class SegmentGrowingImpl : public SegmentGrowing {
// for scalar vectors
template <typename T>
void
bulk_subscript_impl(const VectorBase& vec_raw, const int64_t* seg_offsets, int64_t count, void* output_raw) const;
bulk_subscript_impl(const VectorBase& vec_raw,
const int64_t* seg_offsets,
int64_t count,
void* output_raw) const;
template <typename T>
void
@ -153,16 +160,25 @@ class SegmentGrowingImpl : public SegmentGrowing {
void* output_raw) const;
void
bulk_subscript(SystemFieldType system_type, const int64_t* seg_offsets, int64_t count, void* output) const override;
bulk_subscript(SystemFieldType system_type,
const int64_t* seg_offsets,
int64_t count,
void* output) const override;
std::unique_ptr<DataArray>
bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const override;
bulk_subscript(FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const override;
public:
friend std::unique_ptr<SegmentGrowing>
CreateGrowingSegment(SchemaPtr schema, const SegcoreConfig& segcore_config, int64_t segment_id);
CreateGrowingSegment(SchemaPtr schema,
const SegcoreConfig& segcore_config,
int64_t segment_id);
explicit SegmentGrowingImpl(SchemaPtr schema, const SegcoreConfig& segcore_config, int64_t segment_id)
explicit SegmentGrowingImpl(SchemaPtr schema,
const SegcoreConfig& segcore_config,
int64_t segment_id)
: segcore_config_(segcore_config),
schema_(std::move(schema)),
insert_record_(*schema_, segcore_config.get_chunk_rows()),
@ -171,7 +187,8 @@ class SegmentGrowingImpl : public SegmentGrowing {
}
void
mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const override;
mask_with_timestamps(BitsetType& bitset_chunk,
Timestamp timestamp) const override;
void
vector_search(SearchInfo& search_info,
@ -183,7 +200,9 @@ class SegmentGrowingImpl : public SegmentGrowing {
public:
void
mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const override;
mask_with_delete(BitsetType& bitset,
int64_t ins_barrier,
Timestamp timestamp) const override;
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
search_ids(const IdArray& id_array, Timestamp timestamp) const override;
@ -237,9 +256,10 @@ class SegmentGrowingImpl : public SegmentGrowing {
};
inline SegmentGrowingPtr
CreateGrowingSegment(SchemaPtr schema,
int64_t segment_id = -1,
const SegcoreConfig& conf = SegcoreConfig::default_config()) {
CreateGrowingSegment(
SchemaPtr schema,
int64_t segment_id = -1,
const SegcoreConfig& conf = SegcoreConfig::default_config()) {
return std::make_unique<SegmentGrowingImpl>(schema, conf, segment_id);
}

View File

@ -21,43 +21,51 @@
namespace milvus::segcore {
void
SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const {
SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan,
SearchResult& results) const {
std::shared_lock lck(mutex_);
AssertInfo(plan, "empty plan");
auto size = results.distances_.size();
AssertInfo(results.seg_offsets_.size() == size, "Size of result distances is not equal to size of ids");
AssertInfo(results.seg_offsets_.size() == size,
"Size of result distances is not equal to size of ids");
Assert(results.primary_keys_.size() == 0);
results.primary_keys_.resize(size);
auto pk_field_id_opt = get_schema().get_primary_field_id();
AssertInfo(pk_field_id_opt.has_value(), "Cannot get primary key offset from schema");
AssertInfo(pk_field_id_opt.has_value(),
"Cannot get primary key offset from schema");
auto pk_field_id = pk_field_id_opt.value();
AssertInfo(IsPrimaryKeyDataType(get_schema()[pk_field_id].get_data_type()),
"Primary key field is not INT64 or VARCHAR type");
auto field_data = bulk_subscript(pk_field_id, results.seg_offsets_.data(), size);
auto field_data =
bulk_subscript(pk_field_id, results.seg_offsets_.data(), size);
results.pk_type_ = DataType(field_data->type());
ParsePksFromFieldData(results.primary_keys_, *field_data.get());
}
void
SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult& results) const {
SegmentInternalInterface::FillTargetEntry(const query::Plan* plan,
SearchResult& results) const {
std::shared_lock lck(mutex_);
AssertInfo(plan, "empty plan");
auto size = results.distances_.size();
AssertInfo(results.seg_offsets_.size() == size, "Size of result distances is not equal to size of ids");
AssertInfo(results.seg_offsets_.size() == size,
"Size of result distances is not equal to size of ids");
// fill other entries except primary key by result_offset
for (auto field_id : plan->target_entries_) {
auto field_data = bulk_subscript(field_id, results.seg_offsets_.data(), size);
auto field_data =
bulk_subscript(field_id, results.seg_offsets_.data(), size);
results.output_fields_data_[field_id] = std::move(field_data);
}
}
std::unique_ptr<SearchResult>
SegmentInternalInterface::Search(const query::Plan* plan,
const query::PlaceholderGroup* placeholder_group,
Timestamp timestamp) const {
SegmentInternalInterface::Search(
const query::Plan* plan,
const query::PlaceholderGroup* placeholder_group,
Timestamp timestamp) const {
std::shared_lock lck(mutex_);
check_search(plan);
query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group);
@ -68,24 +76,30 @@ SegmentInternalInterface::Search(const query::Plan* plan,
}
std::unique_ptr<proto::segcore::RetrieveResults>
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, Timestamp timestamp) const {
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
Timestamp timestamp) const {
std::shared_lock lck(mutex_);
auto results = std::make_unique<proto::segcore::RetrieveResults>();
query::ExecPlanNodeVisitor visitor(*this, timestamp);
auto retrieve_results = visitor.get_retrieve_result(*plan->plan_node_);
retrieve_results.segment_ = (void*)this;
results->mutable_offset()->Add(retrieve_results.result_offsets_.begin(), retrieve_results.result_offsets_.end());
results->mutable_offset()->Add(retrieve_results.result_offsets_.begin(),
retrieve_results.result_offsets_.end());
auto fields_data = results->mutable_fields_data();
auto ids = results->mutable_ids();
auto pk_field_id = plan->schema_.get_primary_field_id();
for (auto field_id : plan->field_ids_) {
if (SystemProperty::Instance().IsSystem(field_id)) {
auto system_type = SystemProperty::Instance().GetSystemFieldType(field_id);
auto system_type =
SystemProperty::Instance().GetSystemFieldType(field_id);
auto size = retrieve_results.result_offsets_.size();
FixedVector<int64_t> output(size);
bulk_subscript(system_type, retrieve_results.result_offsets_.data(), size, output.data());
bulk_subscript(system_type,
retrieve_results.result_offsets_.data(),
size,
output.data());
auto data_array = std::make_unique<DataArray>();
data_array->set_field_id(field_id.get());
@ -102,8 +116,9 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, Timestamp ti
auto& field_meta = plan->schema_[field_id];
auto col =
bulk_subscript(field_id, retrieve_results.result_offsets_.data(), retrieve_results.result_offsets_.size());
auto col = bulk_subscript(field_id,
retrieve_results.result_offsets_.data(),
retrieve_results.result_offsets_.size());
auto col_data = col.release();
fields_data->AddAllocated(col_data);
if (pk_field_id.has_value() && pk_field_id.value() == field_id) {
@ -111,7 +126,8 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, Timestamp ti
case DataType::INT64: {
auto int_ids = ids->mutable_int_id();
auto src_data = col_data->scalars().long_data();
int_ids->mutable_data()->Add(src_data.data().begin(), src_data.data().end());
int_ids->mutable_data()->Add(src_data.data().begin(),
src_data.data().end());
break;
}
case DataType::VARCHAR: {

View File

@ -48,7 +48,9 @@ class SegmentInterface {
FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0;
virtual std::unique_ptr<SearchResult>
Search(const query::Plan* Plan, const query::PlaceholderGroup* placeholder_group, Timestamp timestamp) const = 0;
Search(const query::Plan* Plan,
const query::PlaceholderGroup* placeholder_group,
Timestamp timestamp) const = 0;
virtual std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* Plan, Timestamp timestamp) const = 0;
@ -73,7 +75,10 @@ class SegmentInterface {
PreDelete(int64_t size) = 0;
virtual Status
Delete(int64_t reserved_offset, int64_t size, const IdArray* pks, const Timestamp* timestamps) = 0;
Delete(int64_t reserved_offset,
int64_t size,
const IdArray* pks,
const Timestamp* timestamps) = 0;
virtual void
LoadDeletedRecord(const LoadDeletedRecordInfo& info) = 0;
@ -112,13 +117,16 @@ class SegmentInternalInterface : public SegmentInterface {
Timestamp timestamp) const override;
void
FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const override;
FillPrimaryKeys(const query::Plan* plan,
SearchResult& results) const override;
void
FillTargetEntry(const query::Plan* plan, SearchResult& results) const override;
FillTargetEntry(const query::Plan* plan,
SearchResult& results) const override;
std::unique_ptr<proto::segcore::RetrieveResults>
Retrieve(const query::RetrievePlan* plan, Timestamp timestamp) const override;
Retrieve(const query::RetrievePlan* plan,
Timestamp timestamp) const override;
virtual bool
HasIndex(FieldId field_id) const = 0;
@ -142,7 +150,9 @@ class SegmentInternalInterface : public SegmentInterface {
SearchResult& output) const = 0;
virtual void
mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const = 0;
mask_with_delete(BitsetType& bitset,
int64_t ins_barrier,
Timestamp timestamp) const = 0;
// count of chunk that has index available
virtual int64_t
@ -153,7 +163,8 @@ class SegmentInternalInterface : public SegmentInterface {
num_chunk_data(FieldId field_id) const = 0;
virtual void
mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const = 0;
mask_with_timestamps(BitsetType& bitset_chunk,
Timestamp timestamp) const = 0;
// count of chunks
virtual int64_t
@ -186,11 +197,16 @@ class SegmentInternalInterface : public SegmentInterface {
// calculate output[i] = Vec[seg_offsets[i]}, where Vec binds to system_type
virtual void
bulk_subscript(SystemFieldType system_type, const int64_t* seg_offsets, int64_t count, void* output) const = 0;
bulk_subscript(SystemFieldType system_type,
const int64_t* seg_offsets,
int64_t count,
void* output) const = 0;
// calculate output[i] = Vec[seg_offsets[i]}, where Vec binds to field_offset
virtual std::unique_ptr<DataArray>
bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const = 0;
bulk_subscript(FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const = 0;
virtual void
check_search(const query::Plan* plan) const = 0;

View File

@ -66,7 +66,8 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
auto field_id = FieldId(info.field_id);
auto& field_meta = schema_->operator[](field_id);
AssertInfo(info.index_params.count("metric_type"), "Can't get metric_type in index_params");
AssertInfo(info.index_params.count("metric_type"),
"Can't get metric_type in index_params");
auto metric_type = info.index_params.at("metric_type");
auto row_count = info.index->Count();
AssertInfo(row_count > 0, "Index count is 0");
@ -74,17 +75,24 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
std::unique_lock lck(mutex_);
// Don't allow vector raw data and index exist at the same time
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
"vector index can't be loaded when raw data exists at field " + std::to_string(field_id.get()));
AssertInfo(!get_bit(index_ready_bitset_, field_id),
"vector index has been exist at " + std::to_string(field_id.get()));
"vector index can't be loaded when raw data exists at field " +
std::to_string(field_id.get()));
AssertInfo(
!get_bit(index_ready_bitset_, field_id),
"vector index has been exist at " + std::to_string(field_id.get()));
if (row_count_opt_.has_value()) {
AssertInfo(row_count_opt_.value() == row_count,
"field (" + std::to_string(field_id.get()) + ") data has different row count (" +
std::to_string(row_count) + ") than other column's row count (" +
"field (" + std::to_string(field_id.get()) +
") data has different row count (" +
std::to_string(row_count) +
") than other column's row count (" +
std::to_string(row_count_opt_.value()) + ")");
}
AssertInfo(!vector_indexings_.is_ready(field_id), "vec index is not ready");
vector_indexings_.append_field_indexing(field_id, metric_type, std::move(const_cast<LoadIndexInfo&>(info).index));
vector_indexings_.append_field_indexing(
field_id,
metric_type,
std::move(const_cast<LoadIndexInfo&>(info).index));
set_bit(index_ready_bitset_, field_id, true);
update_row_count(row_count);
@ -103,24 +111,30 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) {
std::unique_lock lck(mutex_);
// Don't allow scalar raw data and index exist at the same time
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
"scalar index can't be loaded when raw data exists at field " + std::to_string(field_id.get()));
AssertInfo(!get_bit(index_ready_bitset_, field_id),
"scalar index has been exist at " + std::to_string(field_id.get()));
"scalar index can't be loaded when raw data exists at field " +
std::to_string(field_id.get()));
AssertInfo(
!get_bit(index_ready_bitset_, field_id),
"scalar index has been exist at " + std::to_string(field_id.get()));
if (row_count_opt_.has_value()) {
AssertInfo(row_count_opt_.value() == row_count,
"field (" + std::to_string(field_id.get()) + ") data has different row count (" +
std::to_string(row_count) + ") than other column's row count (" +
"field (" + std::to_string(field_id.get()) +
") data has different row count (" +
std::to_string(row_count) +
") than other column's row count (" +
std::to_string(row_count_opt_.value()) + ")");
}
scalar_indexings_[field_id] = std::move(const_cast<LoadIndexInfo&>(info).index);
scalar_indexings_[field_id] =
std::move(const_cast<LoadIndexInfo&>(info).index);
// reverse pk from scalar index and set pks to offset
if (schema_->get_primary_field_id() == field_id) {
AssertInfo(field_id.get() != -1, "Primary key is -1");
AssertInfo(insert_record_.empty_pks(), "already exists");
switch (field_meta.get_data_type()) {
case DataType::INT64: {
auto int64_index = dynamic_cast<index::ScalarIndex<int64_t>*>(scalar_indexings_[field_id].get());
auto int64_index = dynamic_cast<index::ScalarIndex<int64_t>*>(
scalar_indexings_[field_id].get());
for (int i = 0; i < row_count; ++i) {
insert_record_.insert_pk(int64_index->Reverse_Lookup(i), i);
}
@ -128,9 +142,12 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) {
break;
}
case DataType::VARCHAR: {
auto string_index = dynamic_cast<index::ScalarIndex<std::string>*>(scalar_indexings_[field_id].get());
auto string_index =
dynamic_cast<index::ScalarIndex<std::string>*>(
scalar_indexings_[field_id].get());
for (int i = 0; i < row_count; ++i) {
insert_record_.insert_pk(string_index->Reverse_Lookup(i), i);
insert_record_.insert_pk(string_index->Reverse_Lookup(i),
i);
}
insert_record_.seal_pks();
break;
@ -155,15 +172,21 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& info) {
AssertInfo(info.field_data != nullptr, "Field info blob is null");
auto size = info.row_count;
if (row_count_opt_.has_value()) {
AssertInfo(row_count_opt_.value() == size,
fmt::format("field {} has different row count {} to other column's {}", field_id.get(), size,
row_count_opt_.value()));
AssertInfo(
row_count_opt_.value() == size,
fmt::format(
"field {} has different row count {} to other column's {}",
field_id.get(),
size,
row_count_opt_.value()));
}
if (SystemProperty::Instance().IsSystem(field_id)) {
auto system_field_type = SystemProperty::Instance().GetSystemFieldType(field_id);
auto system_field_type =
SystemProperty::Instance().GetSystemFieldType(field_id);
if (system_field_type == SystemFieldType::Timestamp) {
auto timestamps = reinterpret_cast<const Timestamp*>(info.field_data->scalars().long_data().data().data());
auto timestamps = reinterpret_cast<const Timestamp*>(
info.field_data->scalars().long_data().data().data());
TimestampIndex index;
auto min_slice_length = size < 4096 ? 1 : 4096;
@ -176,15 +199,19 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& info) {
AssertInfo(insert_record_.timestamps_.empty(), "already exists");
insert_record_.timestamps_.fill_chunk_data(timestamps, size);
insert_record_.timestamp_index_ = std::move(index);
AssertInfo(insert_record_.timestamps_.num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
AssertInfo(insert_record_.timestamps_.num_chunk() == 1,
"num chunk not equal to 1 for sealed segment");
} else {
AssertInfo(system_field_type == SystemFieldType::RowId, "System field type of id column is not RowId");
auto row_ids = reinterpret_cast<const idx_t*>(info.field_data->scalars().long_data().data().data());
AssertInfo(system_field_type == SystemFieldType::RowId,
"System field type of id column is not RowId");
auto row_ids = reinterpret_cast<const idx_t*>(
info.field_data->scalars().long_data().data().data());
// write data under lock
std::unique_lock lck(mutex_);
AssertInfo(insert_record_.row_ids_.empty(), "already exists");
insert_record_.row_ids_.fill_chunk_data(row_ids, size);
AssertInfo(insert_record_.row_ids_.num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
AssertInfo(insert_record_.row_ids_.num_chunk() == 1,
"num chunk not equal to 1 for sealed segment");
}
++system_ready_count_;
} else {
@ -198,7 +225,8 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& info) {
std::unique_lock lck(mutex_);
// Don't allow raw data and index exist at the same time
AssertInfo(!get_bit(index_ready_bitset_, field_id), "field data can't be loaded when indexing exists");
AssertInfo(!get_bit(index_ready_bitset_, field_id),
"field data can't be loaded when indexing exists");
void* field_data = nullptr;
size_t size = 0;
@ -250,7 +278,8 @@ SegmentSealedImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
auto reserved_begin = deleted_record_.reserved.fetch_add(size);
deleted_record_.pks_.set_data_raw(reserved_begin, pks.data(), size);
deleted_record_.timestamps_.set_data_raw(reserved_begin, timestamps, size);
deleted_record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
deleted_record_.ack_responder_.AddSegment(reserved_begin,
reserved_begin + size);
}
// internal API: support scalar index only
@ -290,19 +319,24 @@ SegmentSealedImpl::chunk_data_impl(FieldId field_id, int64_t chunk_id) const {
auto field_data = it->second;
return SpanBase(field_data, get_row_count(), element_sizeof);
}
if (auto it = variable_fields_.find(field_id); it != variable_fields_.end()) {
if (auto it = variable_fields_.find(field_id);
it != variable_fields_.end()) {
auto& field = it->second;
return SpanBase(field.views().data(), field.views().size(), sizeof(std::string_view));
return SpanBase(field.views().data(),
field.views().size(),
sizeof(std::string_view));
}
auto field_data = insert_record_.get_field_data_base(field_id);
AssertInfo(field_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
AssertInfo(field_data->num_chunk() == 1,
"num chunk not equal to 1 for sealed segment");
return field_data->get_span_base(0);
}
const index::IndexBase*
SegmentSealedImpl::chunk_index_impl(FieldId field_id, int64_t chunk_id) const {
AssertInfo(scalar_indexings_.find(field_id) != scalar_indexings_.end(),
"Cannot find scalar_indexing with field_id: " + std::to_string(field_id.get()));
"Cannot find scalar_indexing with field_id: " +
std::to_string(field_id.get()));
auto ptr = scalar_indexings_.at(field_id).get();
return ptr;
}
@ -333,17 +367,21 @@ SegmentSealedImpl::get_schema() const {
}
void
SegmentSealedImpl::mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const {
SegmentSealedImpl::mask_with_delete(BitsetType& bitset,
int64_t ins_barrier,
Timestamp timestamp) const {
auto del_barrier = get_barrier(get_deleted_record(), timestamp);
if (del_barrier == 0) {
return;
}
auto bitmap_holder = get_deleted_bitmap(del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(
del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
if (!bitmap_holder || !bitmap_holder->bitmap_ptr) {
return;
}
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
AssertInfo(delete_bitset.size() == bitset.size(), "Deleted bitmap size not equal to filtered bitmap size");
AssertInfo(delete_bitset.size() == bitset.size(),
"Deleted bitmap size not equal to filtered bitmap size");
bitset |= delete_bitset;
}
@ -358,25 +396,42 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
auto field_id = search_info.field_id_;
auto& field_meta = schema_->operator[](field_id);
AssertInfo(field_meta.is_vector(), "The meta type of vector field is not vector type");
AssertInfo(field_meta.is_vector(),
"The meta type of vector field is not vector type");
if (get_bit(index_ready_bitset_, field_id)) {
AssertInfo(vector_indexings_.is_ready(field_id),
"vector indexes isn't ready for field " + std::to_string(field_id.get()));
query::SearchOnSealedIndex(*schema_, vector_indexings_, search_info, query_data, query_count, bitset, output);
"vector indexes isn't ready for field " +
std::to_string(field_id.get()));
query::SearchOnSealedIndex(*schema_,
vector_indexings_,
search_info,
query_data,
query_count,
bitset,
output);
} else {
AssertInfo(get_bit(field_data_ready_bitset_, field_id),
"Field Data is not loaded: " + std::to_string(field_id.get()));
AssertInfo(
get_bit(field_data_ready_bitset_, field_id),
"Field Data is not loaded: " + std::to_string(field_id.get()));
AssertInfo(row_count_opt_.has_value(), "Can't get row count value");
auto row_count = row_count_opt_.value();
auto vec_data = fixed_fields_.at(field_id);
query::SearchOnSealed(*schema_, vec_data, search_info, query_data, query_count, row_count, bitset, output);
query::SearchOnSealed(*schema_,
vec_data,
search_info,
query_data,
query_count,
row_count,
bitset,
output);
}
}
void
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
if (SystemProperty::Instance().IsSystem(field_id)) {
auto system_field_type = SystemProperty::Instance().GetSystemFieldType(field_id);
auto system_field_type =
SystemProperty::Instance().GetSystemFieldType(field_id);
std::unique_lock lck(mutex_);
--system_ready_count_;
@ -398,10 +453,12 @@ SegmentSealedImpl::DropFieldData(const FieldId field_id) {
void
SegmentSealedImpl::DropIndex(const FieldId field_id) {
AssertInfo(!SystemProperty::Instance().IsSystem(field_id),
"Field id:" + std::to_string(field_id.get()) + " isn't one of system type when drop index");
"Field id:" + std::to_string(field_id.get()) +
" isn't one of system type when drop index");
auto& field_meta = schema_->operator[](field_id);
AssertInfo(field_meta.is_vector(),
"Field meta of offset:" + std::to_string(field_id.get()) + " is not vector type");
"Field meta of offset:" + std::to_string(field_id.get()) +
" is not vector type");
std::unique_lock lck(mutex_);
vector_indexings_.drop_field_indexing(field_id);
@ -411,23 +468,29 @@ SegmentSealedImpl::DropIndex(const FieldId field_id) {
void
SegmentSealedImpl::check_search(const query::Plan* plan) const {
AssertInfo(plan, "Search plan is null");
AssertInfo(plan->extra_info_opt_.has_value(), "Extra info of search plan doesn't have value");
AssertInfo(plan->extra_info_opt_.has_value(),
"Extra info of search plan doesn't have value");
if (!is_system_field_ready()) {
PanicInfo("failed to load row ID or timestamp, potential missing bin logs or empty segments. Segment ID = " +
std::to_string(this->id_));
PanicInfo(
"failed to load row ID or timestamp, potential missing bin logs or "
"empty segments. Segment ID = " +
std::to_string(this->id_));
}
auto& request_fields = plan->extra_info_opt_.value().involved_fields_;
auto field_ready_bitset = field_data_ready_bitset_ | index_ready_bitset_;
AssertInfo(request_fields.size() == field_ready_bitset.size(),
"Request fields size not equal to field ready bitset size when check search");
"Request fields size not equal to field ready bitset size when "
"check search");
auto absent_fields = request_fields - field_ready_bitset;
if (absent_fields.any()) {
auto field_id = FieldId(absent_fields.find_first() + START_USER_FIELDID);
auto field_id =
FieldId(absent_fields.find_first() + START_USER_FIELDID);
auto& field_meta = schema_->operator[](field_id);
PanicInfo("User Field(" + field_meta.get_name().get() + ") is not loaded");
PanicInfo("User Field(" + field_meta.get_name().get() +
") is not loaded");
}
}
@ -445,18 +508,27 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type,
const int64_t* seg_offsets,
int64_t count,
void* output) const {
AssertInfo(is_system_field_ready(), "System field isn't ready when do bulk_insert");
AssertInfo(is_system_field_ready(),
"System field isn't ready when do bulk_insert");
switch (system_type) {
case SystemFieldType::Timestamp:
AssertInfo(insert_record_.timestamps_.num_chunk() == 1,
"num chunk of timestamp not equal to 1 for sealed segment");
bulk_subscript_impl<Timestamp>(this->insert_record_.timestamps_.get_chunk_data(0), seg_offsets, count,
output);
AssertInfo(
insert_record_.timestamps_.num_chunk() == 1,
"num chunk of timestamp not equal to 1 for sealed segment");
bulk_subscript_impl<Timestamp>(
this->insert_record_.timestamps_.get_chunk_data(0),
seg_offsets,
count,
output);
break;
case SystemFieldType::RowId:
AssertInfo(insert_record_.row_ids_.num_chunk() == 1,
"num chunk of rowID not equal to 1 for sealed segment");
bulk_subscript_impl<int64_t>(this->insert_record_.row_ids_.get_chunk_data(0), seg_offsets, count, output);
bulk_subscript_impl<int64_t>(
this->insert_record_.row_ids_.get_chunk_data(0),
seg_offsets,
count,
output);
break;
default:
PanicInfo("unknown subscript fields");
@ -465,7 +537,10 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type,
template <typename T>
void
SegmentSealedImpl::bulk_subscript_impl(const void* src_raw, const int64_t* seg_offsets, int64_t count, void* dst_raw) {
SegmentSealedImpl::bulk_subscript_impl(const void* src_raw,
const int64_t* seg_offsets,
int64_t count,
void* dst_raw) {
static_assert(IsScalar<T>);
auto src = reinterpret_cast<const T*>(src_raw);
auto dst = reinterpret_cast<T*>(dst_raw);
@ -495,14 +570,19 @@ SegmentSealedImpl::bulk_subscript_impl(const VariableField& field,
// for vector
void
SegmentSealedImpl::bulk_subscript_impl(
int64_t element_sizeof, const void* src_raw, const int64_t* seg_offsets, int64_t count, void* dst_raw) {
SegmentSealedImpl::bulk_subscript_impl(int64_t element_sizeof,
const void* src_raw,
const int64_t* seg_offsets,
int64_t count,
void* dst_raw) {
auto src_vec = reinterpret_cast<const char*>(src_raw);
auto dst_vec = reinterpret_cast<char*>(dst_raw);
for (int64_t i = 0; i < count; ++i) {
auto offset = seg_offsets[i];
auto dst = dst_vec + i * element_sizeof;
const char* src = (offset == INVALID_SEG_OFFSET ? nullptr : (src_vec + element_sizeof * offset));
const char* src = (offset == INVALID_SEG_OFFSET
? nullptr
: (src_vec + element_sizeof * offset));
if (!src) {
continue;
}
@ -520,7 +600,9 @@ SegmentSealedImpl::fill_with_empty(FieldId field_id, int64_t count) const {
}
std::unique_ptr<DataArray>
SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const {
SegmentSealedImpl::bulk_subscript(FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const {
auto& field_meta = schema_->operator[](field_id);
// if count == 0, return empty data array
if (count == 0) {
@ -530,7 +612,8 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
if (HasIndex(field_id)) {
// if field has load scalar index, reverse raw data from index
if (!datatype_is_vector(field_meta.get_data_type())) {
AssertInfo(num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
AssertInfo(num_chunk() == 1,
"num chunk not equal to 1 for sealed segment");
auto index = chunk_index_impl(field_id, 0);
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
}
@ -547,12 +630,18 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
case DataType::VARCHAR:
case DataType::STRING: {
FixedVector<std::string> output(count);
bulk_subscript_impl<std::string>(variable_fields_.at(field_id), seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
bulk_subscript_impl<std::string>(variable_fields_.at(field_id),
seg_offsets,
count,
output.data());
return CreateScalarDataArrayFrom(
output.data(), count, field_meta);
}
default:
PanicInfo(fmt::format("unsupported data type: {}", datatype_name(field_meta.get_data_type())));
PanicInfo(
fmt::format("unsupported data type: {}",
datatype_name(field_meta.get_data_type())));
}
}
@ -560,44 +649,55 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
switch (field_meta.get_data_type()) {
case DataType::BOOL: {
FixedVector<bool> output(count);
bulk_subscript_impl<bool>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<bool>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT8: {
FixedVector<int8_t> output(count);
bulk_subscript_impl<int8_t>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<int8_t>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT16: {
FixedVector<int16_t> output(count);
bulk_subscript_impl<int16_t>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<int16_t>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT32: {
FixedVector<int32_t> output(count);
bulk_subscript_impl<int32_t>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<int32_t>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::INT64: {
FixedVector<int64_t> output(count);
bulk_subscript_impl<int64_t>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<int64_t>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::FLOAT: {
FixedVector<float> output(count);
bulk_subscript_impl<float>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<float>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::DOUBLE: {
FixedVector<double> output(count);
bulk_subscript_impl<double>(src_vec, seg_offsets, count, output.data());
bulk_subscript_impl<double>(
src_vec, seg_offsets, count, output.data());
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
}
case DataType::VECTOR_FLOAT:
case DataType::VECTOR_BINARY: {
aligned_vector<char> output(field_meta.get_sizeof() * count);
bulk_subscript_impl(field_meta.get_sizeof(), src_vec, seg_offsets, count, output.data());
bulk_subscript_impl(field_meta.get_sizeof(),
src_vec,
seg_offsets,
count,
output.data());
return CreateVectorDataArrayFrom(output.data(), count, field_meta);
}
@ -624,7 +724,8 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
}
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const {
SegmentSealedImpl::search_ids(const IdArray& id_array,
Timestamp timestamp) const {
AssertInfo(id_array.has_int_id(), "Id array doesn't have int_id element");
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
AssertInfo(field_id.get() != -1, "Primary key is -1");
@ -641,11 +742,13 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) cons
for (auto offset : segOffsets) {
switch (data_type) {
case DataType::INT64: {
res_id_arr->mutable_int_id()->add_data(std::get<int64_t>(pk));
res_id_arr->mutable_int_id()->add_data(
std::get<int64_t>(pk));
break;
}
case DataType::VARCHAR: {
res_id_arr->mutable_str_id()->add_data(std::get<std::string>(pk));
res_id_arr->mutable_str_id()->add_data(
std::get<std::string>(pk));
break;
}
default: {
@ -659,7 +762,10 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) cons
}
Status
SegmentSealedImpl::Delete(int64_t reserved_offset, int64_t size, const IdArray* ids, const Timestamp* timestamps_raw) {
SegmentSealedImpl::Delete(int64_t reserved_offset,
int64_t size,
const IdArray* ids,
const Timestamp* timestamps_raw) {
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
AssertInfo(field_id.get() != -1, "Primary key is -1");
auto& field_meta = schema_->operator[](field_id);
@ -680,14 +786,17 @@ SegmentSealedImpl::Delete(int64_t reserved_offset, int64_t size, const IdArray*
sort_timestamps[i] = t;
sort_pks[i] = pk;
}
deleted_record_.timestamps_.set_data_raw(reserved_offset, sort_timestamps.data(), size);
deleted_record_.timestamps_.set_data_raw(
reserved_offset, sort_timestamps.data(), size);
deleted_record_.pks_.set_data_raw(reserved_offset, sort_pks.data(), size);
deleted_record_.ack_responder_.AddSegment(reserved_offset, reserved_offset + size);
deleted_record_.ack_responder_.AddSegment(reserved_offset,
reserved_offset + size);
return Status::OK();
}
std::vector<SegOffset>
SegmentSealedImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) const {
SegmentSealedImpl::search_ids(const BitsetType& bitset,
Timestamp timestamp) const {
std::vector<SegOffset> dst_offset;
for (int i = 0; i < bitset.size(); i++) {
if (bitset[i]) {
@ -701,7 +810,8 @@ SegmentSealedImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) con
}
std::vector<SegOffset>
SegmentSealedImpl::search_ids(const BitsetView& bitset, Timestamp timestamp) const {
SegmentSealedImpl::search_ids(const BitsetView& bitset,
Timestamp timestamp) const {
std::vector<SegOffset> dst_offset;
for (int i = 0; i < bitset.size(); i++) {
if (!bitset.test(i)) {
@ -723,7 +833,8 @@ SegmentSealedImpl::debug() const {
}
void
SegmentSealedImpl::LoadSegmentMeta(const proto::segcore::LoadSegmentMeta& segment_meta) {
SegmentSealedImpl::LoadSegmentMeta(
const proto::segcore::LoadSegmentMeta& segment_meta) {
std::unique_lock lck(mutex_);
std::vector<int64_t> slice_lengths;
for (auto& info : segment_meta.metas()) {
@ -740,11 +851,14 @@ SegmentSealedImpl::get_active_count(Timestamp ts) const {
}
void
SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const {
SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk,
Timestamp timestamp) const {
// TODO change the
AssertInfo(insert_record_.timestamps_.num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
AssertInfo(insert_record_.timestamps_.num_chunk() == 1,
"num chunk not equal to 1 for sealed segment");
const auto& timestamps_data = insert_record_.timestamps_.get_chunk(0);
AssertInfo(timestamps_data.size() == get_row_count(), "Timestamp size not equal to row count");
AssertInfo(timestamps_data.size() == get_row_count(),
"Timestamp size not equal to row count");
auto range = insert_record_.timestamp_index_.get_active_range(timestamp);
// range == (size_, size_) and size_ is this->timestamps_.size().
@ -760,7 +874,8 @@ SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk, Timestamp time
bitset_chunk.set();
return;
}
auto mask = TimestampIndex::GenerateBitset(timestamp, range, timestamps_data.data(), timestamps_data.size());
auto mask = TimestampIndex::GenerateBitset(
timestamp, range, timestamps_data.data(), timestamps_data.size());
bitset_chunk |= mask;
}

Some files were not shown because too many files have changed in this diff Show More