mirror of https://github.com/milvus-io/milvus.git
parent
fa86de530d
commit
bdd6bc7695
|
@ -18,7 +18,7 @@
|
|||
# Below is copied from milvus project
|
||||
BasedOnStyle: Google
|
||||
DerivePointerAlignment: false
|
||||
ColumnLimit: 120
|
||||
ColumnLimit: 80
|
||||
IndentWidth: 4
|
||||
AccessModifierOffset: -3
|
||||
AlwaysBreakAfterReturnType: All
|
||||
|
@ -28,7 +28,9 @@ AllowShortIfStatementsOnASingleLine: false
|
|||
AlignTrailingComments: true
|
||||
|
||||
# Appended Options
|
||||
SortIncludes: false
|
||||
SortIncludes: false
|
||||
Standard: Latest
|
||||
AlignAfterOpenBracket: Align
|
||||
BinPackParameters: false
|
||||
BinPackArguments: false
|
||||
ReflowComments: false
|
|
@ -32,14 +32,17 @@ class BitsetView : public knowhere::BitsetView {
|
|||
BitsetView() = default;
|
||||
~BitsetView() = default;
|
||||
|
||||
BitsetView(const std::nullptr_t value) : knowhere::BitsetView(value) { // NOLINT
|
||||
BitsetView(const std::nullptr_t value)
|
||||
: knowhere::BitsetView(value) { // NOLINT
|
||||
}
|
||||
|
||||
BitsetView(const uint8_t* data, size_t num_bits) : knowhere::BitsetView(data, num_bits) { // NOLINT
|
||||
BitsetView(const uint8_t* data, size_t num_bits)
|
||||
: knowhere::BitsetView(data, num_bits) { // NOLINT
|
||||
}
|
||||
|
||||
BitsetView(const BitsetType& bitset) // NOLINT
|
||||
: BitsetView((uint8_t*)boost_ext::get_data(bitset), size_t(bitset.size())) {
|
||||
: BitsetView((uint8_t*)boost_ext::get_data(bitset),
|
||||
size_t(bitset.size())) {
|
||||
}
|
||||
|
||||
BitsetView(const BitsetTypePtr& bitset_ptr) { // NOLINT
|
||||
|
@ -56,7 +59,11 @@ class BitsetView : public knowhere::BitsetView {
|
|||
|
||||
AssertInfo((offset & 0x7) == 0, "offset is not divisible by 8");
|
||||
AssertInfo(offset + size <= this->size(),
|
||||
fmt::format("index out of range, offset={}, size={}, bitset.size={}", offset, size, this->size()));
|
||||
fmt::format(
|
||||
"index out of range, offset={}, size={}, bitset.size={}",
|
||||
offset,
|
||||
size,
|
||||
this->size()));
|
||||
return {data() + (offset >> 3), size};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
namespace milvus {
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_fundamental_v<T> || std::is_same_v<T, std::string>>>
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_fundamental_v<T> ||
|
||||
std::is_same_v<T, std::string>>>
|
||||
inline CDataType
|
||||
GetDType() {
|
||||
return None;
|
||||
|
|
|
@ -26,13 +26,15 @@ int cpu_num = DEFAULT_CPU_NUM;
|
|||
void
|
||||
SetIndexSliceSize(const int64_t size) {
|
||||
index_file_slice_size = size;
|
||||
LOG_SEGCORE_DEBUG_ << "set config index slice size: " << index_file_slice_size;
|
||||
LOG_SEGCORE_DEBUG_ << "set config index slice size: "
|
||||
<< index_file_slice_size;
|
||||
}
|
||||
|
||||
void
|
||||
SetThreadCoreCoefficient(const int64_t coefficient) {
|
||||
thread_core_coefficient = coefficient;
|
||||
LOG_SEGCORE_DEBUG_ << "set thread pool core coefficient: " << thread_core_coefficient;
|
||||
LOG_SEGCORE_DEBUG_ << "set thread pool core coefficient: "
|
||||
<< thread_core_coefficient;
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -81,7 +81,8 @@ datatype_name(DataType data_type) {
|
|||
return "vector_binary";
|
||||
}
|
||||
default: {
|
||||
auto err_msg = "Unsupported DataType(" + std::to_string((int)data_type) + ")";
|
||||
auto err_msg =
|
||||
"Unsupported DataType(" + std::to_string((int)data_type) + ")";
|
||||
PanicInfo(err_msg);
|
||||
}
|
||||
}
|
||||
|
@ -89,7 +90,8 @@ datatype_name(DataType data_type) {
|
|||
|
||||
inline bool
|
||||
datatype_is_vector(DataType datatype) {
|
||||
return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT;
|
||||
return datatype == DataType::VECTOR_BINARY ||
|
||||
datatype == DataType::VECTOR_FLOAT;
|
||||
}
|
||||
|
||||
inline bool
|
||||
|
@ -148,25 +150,39 @@ class FieldMeta {
|
|||
FieldMeta&
|
||||
operator=(FieldMeta&&) = default;
|
||||
|
||||
FieldMeta(const FieldName& name, FieldId id, DataType type) : name_(name), id_(id), type_(type) {
|
||||
FieldMeta(const FieldName& name, FieldId id, DataType type)
|
||||
: name_(name), id_(id), type_(type) {
|
||||
Assert(!is_vector());
|
||||
}
|
||||
|
||||
FieldMeta(const FieldName& name, FieldId id, DataType type, int64_t max_length)
|
||||
: name_(name), id_(id), type_(type), string_info_(StringInfo{max_length}) {
|
||||
FieldMeta(const FieldName& name,
|
||||
FieldId id,
|
||||
DataType type,
|
||||
int64_t max_length)
|
||||
: name_(name),
|
||||
id_(id),
|
||||
type_(type),
|
||||
string_info_(StringInfo{max_length}) {
|
||||
Assert(is_string());
|
||||
}
|
||||
|
||||
FieldMeta(
|
||||
const FieldName& name, FieldId id, DataType type, int64_t dim, std::optional<knowhere::MetricType> metric_type)
|
||||
: name_(name), id_(id), type_(type), vector_info_(VectorInfo{dim, metric_type}) {
|
||||
FieldMeta(const FieldName& name,
|
||||
FieldId id,
|
||||
DataType type,
|
||||
int64_t dim,
|
||||
std::optional<knowhere::MetricType> metric_type)
|
||||
: name_(name),
|
||||
id_(id),
|
||||
type_(type),
|
||||
vector_info_(VectorInfo{dim, metric_type}) {
|
||||
Assert(is_vector());
|
||||
}
|
||||
|
||||
bool
|
||||
is_vector() const {
|
||||
Assert(type_ != DataType::NONE);
|
||||
return type_ == DataType::VECTOR_BINARY || type_ == DataType::VECTOR_FLOAT;
|
||||
return type_ == DataType::VECTOR_BINARY ||
|
||||
type_ == DataType::VECTOR_FLOAT;
|
||||
}
|
||||
|
||||
bool
|
||||
|
|
|
@ -38,7 +38,8 @@ struct SearchResult {
|
|||
if (topk_per_nq_prefix_sum_.empty()) {
|
||||
return 0;
|
||||
}
|
||||
AssertInfo(topk_per_nq_prefix_sum_.size() == total_nq_ + 1, "wrong topk_per_nq_prefix_sum_ size");
|
||||
AssertInfo(topk_per_nq_prefix_sum_.size() == total_nq_ + 1,
|
||||
"wrong topk_per_nq_prefix_sum_ size");
|
||||
return topk_per_nq_prefix_sum_[total_nq_];
|
||||
}
|
||||
|
||||
|
|
|
@ -20,11 +20,14 @@ namespace {
|
|||
using ResultPair = std::pair<float, int64_t>;
|
||||
}
|
||||
DatasetPtr
|
||||
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type) {
|
||||
SortRangeSearchResult(DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
int64_t nq,
|
||||
std::string metric_type) {
|
||||
/**
|
||||
* nq: number of querys;
|
||||
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result querys[i]
|
||||
* for example, the nq is 5. In the seleted range,
|
||||
* nq: number of queries;
|
||||
* lims: the size of lims is nq + 1, lims[i+1] - lims[i] refers to the size of RangeSearch result queries[i]
|
||||
* for example, the nq is 5. In the selected range,
|
||||
* the size of RangeSearch result for each nq is [1, 2, 3, 4, 5],
|
||||
* the lims will be [0, 1, 3, 6, 10, 15];
|
||||
* ids: the size of ids is lim[nq],
|
||||
|
@ -32,13 +35,13 @@ SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string
|
|||
* i(1,0), i(1,1), …, i(1,k1-1),
|
||||
* …,
|
||||
* i(n-1,0), i(n-1,1), …, i(n-1,kn-1)},
|
||||
* i(0,0), i(0,1), …, i(0,k0-1) means the ids of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
|
||||
* i(0,0), i(0,1), …, i(0,k0-1) means the ids of RangeSearch result queries[0], k0 equals lim[1] - lim[0];
|
||||
* dist: the size of ids is lim[nq],
|
||||
* { d(0,0), d(0,1), …, d(0,k0-1),
|
||||
* d(1,0), d(1,1), …, d(1,k1-1),
|
||||
* …,
|
||||
* d(n-1,0), d(n-1,1), …, d(n-1,kn-1)},
|
||||
* d(0,0), d(0,1), …, d(0,k0-1) means the distances of RangeSearch result querys[0], k0 equals lim[1] - lim[0];
|
||||
* d(0,0), d(0,1), …, d(0,k0-1) means the distances of RangeSearch result queries[0], k0 equals lim[1] - lim[0];
|
||||
*/
|
||||
auto lims = GetDatasetLims(data_set);
|
||||
auto id = GetDatasetIDs(data_set);
|
||||
|
@ -65,11 +68,14 @@ SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string
|
|||
* |------------+---------------| max_heap ascending_order
|
||||
*
|
||||
*/
|
||||
std::function<bool(const ResultPair&, const ResultPair&)> cmp = std::less<std::pair<float, int64_t>>();
|
||||
std::function<bool(const ResultPair&, const ResultPair&)> cmp =
|
||||
std::less<std::pair<float, int64_t>>();
|
||||
if (IsMetricType(metric_type, knowhere::metric::IP)) {
|
||||
cmp = std::greater<std::pair<float, int64_t>>();
|
||||
}
|
||||
std::priority_queue<std::pair<float, int64_t>, std::vector<std::pair<float, int64_t>>, decltype(cmp)>
|
||||
std::priority_queue<std::pair<float, int64_t>,
|
||||
std::vector<std::pair<float, int64_t>>,
|
||||
decltype(cmp)>
|
||||
sub_result(cmp);
|
||||
|
||||
for (int j = lims[i]; j < lims[i + 1]; j++) {
|
||||
|
@ -94,7 +100,9 @@ SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string
|
|||
}
|
||||
|
||||
void
|
||||
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type) {
|
||||
CheckRangeSearchParam(float radius,
|
||||
float range_filter,
|
||||
std::string metric_type) {
|
||||
/*
|
||||
* IP: 1.0 range_filter radius
|
||||
* |------------+---------------| min_heap descending_order
|
||||
|
|
|
@ -17,8 +17,13 @@
|
|||
namespace milvus {
|
||||
|
||||
DatasetPtr
|
||||
SortRangeSearchResult(DatasetPtr data_set, int64_t topk, int64_t nq, std::string metric_type);
|
||||
SortRangeSearchResult(DatasetPtr data_set,
|
||||
int64_t topk,
|
||||
int64_t nq,
|
||||
std::string metric_type);
|
||||
|
||||
void
|
||||
CheckRangeSearchParam(float radius, float range_filter, std::string metric_type);
|
||||
CheckRangeSearchParam(float radius,
|
||||
float range_filter,
|
||||
std::string metric_type);
|
||||
} // namespace milvus
|
||||
|
|
|
@ -26,10 +26,13 @@ namespace milvus {
|
|||
|
||||
using std::string;
|
||||
static std::map<string, string>
|
||||
RepeatedKeyValToMap(const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>& kvs) {
|
||||
RepeatedKeyValToMap(
|
||||
const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>&
|
||||
kvs) {
|
||||
std::map<string, string> mapping;
|
||||
for (auto& kv : kvs) {
|
||||
AssertInfo(!mapping.count(kv.key()), "repeat key(" + kv.key() + ") in protobuf");
|
||||
AssertInfo(!mapping.count(kv.key()),
|
||||
"repeat key(" + kv.key() + ") in protobuf");
|
||||
mapping.emplace(kv.key(), kv.value());
|
||||
}
|
||||
return mapping;
|
||||
|
@ -42,15 +45,18 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
|
|||
|
||||
// NOTE: only two system
|
||||
|
||||
for (const milvus::proto::schema::FieldSchema& child : schema_proto.fields()) {
|
||||
for (const milvus::proto::schema::FieldSchema& child :
|
||||
schema_proto.fields()) {
|
||||
auto field_id = FieldId(child.fieldid());
|
||||
auto name = FieldName(child.name());
|
||||
|
||||
if (field_id.get() < 100) {
|
||||
// system field id
|
||||
auto is_system = SystemProperty::Instance().SystemFieldVerify(name, field_id);
|
||||
auto is_system =
|
||||
SystemProperty::Instance().SystemFieldVerify(name, field_id);
|
||||
AssertInfo(is_system,
|
||||
"invalid system type: name(" + name.get() + "), id(" + std::to_string(field_id.get()) + ")");
|
||||
"invalid system type: name(" + name.get() + "), id(" +
|
||||
std::to_string(field_id.get()) + ")");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -71,23 +77,28 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
|
|||
} else if (datatype_is_string(data_type)) {
|
||||
auto type_map = RepeatedKeyValToMap(child.type_params());
|
||||
AssertInfo(type_map.count(MAX_LENGTH), "max_length not found");
|
||||
auto max_len = boost::lexical_cast<int64_t>(type_map.at(MAX_LENGTH));
|
||||
auto max_len =
|
||||
boost::lexical_cast<int64_t>(type_map.at(MAX_LENGTH));
|
||||
schema->AddField(name, field_id, data_type, max_len);
|
||||
} else {
|
||||
schema->AddField(name, field_id, data_type);
|
||||
}
|
||||
|
||||
if (child.is_primary_key()) {
|
||||
AssertInfo(!schema->get_primary_field_id().has_value(), "repetitive primary key");
|
||||
AssertInfo(!schema->get_primary_field_id().has_value(),
|
||||
"repetitive primary key");
|
||||
schema->set_primary_field_id(field_id);
|
||||
}
|
||||
}
|
||||
|
||||
AssertInfo(schema->get_primary_field_id().has_value(), "primary key should be specified");
|
||||
AssertInfo(schema->get_primary_field_id().has_value(),
|
||||
"primary key should be specified");
|
||||
|
||||
return schema;
|
||||
}
|
||||
|
||||
const FieldMeta FieldMeta::RowIdMeta(FieldName("RowID"), RowFieldID, DataType::INT64);
|
||||
const FieldMeta FieldMeta::RowIdMeta(FieldName("RowID"),
|
||||
RowFieldID,
|
||||
DataType::INT64);
|
||||
|
||||
} // namespace milvus
|
||||
|
|
|
@ -49,7 +49,8 @@ class Schema {
|
|||
std::optional<knowhere::MetricType> metric_type) {
|
||||
auto field_id = FieldId(debug_id);
|
||||
debug_id++;
|
||||
auto field_meta = FieldMeta(FieldName(name), field_id, data_type, dim, metric_type);
|
||||
auto field_meta =
|
||||
FieldMeta(FieldName(name), field_id, data_type, dim, metric_type);
|
||||
this->AddField(std::move(field_meta));
|
||||
return field_id;
|
||||
}
|
||||
|
@ -63,7 +64,10 @@ class Schema {
|
|||
|
||||
// string type
|
||||
void
|
||||
AddField(const FieldName& name, const FieldId id, DataType data_type, int64_t max_length) {
|
||||
AddField(const FieldName& name,
|
||||
const FieldId id,
|
||||
DataType data_type,
|
||||
int64_t max_length) {
|
||||
auto field_meta = FieldMeta(name, id, data_type, max_length);
|
||||
this->AddField(std::move(field_meta));
|
||||
}
|
||||
|
@ -103,7 +107,8 @@ class Schema {
|
|||
operator[](FieldId field_id) const {
|
||||
Assert(field_id.get() >= 0);
|
||||
AssertInfo(fields_.find(field_id) != fields_.end(),
|
||||
"Cannot find field with field_id: " + std::to_string(field_id.get()));
|
||||
"Cannot find field with field_id: " +
|
||||
std::to_string(field_id.get()));
|
||||
return fields_.at(field_id);
|
||||
}
|
||||
|
||||
|
@ -131,7 +136,8 @@ class Schema {
|
|||
const FieldMeta&
|
||||
operator[](const FieldName& field_name) const {
|
||||
auto id_iter = name_ids_.find(field_name);
|
||||
AssertInfo(id_iter != name_ids_.end(), "Cannot find field with field_name: " + field_name.get());
|
||||
AssertInfo(id_iter != name_ids_.end(),
|
||||
"Cannot find field with field_name: " + field_name.get());
|
||||
return fields_.at(id_iter->second);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,11 @@ static const char* SLICE_NUM = "slice_num";
|
|||
static const char* TOTAL_LEN = "total_len";
|
||||
|
||||
void
|
||||
Slice(
|
||||
const std::string& prefix, const BinaryPtr& data_src, const int64_t slice_len, BinarySet& binarySet, Config& ret) {
|
||||
Slice(const std::string& prefix,
|
||||
const BinaryPtr& data_src,
|
||||
const int64_t slice_len,
|
||||
BinarySet& binarySet,
|
||||
Config& ret) {
|
||||
if (!data_src) {
|
||||
return;
|
||||
}
|
||||
|
@ -39,7 +42,8 @@ Slice(
|
|||
auto size = static_cast<size_t>(ri - i);
|
||||
auto slice_i = std::shared_ptr<uint8_t[]>(new uint8_t[size]);
|
||||
memcpy(slice_i.get(), data_src->data.get() + i, size);
|
||||
binarySet.Append(prefix + "_" + std::to_string(slice_num), slice_i, ri - i);
|
||||
binarySet.Append(
|
||||
prefix + "_" + std::to_string(slice_num), slice_i, ri - i);
|
||||
i = ri;
|
||||
}
|
||||
ret[NAME] = prefix;
|
||||
|
@ -54,7 +58,8 @@ Assemble(BinarySet& binarySet) {
|
|||
return;
|
||||
}
|
||||
|
||||
Config meta_data = Config::parse(std::string(reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
|
||||
Config meta_data = Config::parse(std::string(
|
||||
reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
|
||||
|
||||
for (auto& item : meta_data[META]) {
|
||||
std::string prefix = item[NAME];
|
||||
|
@ -64,7 +69,9 @@ Assemble(BinarySet& binarySet) {
|
|||
int64_t pos = 0;
|
||||
for (auto i = 0; i < slice_num; ++i) {
|
||||
auto slice_i_sp = binarySet.Erase(prefix + "_" + std::to_string(i));
|
||||
memcpy(p_data.get() + pos, slice_i_sp->data.get(), static_cast<size_t>(slice_i_sp->size));
|
||||
memcpy(p_data.get() + pos,
|
||||
slice_i_sp->data.get(),
|
||||
static_cast<size_t>(slice_i_sp->size));
|
||||
pos += slice_i_sp->size;
|
||||
}
|
||||
binarySet.Append(prefix, p_data, total_len);
|
||||
|
@ -76,8 +83,8 @@ Disassemble(BinarySet& binarySet) {
|
|||
Config meta_info;
|
||||
auto slice_meta = EraseSliceMeta(binarySet);
|
||||
if (slice_meta != nullptr) {
|
||||
Config last_meta_data =
|
||||
Config::parse(std::string(reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
|
||||
Config last_meta_data = Config::parse(std::string(
|
||||
reinterpret_cast<char*>(slice_meta->data.get()), slice_meta->size));
|
||||
for (auto& item : last_meta_data[META]) {
|
||||
meta_info[META].emplace_back(item);
|
||||
}
|
||||
|
@ -92,7 +99,8 @@ Disassemble(BinarySet& binarySet) {
|
|||
}
|
||||
for (auto& key : slice_key_list) {
|
||||
Config slice_i;
|
||||
Slice(key, binarySet.Erase(key), slice_size_in_byte, binarySet, slice_i);
|
||||
Slice(
|
||||
key, binarySet.Erase(key), slice_size_in_byte, binarySet, slice_i);
|
||||
meta_info[META].emplace_back(slice_i);
|
||||
}
|
||||
if (!slice_key_list.empty()) {
|
||||
|
|
|
@ -27,7 +27,9 @@ namespace milvus {
|
|||
// type erasure to work around virtual restriction
|
||||
class SpanBase {
|
||||
public:
|
||||
explicit SpanBase(const void* data, int64_t row_count, int64_t element_sizeof)
|
||||
explicit SpanBase(const void* data,
|
||||
int64_t row_count,
|
||||
int64_t element_sizeof)
|
||||
: data_(data), row_count_(row_count), element_sizeof_(element_sizeof) {
|
||||
}
|
||||
|
||||
|
@ -57,17 +59,21 @@ class Span;
|
|||
|
||||
// TODO: refine Span to support T=FloatVector
|
||||
template <typename T>
|
||||
class Span<T, typename std::enable_if_t<IsScalar<T> || std::is_same_v<T, PkType>>> {
|
||||
class Span<
|
||||
T,
|
||||
typename std::enable_if_t<IsScalar<T> || std::is_same_v<T, PkType>>> {
|
||||
public:
|
||||
using embeded_type = T;
|
||||
explicit Span(const T* data, int64_t row_count) : data_(data), row_count_(row_count) {
|
||||
using embedded_type = T;
|
||||
explicit Span(const T* data, int64_t row_count)
|
||||
: data_(data), row_count_(row_count) {
|
||||
}
|
||||
|
||||
operator SpanBase() const {
|
||||
return SpanBase(data_, row_count_, sizeof(T));
|
||||
}
|
||||
|
||||
explicit Span(const SpanBase& base) : Span(reinterpret_cast<const T*>(base.data()), base.row_count()) {
|
||||
explicit Span(const SpanBase& base)
|
||||
: Span(reinterpret_cast<const T*>(base.data()), base.row_count()) {
|
||||
assert(base.element_sizeof() == sizeof(T));
|
||||
}
|
||||
|
||||
|
@ -97,7 +103,9 @@ class Span<T, typename std::enable_if_t<IsScalar<T> || std::is_same_v<T, PkType>
|
|||
};
|
||||
|
||||
template <typename VectorType>
|
||||
class Span<VectorType, typename std::enable_if_t<std::is_base_of_v<VectorTrait, VectorType>>> {
|
||||
class Span<
|
||||
VectorType,
|
||||
typename std::enable_if_t<std::is_base_of_v<VectorTrait, VectorType>>> {
|
||||
public:
|
||||
using embedded_type = typename VectorType::embedded_type;
|
||||
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace milvus {
|
|||
class SystemPropertyImpl : public SystemProperty {
|
||||
public:
|
||||
bool
|
||||
SystemFieldVerify(const FieldName& field_name, FieldId field_id) const override {
|
||||
SystemFieldVerify(const FieldName& field_name,
|
||||
FieldId field_id) const override {
|
||||
if (!IsSystem(field_name)) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -90,7 +90,8 @@ template <class...>
|
|||
constexpr std::false_type always_false{};
|
||||
|
||||
template <typename T>
|
||||
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 64>>;
|
||||
using aligned_vector =
|
||||
std::vector<T, boost::alignment::aligned_allocator<T, 64>>;
|
||||
|
||||
namespace impl {
|
||||
// hide identifier name to make auto-completion happy
|
||||
|
@ -100,10 +101,15 @@ struct FieldOffsetTag;
|
|||
struct SegOffsetTag;
|
||||
}; // namespace impl
|
||||
|
||||
using FieldId = fluent::NamedType<int64_t, impl::FieldIdTag, fluent::Comparable, fluent::Hashable>;
|
||||
using FieldName = fluent::NamedType<std::string, impl::FieldNameTag, fluent::Comparable, fluent::Hashable>;
|
||||
using FieldId = fluent::
|
||||
NamedType<int64_t, impl::FieldIdTag, fluent::Comparable, fluent::Hashable>;
|
||||
using FieldName = fluent::NamedType<std::string,
|
||||
impl::FieldNameTag,
|
||||
fluent::Comparable,
|
||||
fluent::Hashable>;
|
||||
// using FieldOffset = fluent::NamedType<int64_t, impl::FieldOffsetTag, fluent::Comparable, fluent::Hashable>;
|
||||
using SegOffset = fluent::NamedType<int64_t, impl::SegOffsetTag, fluent::Arithmetic>;
|
||||
using SegOffset =
|
||||
fluent::NamedType<int64_t, impl::SegOffsetTag, fluent::Arithmetic>;
|
||||
|
||||
using BitsetType = boost::dynamic_bitset<>;
|
||||
using BitsetTypePtr = std::shared_ptr<boost::dynamic_bitset<>>;
|
||||
|
|
|
@ -75,7 +75,10 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
|
|||
}
|
||||
|
||||
inline DatasetPtr
|
||||
GenResultDataset(const int64_t nq, const int64_t topk, const int64_t* ids, const float* distance) {
|
||||
GenResultDataset(const int64_t nq,
|
||||
const int64_t topk,
|
||||
const int64_t* ids,
|
||||
const float* distance) {
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->SetRows(nq);
|
||||
ret_ds->SetDim(topk);
|
||||
|
@ -192,7 +195,8 @@ GetDataSize(const FieldMeta& field, size_t row_count, const DataArray* data) {
|
|||
}
|
||||
|
||||
default:
|
||||
PanicInfo(fmt::format("not supported data type {}", datatype_name(data_type)));
|
||||
PanicInfo(fmt::format("not supported data type {}",
|
||||
datatype_name(data_type)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,7 +204,10 @@ GetDataSize(const FieldMeta& field, size_t row_count, const DataArray* data) {
|
|||
}
|
||||
|
||||
inline void*
|
||||
FillField(DataType data_type, size_t size, const LoadFieldDataInfo& info, void* dst) {
|
||||
FillField(DataType data_type,
|
||||
size_t size,
|
||||
const LoadFieldDataInfo& info,
|
||||
void* dst) {
|
||||
auto data = info.field_data;
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
|
@ -225,10 +232,12 @@ FillField(DataType data_type, size_t size, const LoadFieldDataInfo& info, void*
|
|||
return memcpy(dst, data->scalars().long_data().data().data(), size);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return memcpy(dst, data->scalars().float_data().data().data(), size);
|
||||
return memcpy(
|
||||
dst, data->scalars().float_data().data().data(), size);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return memcpy(dst, data->scalars().double_data().data().data(), size);
|
||||
return memcpy(
|
||||
dst, data->scalars().double_data().data().data(), size);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
char* dest = reinterpret_cast<char*>(dst);
|
||||
|
@ -243,7 +252,8 @@ FillField(DataType data_type, size_t size, const LoadFieldDataInfo& info, void*
|
|||
return dst;
|
||||
}
|
||||
case DataType::VECTOR_FLOAT:
|
||||
return memcpy(dst, data->vectors().float_vector().data().data(), size);
|
||||
return memcpy(
|
||||
dst, data->vectors().float_vector().data().data(), size);
|
||||
|
||||
case DataType::VECTOR_BINARY:
|
||||
return memcpy(dst, data->vectors().binary_vector().data(), size);
|
||||
|
@ -300,7 +310,8 @@ WriteFieldData(int fd, DataType data_type, const DataArray* data, size_t size) {
|
|||
return total_written;
|
||||
}
|
||||
case DataType::VECTOR_FLOAT:
|
||||
return write(fd, data->vectors().float_vector().data().data(), size);
|
||||
return write(
|
||||
fd, data->vectors().float_vector().data().data(), size);
|
||||
|
||||
case DataType::VECTOR_BINARY:
|
||||
return write(fd, data->vectors().binary_vector().data(), size);
|
||||
|
@ -315,7 +326,9 @@ WriteFieldData(int fd, DataType data_type, const DataArray* data, size_t size) {
|
|||
// if mmap enabled, this writes field data to disk and create a map to the file,
|
||||
// otherwise this just alloc memory
|
||||
inline void*
|
||||
CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataInfo& info) {
|
||||
CreateMap(int64_t segment_id,
|
||||
const FieldMeta& field_meta,
|
||||
const LoadFieldDataInfo& info) {
|
||||
static int mmap_flags = MAP_PRIVATE;
|
||||
#ifdef MAP_POPULATE
|
||||
// macOS doesn't support MAP_POPULATE
|
||||
|
@ -324,33 +337,52 @@ CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataIn
|
|||
// Allocate memory
|
||||
if (info.mmap_dir_path == nullptr) {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
auto data_size = GetDataSize(field_meta, info.row_count, info.field_data);
|
||||
auto data_size =
|
||||
GetDataSize(field_meta, info.row_count, info.field_data);
|
||||
if (data_size == 0)
|
||||
return nullptr;
|
||||
|
||||
// Use anon mapping so we are able to free these memory with munmap only
|
||||
void* map = mmap(NULL, data_size, PROT_READ | PROT_WRITE, mmap_flags | MAP_ANON, -1, 0);
|
||||
AssertInfo(map != MAP_FAILED, fmt::format("failed to create anon map, err: {}", strerror(errno)));
|
||||
void* map = mmap(NULL,
|
||||
data_size,
|
||||
PROT_READ | PROT_WRITE,
|
||||
mmap_flags | MAP_ANON,
|
||||
-1,
|
||||
0);
|
||||
AssertInfo(
|
||||
map != MAP_FAILED,
|
||||
fmt::format("failed to create anon map, err: {}", strerror(errno)));
|
||||
FillField(data_type, data_size, info, map);
|
||||
return map;
|
||||
}
|
||||
|
||||
auto filepath =
|
||||
std::filesystem::path(info.mmap_dir_path) / std::to_string(segment_id) / std::to_string(info.field_id);
|
||||
auto filepath = std::filesystem::path(info.mmap_dir_path) /
|
||||
std::to_string(segment_id) / std::to_string(info.field_id);
|
||||
auto dir = filepath.parent_path();
|
||||
std::filesystem::create_directories(dir);
|
||||
|
||||
int fd = open(filepath.c_str(), O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR);
|
||||
AssertInfo(fd != -1, fmt::format("failed to create mmap file {}", filepath.c_str()));
|
||||
int fd =
|
||||
open(filepath.c_str(), O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR);
|
||||
AssertInfo(fd != -1,
|
||||
fmt::format("failed to create mmap file {}", filepath.c_str()));
|
||||
|
||||
auto data_type = field_meta.get_data_type();
|
||||
size_t size = field_meta.get_sizeof() * info.row_count;
|
||||
auto written = WriteFieldData(fd, data_type, info.field_data, size);
|
||||
AssertInfo(written == size || written != -1 && datatype_is_variable(field_meta.get_data_type()),
|
||||
fmt::format("failed to write data file {}, written {} but total {}, err: {}", filepath.c_str(), written,
|
||||
size, strerror(errno)));
|
||||
AssertInfo(
|
||||
written == size ||
|
||||
written != -1 && datatype_is_variable(field_meta.get_data_type()),
|
||||
fmt::format(
|
||||
"failed to write data file {}, written {} but total {}, err: {}",
|
||||
filepath.c_str(),
|
||||
written,
|
||||
size,
|
||||
strerror(errno)));
|
||||
int ok = fsync(fd);
|
||||
AssertInfo(ok == 0, fmt::format("failed to fsync mmap data file {}, err: {}", filepath.c_str(), strerror(errno)));
|
||||
AssertInfo(ok == 0,
|
||||
fmt::format("failed to fsync mmap data file {}, err: {}",
|
||||
filepath.c_str(),
|
||||
strerror(errno)));
|
||||
|
||||
// Empty field
|
||||
if (written == 0) {
|
||||
|
@ -359,7 +391,9 @@ CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataIn
|
|||
|
||||
auto map = mmap(NULL, written, PROT_READ, mmap_flags, fd, 0);
|
||||
AssertInfo(map != MAP_FAILED,
|
||||
fmt::format("failed to create map for data file {}, err: {}", filepath.c_str(), strerror(errno)));
|
||||
fmt::format("failed to create map for data file {}, err: {}",
|
||||
filepath.c_str(),
|
||||
strerror(errno)));
|
||||
|
||||
#ifndef MAP_POPULATE
|
||||
// Manually access the mapping to populate it
|
||||
|
@ -373,9 +407,15 @@ CreateMap(int64_t segment_id, const FieldMeta& field_meta, const LoadFieldDataIn
|
|||
// unlink this data file so
|
||||
// then it will be auto removed after we don't need it again
|
||||
ok = unlink(filepath.c_str());
|
||||
AssertInfo(ok == 0, fmt::format("failed to unlink mmap data file {}, err: {}", filepath.c_str(), strerror(errno)));
|
||||
AssertInfo(ok == 0,
|
||||
fmt::format("failed to unlink mmap data file {}, err: {}",
|
||||
filepath.c_str(),
|
||||
strerror(errno)));
|
||||
ok = close(fd);
|
||||
AssertInfo(ok == 0, fmt::format("failed to close data file {}, err: {}", filepath.c_str(), strerror(errno)));
|
||||
AssertInfo(ok == 0,
|
||||
fmt::format("failed to close data file {}, err: {}",
|
||||
filepath.c_str(),
|
||||
strerror(errno)));
|
||||
return map;
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,8 @@ constexpr bool IsVector = std::is_base_of_v<VectorTrait, T>;
|
|||
|
||||
template <typename T>
|
||||
constexpr bool IsScalar =
|
||||
std::is_fundamental_v<T> || std::is_same_v<T, std::string> || std::is_same_v<T, std::string_view>;
|
||||
std::is_fundamental_v<T> || std::is_same_v<T, std::string> ||
|
||||
std::is_same_v<T, std::string_view>;
|
||||
|
||||
template <typename T, typename Enabled = void>
|
||||
struct EmbeddedTypeImpl;
|
||||
|
@ -62,7 +63,8 @@ struct EmbeddedTypeImpl<T, std::enable_if_t<IsScalar<T>>> {
|
|||
|
||||
template <typename T>
|
||||
struct EmbeddedTypeImpl<T, std::enable_if_t<IsVector<T>>> {
|
||||
using type = std::conditional_t<std::is_same_v<T, FloatVector>, float, uint8_t>;
|
||||
using type =
|
||||
std::conditional_t<std::is_same_v<T, FloatVector>, float, uint8_t>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -41,7 +41,10 @@ DeleteBinarySet(CBinarySet c_binary_set) {
|
|||
}
|
||||
|
||||
CStatus
|
||||
AppendIndexBinary(CBinarySet c_binary_set, void* index_binary, int64_t index_size, const char* c_index_key) {
|
||||
AppendIndexBinary(CBinarySet c_binary_set,
|
||||
void* index_binary,
|
||||
int64_t index_size,
|
||||
const char* c_index_key) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
auto binary_set = (knowhere::BinarySet*)c_binary_set;
|
||||
|
@ -68,13 +71,13 @@ GetBinarySetSize(CBinarySet c_binary_set) {
|
|||
}
|
||||
|
||||
void
|
||||
GetBinarySetKeys(CBinarySet c_binary_set, void* datas) {
|
||||
GetBinarySetKeys(CBinarySet c_binary_set, void* data) {
|
||||
auto binary_set = (knowhere::BinarySet*)c_binary_set;
|
||||
auto& map_ = binary_set->binary_map_;
|
||||
const char** datas_ = (const char**)datas;
|
||||
const char** data_ = (const char**)data;
|
||||
std::size_t i = 0;
|
||||
for (auto it = map_.begin(); it != map_.end(); ++it, i++) {
|
||||
datas_[i] = it->first.c_str();
|
||||
data_[i] = it->first.c_str();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,13 +32,16 @@ void
|
|||
DeleteBinarySet(CBinarySet c_binary_set);
|
||||
|
||||
CStatus
|
||||
AppendIndexBinary(CBinarySet c_binary_set, void* index_binary, int64_t index_size, const char* c_index_key);
|
||||
AppendIndexBinary(CBinarySet c_binary_set,
|
||||
void* index_binary,
|
||||
int64_t index_size,
|
||||
const char* c_index_key);
|
||||
|
||||
int
|
||||
GetBinarySetSize(CBinarySet c_binary_set);
|
||||
|
||||
void
|
||||
GetBinarySetKeys(CBinarySet c_binary_set, void* datas);
|
||||
GetBinarySetKeys(CBinarySet c_binary_set, void* data);
|
||||
|
||||
int
|
||||
GetBinarySetValueSize(CBinarySet c_set, const char* key);
|
||||
|
|
|
@ -29,7 +29,11 @@ void
|
|||
InitLocalRootPath(const char* root_path) {
|
||||
std::string local_path_root(root_path);
|
||||
std::call_once(
|
||||
flag1, [](std::string path) { milvus::ChunkMangerConfig::SetLocalRootPath(path); }, local_path_root);
|
||||
flag1,
|
||||
[](std::string path) {
|
||||
milvus::ChunkMangerConfig::SetLocalRootPath(path);
|
||||
},
|
||||
local_path_root);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -41,7 +45,9 @@ InitIndexSliceSize(const int64_t size) {
|
|||
void
|
||||
InitThreadCoreCoefficient(const int64_t value) {
|
||||
std::call_once(
|
||||
flag3, [](int64_t value) { milvus::SetThreadCoreCoefficient(value); }, value);
|
||||
flag3,
|
||||
[](int64_t value) { milvus::SetThreadCoreCoefficient(value); },
|
||||
value);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -45,7 +45,8 @@ KnowhereInitImpl(const char* conf_file) {
|
|||
if (conf_file != nullptr) {
|
||||
el::Configurations el_conf(conf_file);
|
||||
el::Loggers::reconfigureAllLoggers(el_conf);
|
||||
LOG_SERVER_DEBUG_ << "Config easylogging with yaml file: " << conf_file;
|
||||
LOG_SERVER_DEBUG_ << "Config easylogging with yaml file: "
|
||||
<< conf_file;
|
||||
}
|
||||
LOG_SERVER_DEBUG_ << "Knowhere init successfully";
|
||||
#endif
|
||||
|
|
|
@ -45,7 +45,8 @@ EasyAssertInfo(bool value,
|
|||
if (!value) {
|
||||
std::string info;
|
||||
info += "Assert \"" + std::string(expr_str) + "\"";
|
||||
info += " at " + std::string(filename) + ":" + std::to_string(lineno) + "\n";
|
||||
info += " at " + std::string(filename) + ":" + std::to_string(lineno) +
|
||||
"\n";
|
||||
if (!extra_info.empty()) {
|
||||
info += " => " + std::string(extra_info);
|
||||
}
|
||||
|
|
|
@ -54,13 +54,14 @@ class SegcoreError : public std::runtime_error {
|
|||
|
||||
} // namespace milvus
|
||||
|
||||
#define AssertInfo(expr, info) \
|
||||
do { \
|
||||
auto _expr_res = bool(expr); \
|
||||
/* call func only when needed */ \
|
||||
if (!_expr_res) { \
|
||||
milvus::impl::EasyAssertInfo(_expr_res, #expr, __FILE__, __LINE__, (info)); \
|
||||
} \
|
||||
#define AssertInfo(expr, info) \
|
||||
do { \
|
||||
auto _expr_res = bool(expr); \
|
||||
/* call func only when needed */ \
|
||||
if (!_expr_res) { \
|
||||
milvus::impl::EasyAssertInfo( \
|
||||
_expr_res, #expr, __FILE__, __LINE__, (info)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define Assert(expr) AssertInfo((expr), "")
|
||||
|
@ -70,8 +71,9 @@ class SegcoreError : public std::runtime_error {
|
|||
__builtin_unreachable(); \
|
||||
} while (0)
|
||||
|
||||
#define PanicCodeInfo(errcode, info) \
|
||||
do { \
|
||||
milvus::impl::EasyAssertInfo(false, (info), __FILE__, __LINE__, "", errcode); \
|
||||
__builtin_unreachable(); \
|
||||
#define PanicCodeInfo(errcode, info) \
|
||||
do { \
|
||||
milvus::impl::EasyAssertInfo( \
|
||||
false, (info), __FILE__, __LINE__, "", errcode); \
|
||||
__builtin_unreachable(); \
|
||||
} while (0)
|
||||
|
|
|
@ -34,7 +34,9 @@ class IndexBase {
|
|||
Load(const BinarySet& binary_set, const Config& config = {}) = 0;
|
||||
|
||||
virtual void
|
||||
BuildWithRawData(size_t n, const void* values, const Config& config = {}) = 0;
|
||||
BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config = {}) = 0;
|
||||
|
||||
virtual void
|
||||
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) = 0;
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
namespace milvus::index {
|
||||
|
||||
IndexBasePtr
|
||||
IndexFactory::CreateIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager) {
|
||||
IndexFactory::CreateIndex(const CreateIndexInfo& create_index_info,
|
||||
storage::FileManagerImplPtr file_manager) {
|
||||
if (datatype_is_vector(create_index_info.field_type)) {
|
||||
return CreateVectorIndex(create_index_info, file_manager);
|
||||
}
|
||||
|
@ -62,13 +63,15 @@ IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info) {
|
|||
case DataType::VARCHAR:
|
||||
return CreateScalarIndex<std::string>(index_type);
|
||||
default:
|
||||
throw std::invalid_argument(std::string("invalid data type to build index: ") +
|
||||
std::to_string(int(data_type)));
|
||||
throw std::invalid_argument(
|
||||
std::string("invalid data type to build index: ") +
|
||||
std::to_string(int(data_type)));
|
||||
}
|
||||
}
|
||||
|
||||
IndexBasePtr
|
||||
IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager) {
|
||||
IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info,
|
||||
storage::FileManagerImplPtr file_manager) {
|
||||
auto data_type = create_index_info.field_type;
|
||||
auto index_type = create_index_info.index_type;
|
||||
auto metric_type = create_index_info.metric_type;
|
||||
|
@ -79,20 +82,24 @@ IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info, storag
|
|||
if (is_in_disk_list(index_type)) {
|
||||
switch (data_type) {
|
||||
case DataType::VECTOR_FLOAT: {
|
||||
return std::make_unique<VectorDiskAnnIndex<float>>(index_type, metric_type, index_mode, file_manager);
|
||||
return std::make_unique<VectorDiskAnnIndex<float>>(
|
||||
index_type, metric_type, index_mode, file_manager);
|
||||
}
|
||||
default:
|
||||
throw std::invalid_argument(std::string("invalid data type to build disk index: ") +
|
||||
std::to_string(int(data_type)));
|
||||
throw std::invalid_argument(
|
||||
std::string("invalid data type to build disk index: ") +
|
||||
std::to_string(int(data_type)));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_in_nm_list(index_type)) {
|
||||
return std::make_unique<VectorMemNMIndex>(index_type, metric_type, index_mode);
|
||||
return std::make_unique<VectorMemNMIndex>(
|
||||
index_type, metric_type, index_mode);
|
||||
}
|
||||
// create mem index
|
||||
return std::make_unique<VectorMemIndex>(index_type, metric_type, index_mode);
|
||||
return std::make_unique<VectorMemIndex>(
|
||||
index_type, metric_type, index_mode);
|
||||
}
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -53,10 +53,12 @@ class IndexFactory {
|
|||
}
|
||||
|
||||
IndexBasePtr
|
||||
CreateIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager);
|
||||
CreateIndex(const CreateIndexInfo& create_index_info,
|
||||
storage::FileManagerImplPtr file_manager);
|
||||
|
||||
IndexBasePtr
|
||||
CreateVectorIndex(const CreateIndexInfo& create_index_info, storage::FileManagerImplPtr file_manager);
|
||||
CreateVectorIndex(const CreateIndexInfo& create_index_info,
|
||||
storage::FileManagerImplPtr file_manager);
|
||||
|
||||
IndexBasePtr
|
||||
CreateScalarIndex(const CreateIndexInfo& create_index_info);
|
||||
|
|
|
@ -38,9 +38,14 @@ ScalarIndex<T>::Query(const DatasetPtr& dataset) {
|
|||
case OpType::Range: {
|
||||
auto lower_bound_value = dataset->Get<T>(LOWER_BOUND_VALUE);
|
||||
auto upper_bound_value = dataset->Get<T>(UPPER_BOUND_VALUE);
|
||||
auto lower_bound_inclusive = dataset->Get<bool>(LOWER_BOUND_INCLUSIVE);
|
||||
auto upper_bound_inclusive = dataset->Get<bool>(UPPER_BOUND_INCLUSIVE);
|
||||
return Range(lower_bound_value, lower_bound_inclusive, upper_bound_value, upper_bound_inclusive);
|
||||
auto lower_bound_inclusive =
|
||||
dataset->Get<bool>(LOWER_BOUND_INCLUSIVE);
|
||||
auto upper_bound_inclusive =
|
||||
dataset->Get<bool>(UPPER_BOUND_INCLUSIVE);
|
||||
return Range(lower_bound_value,
|
||||
lower_bound_inclusive,
|
||||
upper_bound_value,
|
||||
upper_bound_inclusive);
|
||||
}
|
||||
|
||||
case OpType::In: {
|
||||
|
@ -58,13 +63,16 @@ ScalarIndex<T>::Query(const DatasetPtr& dataset) {
|
|||
case OpType::PrefixMatch:
|
||||
case OpType::PostfixMatch:
|
||||
default:
|
||||
throw std::invalid_argument(std::string("unsupported operator type: " + std::to_string(op)));
|
||||
throw std::invalid_argument(std::string(
|
||||
"unsupported operator type: " + std::to_string(op)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<std::string>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<std::string>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
// TODO :: use arrow
|
||||
proto::schema::StringArray arr;
|
||||
auto ok = arr.ParseFromArray(values, n);
|
||||
|
@ -77,7 +85,9 @@ ScalarIndex<std::string>::BuildWithRawData(size_t n, const void* values, const C
|
|||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<bool>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<bool>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
proto::schema::BoolArray arr;
|
||||
auto ok = arr.ParseFromArray(values, n);
|
||||
Assert(ok);
|
||||
|
@ -86,42 +96,54 @@ ScalarIndex<bool>::BuildWithRawData(size_t n, const void* values, const Config&
|
|||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<int8_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<int8_t>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
auto data = reinterpret_cast<int8_t*>(const_cast<void*>(values));
|
||||
Build(n, data);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<int16_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<int16_t>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
auto data = reinterpret_cast<int16_t*>(const_cast<void*>(values));
|
||||
Build(n, data);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<int32_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<int32_t>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
auto data = reinterpret_cast<int32_t*>(const_cast<void*>(values));
|
||||
Build(n, data);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<int64_t>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<int64_t>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
auto data = reinterpret_cast<int64_t*>(const_cast<void*>(values));
|
||||
Build(n, data);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<float>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<float>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
auto data = reinterpret_cast<float*>(const_cast<void*>(values));
|
||||
Build(n, data);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ScalarIndex<double>::BuildWithRawData(size_t n, const void* values, const Config& config) {
|
||||
ScalarIndex<double>::BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config) {
|
||||
auto data = reinterpret_cast<double*>(const_cast<void*>(values));
|
||||
Build(n, data);
|
||||
}
|
||||
|
|
|
@ -31,10 +31,13 @@ template <typename T>
|
|||
class ScalarIndex : public IndexBase {
|
||||
public:
|
||||
void
|
||||
BuildWithRawData(size_t n, const void* values, const Config& config = {}) override;
|
||||
BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config = {}) override;
|
||||
|
||||
void
|
||||
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override {
|
||||
BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config = {}) override {
|
||||
PanicInfo("scalar index don't support build index with dataset");
|
||||
};
|
||||
|
||||
|
@ -52,7 +55,10 @@ class ScalarIndex : public IndexBase {
|
|||
Range(T value, OpType op) = 0;
|
||||
|
||||
virtual const TargetBitmapPtr
|
||||
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) = 0;
|
||||
Range(T lower_bound_value,
|
||||
bool lb_inclusive,
|
||||
T upper_bound_value,
|
||||
bool ub_inclusive) = 0;
|
||||
|
||||
virtual T
|
||||
Reverse_Lookup(size_t offset) const = 0;
|
||||
|
|
|
@ -32,7 +32,8 @@ inline ScalarIndexSort<T>::ScalarIndexSort() : is_built_(false), data_() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline ScalarIndexSort<T>::ScalarIndexSort(const size_t n, const T* values) : is_built_(false) {
|
||||
inline ScalarIndexSort<T>::ScalarIndexSort(const size_t n, const T* values)
|
||||
: is_built_(false) {
|
||||
ScalarIndexSort<T>::BuildWithDataset(n, values);
|
||||
}
|
||||
|
||||
|
@ -43,7 +44,8 @@ ScalarIndexSort<T>::Build(const size_t n, const T* values) {
|
|||
return;
|
||||
if (n == 0) {
|
||||
// todo: throw an exception
|
||||
throw std::invalid_argument("ScalarIndexSort cannot build null values!");
|
||||
throw std::invalid_argument(
|
||||
"ScalarIndexSort cannot build null values!");
|
||||
}
|
||||
data_.reserve(n);
|
||||
idx_to_offsets_.resize(n);
|
||||
|
@ -104,12 +106,15 @@ ScalarIndexSort<T>::In(const size_t n, const T* values) {
|
|||
AssertInfo(is_built_, "index has not been built");
|
||||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
auto lb = std::lower_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
auto ub = std::upper_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
for (; lb < ub; ++lb) {
|
||||
if (lb->a_ != *(values + i)) {
|
||||
std::cout << "error happens in ScalarIndexSort<T>::In, experted value is: " << *(values + i)
|
||||
<< ", but real value is: " << lb->a_;
|
||||
std::cout << "error happens in ScalarIndexSort<T>::In, "
|
||||
"experted value is: "
|
||||
<< *(values + i) << ", but real value is: " << lb->a_;
|
||||
}
|
||||
bitset->set(lb->idx_);
|
||||
}
|
||||
|
@ -124,12 +129,15 @@ ScalarIndexSort<T>::NotIn(const size_t n, const T* values) {
|
|||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
|
||||
bitset->set();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
auto lb = std::lower_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
auto ub = std::upper_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
|
||||
for (; lb < ub; ++lb) {
|
||||
if (lb->a_ != *(values + i)) {
|
||||
std::cout << "error happens in ScalarIndexSort<T>::NotIn, experted value is: " << *(values + i)
|
||||
<< ", but real value is: " << lb->a_;
|
||||
std::cout << "error happens in ScalarIndexSort<T>::NotIn, "
|
||||
"experted value is: "
|
||||
<< *(values + i) << ", but real value is: " << lb->a_;
|
||||
}
|
||||
bitset->reset(lb->idx_);
|
||||
}
|
||||
|
@ -146,19 +154,24 @@ ScalarIndexSort<T>::Range(const T value, const OpType op) {
|
|||
auto ub = data_.end();
|
||||
switch (op) {
|
||||
case OpType::LessThan:
|
||||
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
ub = std::lower_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
break;
|
||||
case OpType::LessEqual:
|
||||
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
ub = std::upper_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
break;
|
||||
case OpType::GreaterThan:
|
||||
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
lb = std::upper_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
break;
|
||||
case OpType::GreaterEqual:
|
||||
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
lb = std::lower_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(value));
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(std::string("Invalid OperatorType: ") + std::to_string((int)op) + "!");
|
||||
throw std::invalid_argument(std::string("Invalid OperatorType: ") +
|
||||
std::to_string((int)op) + "!");
|
||||
}
|
||||
for (; lb < ub; ++lb) {
|
||||
bitset->set(lb->idx_);
|
||||
|
@ -168,24 +181,32 @@ ScalarIndexSort<T>::Range(const T value, const OpType op) {
|
|||
|
||||
template <typename T>
|
||||
inline const TargetBitmapPtr
|
||||
ScalarIndexSort<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
|
||||
ScalarIndexSort<T>::Range(T lower_bound_value,
|
||||
bool lb_inclusive,
|
||||
T upper_bound_value,
|
||||
bool ub_inclusive) {
|
||||
AssertInfo(is_built_, "index has not been built");
|
||||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
|
||||
if (lower_bound_value > upper_bound_value ||
|
||||
(lower_bound_value == upper_bound_value && !(lb_inclusive && ub_inclusive))) {
|
||||
(lower_bound_value == upper_bound_value &&
|
||||
!(lb_inclusive && ub_inclusive))) {
|
||||
return bitset;
|
||||
}
|
||||
auto lb = data_.begin();
|
||||
auto ub = data_.end();
|
||||
if (lb_inclusive) {
|
||||
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
|
||||
lb = std::lower_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
|
||||
} else {
|
||||
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
|
||||
lb = std::upper_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
|
||||
}
|
||||
if (ub_inclusive) {
|
||||
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
|
||||
ub = std::upper_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
|
||||
} else {
|
||||
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
|
||||
ub = std::lower_bound(
|
||||
data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
|
||||
}
|
||||
for (; lb < ub; ++lb) {
|
||||
bitset->set(lb->idx_);
|
||||
|
|
|
@ -56,7 +56,10 @@ class ScalarIndexSort : public ScalarIndex<T> {
|
|||
Range(T value, OpType op) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override;
|
||||
Range(T lower_bound_value,
|
||||
bool lb_inclusive,
|
||||
T upper_bound_value,
|
||||
bool ub_inclusive) override;
|
||||
|
||||
T
|
||||
Reverse_Lookup(size_t offset) const override;
|
||||
|
|
|
@ -63,7 +63,8 @@ StringIndexMarisa::Serialize(const Config& config) {
|
|||
auto uuid_string = boost::uuids::to_string(uuid);
|
||||
auto file = std::string("/tmp/") + uuid_string;
|
||||
|
||||
auto fd = open(file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
auto fd = open(
|
||||
file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
trie_.write(fd);
|
||||
|
||||
auto size = get_file_size(fd);
|
||||
|
@ -101,7 +102,8 @@ StringIndexMarisa::Load(const BinarySet& set, const Config& config) {
|
|||
auto index = set.GetByName(MARISA_TRIE_INDEX);
|
||||
auto len = index->size;
|
||||
|
||||
auto fd = open(file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
auto fd = open(
|
||||
file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
lseek(fd, 0, SEEK_SET);
|
||||
while (write(fd, index->data.get(), len) != len) {
|
||||
lseek(fd, 0, SEEK_SET);
|
||||
|
@ -182,7 +184,9 @@ StringIndexMarisa::Range(std::string value, OpType op) {
|
|||
set = raw_data.compare(value) >= 0;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(std::string("Invalid OperatorType: ") + std::to_string((int)op) + "!");
|
||||
throw std::invalid_argument(
|
||||
std::string("Invalid OperatorType: ") +
|
||||
std::to_string((int)op) + "!");
|
||||
}
|
||||
if (set) {
|
||||
bitset->set(offset);
|
||||
|
@ -199,7 +203,8 @@ StringIndexMarisa::Range(std::string lower_bound_value,
|
|||
auto count = Count();
|
||||
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(count);
|
||||
if (lower_bound_value.compare(upper_bound_value) > 0 ||
|
||||
(lower_bound_value.compare(upper_bound_value) == 0 && !(lb_inclusive && ub_inclusive))) {
|
||||
(lower_bound_value.compare(upper_bound_value) == 0 &&
|
||||
!(lb_inclusive && ub_inclusive))) {
|
||||
return bitset;
|
||||
}
|
||||
marisa::Agent agent;
|
||||
|
|
|
@ -58,7 +58,10 @@ class StringIndexMarisa : public StringIndex {
|
|||
Range(std::string value, OpType op) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
Range(std::string lower_bound_value, bool lb_inclusive, std::string upper_bound_value, bool ub_inclusive) override;
|
||||
Range(std::string lower_bound_value,
|
||||
bool lb_inclusive,
|
||||
std::string upper_bound_value,
|
||||
bool ub_inclusive) override;
|
||||
|
||||
const TargetBitmapPtr
|
||||
PrefixMatch(std::string prefix) override;
|
||||
|
|
|
@ -62,7 +62,8 @@ DISK_LIST() {
|
|||
std::vector<std::tuple<IndexType, MetricType>>
|
||||
unsupported_index_combinations() {
|
||||
static std::vector<std::tuple<IndexType, MetricType>> ret{
|
||||
std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, knowhere::metric::L2),
|
||||
std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
|
||||
knowhere::metric::L2),
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
|
@ -84,8 +85,9 @@ is_in_disk_list(const IndexType& index_type) {
|
|||
|
||||
bool
|
||||
is_unsupported(const IndexType& index_type, const MetricType& metric_type) {
|
||||
return is_in_list<std::tuple<IndexType, MetricType>>(std::make_tuple(index_type, metric_type),
|
||||
unsupported_index_combinations);
|
||||
return is_in_list<std::tuple<IndexType, MetricType>>(
|
||||
std::make_tuple(index_type, metric_type),
|
||||
unsupported_index_combinations);
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -123,7 +125,8 @@ GetIndexTypeFromConfig(const Config& config) {
|
|||
IndexMode
|
||||
GetIndexModeFromConfig(const Config& config) {
|
||||
auto mode = GetValueFromConfig<std::string>(config, INDEX_MODE);
|
||||
return mode.has_value() ? GetIndexMode(mode.value()) : knowhere::IndexMode::MODE_CPU;
|
||||
return mode.has_value() ? GetIndexMode(mode.value())
|
||||
: knowhere::IndexMode::MODE_CPU;
|
||||
}
|
||||
|
||||
IndexMode
|
||||
|
@ -144,22 +147,28 @@ storage::FieldDataMeta
|
|||
GetFieldDataMetaFromConfig(const Config& config) {
|
||||
storage::FieldDataMeta field_data_meta;
|
||||
// set collection id
|
||||
auto collection_id = index::GetValueFromConfig<std::string>(config, index::COLLECTION_ID);
|
||||
AssertInfo(collection_id.has_value(), "collection id not exist in index config");
|
||||
auto collection_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::COLLECTION_ID);
|
||||
AssertInfo(collection_id.has_value(),
|
||||
"collection id not exist in index config");
|
||||
field_data_meta.collection_id = std::stol(collection_id.value());
|
||||
|
||||
// set partition id
|
||||
auto partition_id = index::GetValueFromConfig<std::string>(config, index::PARTITION_ID);
|
||||
AssertInfo(partition_id.has_value(), "partition id not exist in index config");
|
||||
auto partition_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::PARTITION_ID);
|
||||
AssertInfo(partition_id.has_value(),
|
||||
"partition id not exist in index config");
|
||||
field_data_meta.partition_id = std::stol(partition_id.value());
|
||||
|
||||
// set segment id
|
||||
auto segment_id = index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
|
||||
auto segment_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
|
||||
AssertInfo(segment_id.has_value(), "segment id not exist in index config");
|
||||
field_data_meta.segment_id = std::stol(segment_id.value());
|
||||
|
||||
// set field id
|
||||
auto field_id = index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
|
||||
auto field_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
|
||||
AssertInfo(field_id.has_value(), "field id not exist in index config");
|
||||
field_data_meta.field_id = std::stol(field_id.value());
|
||||
|
||||
|
@ -170,22 +179,27 @@ storage::IndexMeta
|
|||
GetIndexMetaFromConfig(const Config& config) {
|
||||
storage::IndexMeta index_meta;
|
||||
// set segment id
|
||||
auto segment_id = index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
|
||||
auto segment_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::SEGMENT_ID);
|
||||
AssertInfo(segment_id.has_value(), "segment id not exist in index config");
|
||||
index_meta.segment_id = std::stol(segment_id.value());
|
||||
|
||||
// set field id
|
||||
auto field_id = index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
|
||||
auto field_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::FIELD_ID);
|
||||
AssertInfo(field_id.has_value(), "field id not exist in index config");
|
||||
index_meta.field_id = std::stol(field_id.value());
|
||||
|
||||
// set index version
|
||||
auto index_version = index::GetValueFromConfig<std::string>(config, index::INDEX_VERSION);
|
||||
AssertInfo(index_version.has_value(), "index_version id not exist in index config");
|
||||
auto index_version =
|
||||
index::GetValueFromConfig<std::string>(config, index::INDEX_VERSION);
|
||||
AssertInfo(index_version.has_value(),
|
||||
"index_version id not exist in index config");
|
||||
index_meta.index_version = std::stol(index_version.value());
|
||||
|
||||
// set index id
|
||||
auto build_id = index::GetValueFromConfig<std::string>(config, index::INDEX_BUILD_ID);
|
||||
auto build_id =
|
||||
index::GetValueFromConfig<std::string>(config, index::INDEX_BUILD_ID);
|
||||
AssertInfo(build_id.has_value(), "build id not exist in index config");
|
||||
index_meta.build_id = std::stol(build_id.value());
|
||||
|
||||
|
@ -193,7 +207,8 @@ GetIndexMetaFromConfig(const Config& config) {
|
|||
}
|
||||
|
||||
Config
|
||||
ParseConfigFromIndexParams(const std::map<std::string, std::string>& index_params) {
|
||||
ParseConfigFromIndexParams(
|
||||
const std::map<std::string, std::string>& index_params) {
|
||||
Config config;
|
||||
for (auto& p : index_params) {
|
||||
config[p.first] = p.second;
|
||||
|
|
|
@ -121,6 +121,7 @@ storage::IndexMeta
|
|||
GetIndexMetaFromConfig(const Config& config);
|
||||
|
||||
Config
|
||||
ParseConfigFromIndexParams(const std::map<std::string, std::string>& index_params);
|
||||
ParseConfigFromIndexParams(
|
||||
const std::map<std::string, std::string>& index_params);
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -35,12 +35,14 @@ namespace milvus::index {
|
|||
#define kPrepareRows 1
|
||||
|
||||
template <typename T>
|
||||
VectorDiskAnnIndex<T>::VectorDiskAnnIndex(const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexMode& index_mode,
|
||||
storage::FileManagerImplPtr file_manager)
|
||||
VectorDiskAnnIndex<T>::VectorDiskAnnIndex(
|
||||
const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexMode& index_mode,
|
||||
storage::FileManagerImplPtr file_manager)
|
||||
: VectorIndex(index_type, index_mode, metric_type) {
|
||||
file_manager_ = std::dynamic_pointer_cast<storage::DiskFileManagerImpl>(file_manager);
|
||||
file_manager_ =
|
||||
std::dynamic_pointer_cast<storage::DiskFileManagerImpl>(file_manager);
|
||||
auto& local_chunk_manager = storage::LocalChunkManager::GetInstance();
|
||||
auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
|
||||
|
||||
|
@ -52,17 +54,22 @@ VectorDiskAnnIndex<T>::VectorDiskAnnIndex(const IndexType& index_type,
|
|||
}
|
||||
|
||||
local_chunk_manager.CreateDir(local_index_path_prefix);
|
||||
auto diskann_index_pack = knowhere::Pack(std::shared_ptr<knowhere::FileManager>(file_manager));
|
||||
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(), diskann_index_pack);
|
||||
auto diskann_index_pack =
|
||||
knowhere::Pack(std::shared_ptr<knowhere::FileManager>(file_manager));
|
||||
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(),
|
||||
diskann_index_pack);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */, const Config& config) {
|
||||
VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */,
|
||||
const Config& config) {
|
||||
knowhere::Json load_config = update_load_json(config);
|
||||
|
||||
auto index_files = GetValueFromConfig<std::vector<std::string>>(config, "index_files");
|
||||
AssertInfo(index_files.has_value(), "index file paths is empty when load disk ann index data");
|
||||
auto index_files =
|
||||
GetValueFromConfig<std::vector<std::string>>(config, "index_files");
|
||||
AssertInfo(index_files.has_value(),
|
||||
"index file paths is empty when load disk ann index data");
|
||||
file_manager_->CacheIndexToDisk(index_files.value());
|
||||
|
||||
// todo : replace by index::load function later
|
||||
|
@ -79,21 +86,25 @@ VectorDiskAnnIndex<T>::Load(const BinarySet& binary_set /* not used */, const Co
|
|||
|
||||
template <typename T>
|
||||
void
|
||||
VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset, const Config& config) {
|
||||
VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config) {
|
||||
auto& local_chunk_manager = storage::LocalChunkManager::GetInstance();
|
||||
knowhere::Json build_config;
|
||||
build_config.update(config);
|
||||
// set data path
|
||||
auto segment_id = file_manager_->GetFileDataMeta().segment_id;
|
||||
auto field_id = file_manager_->GetFileDataMeta().field_id;
|
||||
auto local_data_path = storage::GenFieldRawDataPathPrefix(segment_id, field_id) + "raw_data";
|
||||
auto local_data_path =
|
||||
storage::GenFieldRawDataPathPrefix(segment_id, field_id) + "raw_data";
|
||||
build_config[DISK_ANN_RAW_DATA_PATH] = local_data_path;
|
||||
|
||||
auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
|
||||
build_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;
|
||||
|
||||
auto num_threads = GetValueFromConfig<std::string>(build_config, DISK_ANN_BUILD_THREAD_NUM);
|
||||
AssertInfo(num_threads.has_value(), "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty");
|
||||
auto num_threads = GetValueFromConfig<std::string>(
|
||||
build_config, DISK_ANN_BUILD_THREAD_NUM);
|
||||
AssertInfo(num_threads.has_value(),
|
||||
"param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty");
|
||||
build_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str());
|
||||
|
||||
if (!local_chunk_manager.Exist(local_data_path)) {
|
||||
|
@ -116,14 +127,17 @@ VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset, const Config&
|
|||
knowhere::DataSet* ds_ptr = nullptr;
|
||||
index_.Build(*ds_ptr, build_config);
|
||||
|
||||
local_chunk_manager.RemoveDir(storage::GetSegmentRawDataPathPrefix(segment_id));
|
||||
local_chunk_manager.RemoveDir(
|
||||
storage::GetSegmentRawDataPathPrefix(segment_id));
|
||||
// TODO ::
|
||||
// SetDim(index_->Dim());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<SearchResult>
|
||||
VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) {
|
||||
VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) {
|
||||
AssertInfo(GetMetricType() == search_info.metric_type_,
|
||||
"Metric type of field index isn't the same with search info");
|
||||
auto num_queries = dataset->GetRows();
|
||||
|
@ -135,12 +149,18 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
|
|||
search_config[knowhere::meta::METRIC_TYPE] = GetMetricType();
|
||||
|
||||
// set search list size
|
||||
auto search_list_size = GetValueFromConfig<uint32_t>(search_info.search_params_, DISK_ANN_QUERY_LIST);
|
||||
AssertInfo(search_list_size.has_value(), "param " + std::string(DISK_ANN_QUERY_LIST) + "is empty");
|
||||
AssertInfo(search_list_size.value() >= topk, "search_list should be greater than or equal to topk");
|
||||
AssertInfo(search_list_size.value() <= std::max(uint32_t(topk * 10), uint32_t(kSearchListMaxValue1)) &&
|
||||
search_list_size.value() <= uint32_t(kSearchListMaxValue2),
|
||||
"search_list should be less than max(topk*10, 200) and less than 65535");
|
||||
auto search_list_size = GetValueFromConfig<uint32_t>(
|
||||
search_info.search_params_, DISK_ANN_QUERY_LIST);
|
||||
AssertInfo(search_list_size.has_value(),
|
||||
"param " + std::string(DISK_ANN_QUERY_LIST) + "is empty");
|
||||
AssertInfo(search_list_size.value() >= topk,
|
||||
"search_list should be greater than or equal to topk");
|
||||
AssertInfo(
|
||||
search_list_size.value() <=
|
||||
std::max(uint32_t(topk * 10), uint32_t(kSearchListMaxValue1)) &&
|
||||
search_list_size.value() <= uint32_t(kSearchListMaxValue2),
|
||||
"search_list should be less than max(topk*10, 200) and less than "
|
||||
"65535");
|
||||
search_config[DISK_ANN_SEARCH_LIST_SIZE] = search_list_size.value();
|
||||
|
||||
// set beamwidth
|
||||
|
@ -154,25 +174,33 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
|
|||
search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0;
|
||||
|
||||
auto final = [&] {
|
||||
auto radius = GetValueFromConfig<float>(search_info.search_params_, RADIUS);
|
||||
auto radius =
|
||||
GetValueFromConfig<float>(search_info.search_params_, RADIUS);
|
||||
if (radius.has_value()) {
|
||||
search_config[RADIUS] = radius.value();
|
||||
auto range_filter = GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
|
||||
auto range_filter = GetValueFromConfig<float>(
|
||||
search_info.search_params_, RANGE_FILTER);
|
||||
if (range_filter.has_value()) {
|
||||
search_config[RANGE_FILTER] = range_filter.value();
|
||||
CheckRangeSearchParam(search_config[RADIUS], search_config[RANGE_FILTER], GetMetricType());
|
||||
CheckRangeSearchParam(search_config[RADIUS],
|
||||
search_config[RANGE_FILTER],
|
||||
GetMetricType());
|
||||
}
|
||||
auto res = index_.RangeSearch(*dataset, search_config, bitset);
|
||||
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
|
||||
"failed to range search, " + MatchKnowhereError(res.error()));
|
||||
"failed to range search, " +
|
||||
MatchKnowhereError(res.error()));
|
||||
}
|
||||
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
|
||||
return SortRangeSearchResult(
|
||||
res.value(), topk, num_queries, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(*dataset, search_config, bitset);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search, " + MatchKnowhereError(res.error()));
|
||||
PanicCodeInfo(
|
||||
ErrorCodeEnum::UnexpectedError,
|
||||
"failed to search, " + MatchKnowhereError(res.error()));
|
||||
}
|
||||
return res.value();
|
||||
}
|
||||
|
@ -226,12 +254,15 @@ VectorDiskAnnIndex<T>::update_load_json(const Config& config) {
|
|||
load_config[DISK_ANN_PREPARE_USE_BFS_CACHE] = false;
|
||||
|
||||
// set threads number
|
||||
auto num_threads = GetValueFromConfig<std::string>(load_config, DISK_ANN_LOAD_THREAD_NUM);
|
||||
AssertInfo(num_threads.has_value(), "param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty");
|
||||
auto num_threads =
|
||||
GetValueFromConfig<std::string>(load_config, DISK_ANN_LOAD_THREAD_NUM);
|
||||
AssertInfo(num_threads.has_value(),
|
||||
"param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty");
|
||||
load_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str());
|
||||
|
||||
// update search_beamwidth
|
||||
auto beamwidth = GetValueFromConfig<std::string>(load_config, DISK_ANN_QUERY_BEAMWIDTH);
|
||||
auto beamwidth =
|
||||
GetValueFromConfig<std::string>(load_config, DISK_ANN_QUERY_BEAMWIDTH);
|
||||
if (beamwidth.has_value()) {
|
||||
search_beamwidth_ = std::atoi(beamwidth.value().c_str());
|
||||
}
|
||||
|
|
|
@ -49,13 +49,17 @@ class VectorDiskAnnIndex : public VectorIndex {
|
|||
}
|
||||
|
||||
void
|
||||
Load(const BinarySet& binary_set /* not used */, const Config& config = {}) override;
|
||||
Load(const BinarySet& binary_set /* not used */,
|
||||
const Config& config = {}) override;
|
||||
|
||||
void
|
||||
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override;
|
||||
BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config = {}) override;
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
|
||||
Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) override;
|
||||
|
||||
void
|
||||
CleanLocalData() override;
|
||||
|
|
|
@ -32,18 +32,26 @@ namespace milvus::index {
|
|||
|
||||
class VectorIndex : public IndexBase {
|
||||
public:
|
||||
explicit VectorIndex(const IndexType& index_type, const IndexMode& index_mode, const MetricType& metric_type)
|
||||
: index_type_(index_type), index_mode_(index_mode), metric_type_(metric_type) {
|
||||
explicit VectorIndex(const IndexType& index_type,
|
||||
const IndexMode& index_mode,
|
||||
const MetricType& metric_type)
|
||||
: index_type_(index_type),
|
||||
index_mode_(index_mode),
|
||||
metric_type_(metric_type) {
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
BuildWithRawData(size_t n, const void* values, const Config& config = {}) override {
|
||||
BuildWithRawData(size_t n,
|
||||
const void* values,
|
||||
const Config& config = {}) override {
|
||||
PanicInfo("vector index don't support build index with raw data");
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<SearchResult>
|
||||
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) = 0;
|
||||
Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) = 0;
|
||||
|
||||
IndexType
|
||||
GetIndexType() const {
|
||||
|
|
|
@ -30,9 +30,12 @@
|
|||
|
||||
namespace milvus::index {
|
||||
|
||||
VectorMemIndex::VectorMemIndex(const IndexType& index_type, const MetricType& metric_type, const IndexMode& index_mode)
|
||||
VectorMemIndex::VectorMemIndex(const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexMode& index_mode)
|
||||
: VectorIndex(index_type, index_mode, metric_type) {
|
||||
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
|
||||
AssertInfo(!is_unsupported(index_type, metric_type),
|
||||
index_type + " doesn't support metric: " + metric_type);
|
||||
|
||||
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType());
|
||||
}
|
||||
|
@ -42,7 +45,8 @@ VectorMemIndex::Serialize(const Config& config) {
|
|||
knowhere::BinarySet ret;
|
||||
auto stat = index_.Serialize(ret);
|
||||
if (stat != knowhere::Status::success)
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to serialize index, " + MatchKnowhereError(stat));
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
|
||||
"failed to serialize index, " + MatchKnowhereError(stat));
|
||||
milvus::Disassemble(ret);
|
||||
|
||||
return ret;
|
||||
|
@ -53,12 +57,15 @@ VectorMemIndex::Load(const BinarySet& binary_set, const Config& config) {
|
|||
milvus::Assemble(const_cast<BinarySet&>(binary_set));
|
||||
auto stat = index_.Deserialize(binary_set);
|
||||
if (stat != knowhere::Status::success)
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to Deserialize index, " + MatchKnowhereError(stat));
|
||||
PanicCodeInfo(
|
||||
ErrorCodeEnum::UnexpectedError,
|
||||
"failed to Deserialize index, " + MatchKnowhereError(stat));
|
||||
SetDim(index_.Dim());
|
||||
}
|
||||
|
||||
void
|
||||
VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, const Config& config) {
|
||||
VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config) {
|
||||
knowhere::Json index_config;
|
||||
index_config.update(config);
|
||||
|
||||
|
@ -67,13 +74,16 @@ VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, const Config& config
|
|||
knowhere::TimeRecorder rc("BuildWithoutIds", 1);
|
||||
auto stat = index_.Build(*dataset, index_config);
|
||||
if (stat != knowhere::Status::success)
|
||||
PanicCodeInfo(ErrorCodeEnum::BuildIndexError, "failed to build index, " + MatchKnowhereError(stat));
|
||||
PanicCodeInfo(ErrorCodeEnum::BuildIndexError,
|
||||
"failed to build index, " + MatchKnowhereError(stat));
|
||||
rc.ElapseFromBegin("Done");
|
||||
SetDim(index_.Dim());
|
||||
}
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
VectorMemIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) {
|
||||
VectorMemIndex::Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) {
|
||||
// AssertInfo(GetMetricType() == search_info.metric_type_,
|
||||
// "Metric type of field index isn't the same with search info");
|
||||
|
||||
|
@ -87,18 +97,24 @@ VectorMemIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, c
|
|||
auto index_type = GetIndexType();
|
||||
if (CheckKeyInConfig(search_conf, RADIUS)) {
|
||||
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
|
||||
CheckRangeSearchParam(search_conf[RADIUS], search_conf[RANGE_FILTER], GetMetricType());
|
||||
CheckRangeSearchParam(search_conf[RADIUS],
|
||||
search_conf[RANGE_FILTER],
|
||||
GetMetricType());
|
||||
}
|
||||
auto res = index_.RangeSearch(*dataset, search_conf, bitset);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
|
||||
"failed to range search, " + MatchKnowhereError(res.error()));
|
||||
"failed to range search, " +
|
||||
MatchKnowhereError(res.error()));
|
||||
}
|
||||
return SortRangeSearchResult(res.value(), topk, num_queries, GetMetricType());
|
||||
return SortRangeSearchResult(
|
||||
res.value(), topk, num_queries, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(*dataset, search_conf, bitset);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to search, " + MatchKnowhereError(res.error()));
|
||||
PanicCodeInfo(
|
||||
ErrorCodeEnum::UnexpectedError,
|
||||
"failed to search, " + MatchKnowhereError(res.error()));
|
||||
}
|
||||
return res.value();
|
||||
}
|
||||
|
|
|
@ -28,7 +28,9 @@ namespace milvus::index {
|
|||
|
||||
class VectorMemIndex : public VectorIndex {
|
||||
public:
|
||||
explicit VectorMemIndex(const IndexType& index_type, const MetricType& metric_type, const IndexMode& index_mode);
|
||||
explicit VectorMemIndex(const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexMode& index_mode);
|
||||
|
||||
BinarySet
|
||||
Serialize(const Config& config) override;
|
||||
|
@ -37,7 +39,8 @@ class VectorMemIndex : public VectorIndex {
|
|||
Load(const BinarySet& binary_set, const Config& config = {}) override;
|
||||
|
||||
void
|
||||
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override;
|
||||
BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config = {}) override;
|
||||
|
||||
int64_t
|
||||
Count() override {
|
||||
|
@ -45,7 +48,9 @@ class VectorMemIndex : public VectorIndex {
|
|||
}
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
|
||||
Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) override;
|
||||
|
||||
protected:
|
||||
Config config_;
|
||||
|
|
|
@ -31,10 +31,12 @@ VectorMemNMIndex::Serialize(const Config& config) {
|
|||
knowhere::BinarySet ret;
|
||||
auto stat = index_.Serialize(ret);
|
||||
if (stat != knowhere::Status::success)
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to serialize index, " + MatchKnowhereError(stat));
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
|
||||
"failed to serialize index, " + MatchKnowhereError(stat));
|
||||
|
||||
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
|
||||
auto raw_data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
|
||||
auto raw_data = std::shared_ptr<uint8_t[]>(
|
||||
static_cast<uint8_t*>(raw_data_.data()), deleter);
|
||||
ret.Append(RAW_DATA, raw_data, raw_data_.size());
|
||||
milvus::Disassemble(ret);
|
||||
|
||||
|
@ -42,7 +44,8 @@ VectorMemNMIndex::Serialize(const Config& config) {
|
|||
}
|
||||
|
||||
void
|
||||
VectorMemNMIndex::BuildWithDataset(const DatasetPtr& dataset, const Config& config) {
|
||||
VectorMemNMIndex::BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config) {
|
||||
VectorMemIndex::BuildWithDataset(dataset, config);
|
||||
knowhere::TimeRecorder rc("store_raw_data", 1);
|
||||
store_raw_data(dataset);
|
||||
|
@ -53,12 +56,16 @@ void
|
|||
VectorMemNMIndex::Load(const BinarySet& binary_set, const Config& config) {
|
||||
VectorMemIndex::Load(binary_set, config);
|
||||
if (binary_set.Contains(RAW_DATA)) {
|
||||
std::call_once(raw_data_loaded_, [&]() { LOG_SEGCORE_INFO_C << "NM index load raw data done!"; });
|
||||
std::call_once(raw_data_loaded_, [&]() {
|
||||
LOG_SEGCORE_INFO_C << "NM index load raw data done!";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
VectorMemNMIndex::Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) {
|
||||
VectorMemNMIndex::Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) {
|
||||
auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer
|
||||
// load -> query, raw data has been loaded
|
||||
// build -> query, this case just for test, should load raw data before query
|
||||
|
@ -88,16 +95,20 @@ VectorMemNMIndex::LoadRawData() {
|
|||
knowhere::BinarySet bs;
|
||||
auto stat = index_.Serialize(bs);
|
||||
if (stat != knowhere::Status::success)
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to Serialize index, " + MatchKnowhereError(stat));
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
|
||||
"failed to Serialize index, " + MatchKnowhereError(stat));
|
||||
|
||||
auto bptr = std::make_shared<knowhere::Binary>();
|
||||
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
|
||||
bptr->data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
|
||||
bptr->data = std::shared_ptr<uint8_t[]>(
|
||||
static_cast<uint8_t*>(raw_data_.data()), deleter);
|
||||
bptr->size = raw_data_.size();
|
||||
bs.Append(RAW_DATA, bptr);
|
||||
stat = index_.Deserialize(bs);
|
||||
if (stat != knowhere::Status::success)
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to Deserialize index, " + MatchKnowhereError(stat));
|
||||
PanicCodeInfo(
|
||||
ErrorCodeEnum::UnexpectedError,
|
||||
"failed to Deserialize index, " + MatchKnowhereError(stat));
|
||||
}
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -28,7 +28,9 @@ namespace milvus::index {
|
|||
|
||||
class VectorMemNMIndex : public VectorMemIndex {
|
||||
public:
|
||||
explicit VectorMemNMIndex(const IndexType& index_type, const MetricType& metric_type, const IndexMode& index_mode)
|
||||
explicit VectorMemNMIndex(const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexMode& index_mode)
|
||||
: VectorMemIndex(index_type, metric_type, index_mode) {
|
||||
AssertInfo(is_in_nm_list(index_type), "not valid nm index type");
|
||||
}
|
||||
|
@ -37,13 +39,16 @@ class VectorMemNMIndex : public VectorMemIndex {
|
|||
Serialize(const Config& config) override;
|
||||
|
||||
void
|
||||
BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override;
|
||||
BuildWithDataset(const DatasetPtr& dataset,
|
||||
const Config& config = {}) override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& binary_set, const Config& config = {}) override;
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
|
||||
Query(const DatasetPtr dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) override;
|
||||
|
||||
private:
|
||||
void
|
||||
|
|
|
@ -45,7 +45,8 @@ class IndexFactory {
|
|||
const char* index_params,
|
||||
const storage::StorageConfig& storage_config) {
|
||||
auto real_dtype = DataType(dtype);
|
||||
auto invalid_dtype_msg = std::string("invalid data type: ") + std::to_string(int(real_dtype));
|
||||
auto invalid_dtype_msg = std::string("invalid data type: ") +
|
||||
std::to_string(int(real_dtype));
|
||||
|
||||
switch (real_dtype) {
|
||||
case DataType::BOOL:
|
||||
|
@ -61,7 +62,8 @@ class IndexFactory {
|
|||
|
||||
case DataType::VECTOR_FLOAT:
|
||||
case DataType::VECTOR_BINARY:
|
||||
return std::make_unique<VecIndexCreator>(real_dtype, type_params, index_params, storage_config);
|
||||
return std::make_unique<VecIndexCreator>(
|
||||
real_dtype, type_params, index_params, storage_config);
|
||||
default:
|
||||
throw std::invalid_argument(invalid_dtype_msg);
|
||||
}
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
|
||||
namespace milvus::indexbuilder {
|
||||
|
||||
ScalarIndexCreator::ScalarIndexCreator(DataType dtype, const char* type_params, const char* index_params)
|
||||
ScalarIndexCreator::ScalarIndexCreator(DataType dtype,
|
||||
const char* type_params,
|
||||
const char* index_params)
|
||||
: dtype_(dtype) {
|
||||
// TODO: move parse-related logic to a common interface.
|
||||
proto::indexcgo::TypeParams type_params_;
|
||||
|
@ -42,7 +44,8 @@ ScalarIndexCreator::ScalarIndexCreator(DataType dtype, const char* type_params,
|
|||
index_info.field_type = dtype_;
|
||||
index_info.index_type = index_type();
|
||||
index_info.index_mode = IndexMode::MODE_CPU;
|
||||
index_ = index::IndexFactory::GetInstance().CreateIndex(index_info, nullptr);
|
||||
index_ =
|
||||
index::IndexFactory::GetInstance().CreateIndex(index_info, nullptr);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -22,7 +22,9 @@ namespace milvus::indexbuilder {
|
|||
|
||||
class ScalarIndexCreator : public IndexCreatorBase {
|
||||
public:
|
||||
ScalarIndexCreator(DataType data_type, const char* type_params, const char* index_params);
|
||||
ScalarIndexCreator(DataType data_type,
|
||||
const char* type_params,
|
||||
const char* index_params);
|
||||
|
||||
void
|
||||
Build(const milvus::DatasetPtr& dataset) override;
|
||||
|
@ -46,8 +48,11 @@ class ScalarIndexCreator : public IndexCreatorBase {
|
|||
using ScalarIndexCreatorPtr = std::unique_ptr<ScalarIndexCreator>;
|
||||
|
||||
inline ScalarIndexCreatorPtr
|
||||
CreateScalarIndex(DataType dtype, const char* type_params, const char* index_params) {
|
||||
return std::make_unique<ScalarIndexCreator>(dtype, type_params, index_params);
|
||||
CreateScalarIndex(DataType dtype,
|
||||
const char* type_params,
|
||||
const char* index_params) {
|
||||
return std::make_unique<ScalarIndexCreator>(
|
||||
dtype, type_params, index_params);
|
||||
}
|
||||
|
||||
} // namespace milvus::indexbuilder
|
||||
|
|
|
@ -30,8 +30,10 @@ VecIndexCreator::VecIndexCreator(DataType data_type,
|
|||
: data_type_(data_type) {
|
||||
proto::indexcgo::TypeParams type_params_;
|
||||
proto::indexcgo::IndexParams index_params_;
|
||||
milvus::index::ParseFromString(type_params_, std::string(serialized_type_params));
|
||||
milvus::index::ParseFromString(index_params_, std::string(serialized_index_params));
|
||||
milvus::index::ParseFromString(type_params_,
|
||||
std::string(serialized_type_params));
|
||||
milvus::index::ParseFromString(index_params_,
|
||||
std::string(serialized_index_params));
|
||||
|
||||
for (auto i = 0; i < type_params_.params_size(); ++i) {
|
||||
const auto& param = type_params_.params(i);
|
||||
|
@ -54,12 +56,16 @@ VecIndexCreator::VecIndexCreator(DataType data_type,
|
|||
if (index::is_in_disk_list(index_info.index_type)) {
|
||||
// For now, only support diskann index
|
||||
file_manager = std::make_shared<storage::DiskFileManagerImpl>(
|
||||
index::GetFieldDataMetaFromConfig(config_), index::GetIndexMetaFromConfig(config_), storage_config);
|
||||
index::GetFieldDataMetaFromConfig(config_),
|
||||
index::GetIndexMetaFromConfig(config_),
|
||||
storage_config);
|
||||
}
|
||||
#endif
|
||||
|
||||
index_ = index::IndexFactory::GetInstance().CreateIndex(index_info, file_manager);
|
||||
AssertInfo(index_ != nullptr, "[VecIndexCreator]Index is null after create index");
|
||||
index_ = index::IndexFactory::GetInstance().CreateIndex(index_info,
|
||||
file_manager);
|
||||
AssertInfo(index_ != nullptr,
|
||||
"[VecIndexCreator]Index is null after create index");
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -83,7 +89,9 @@ VecIndexCreator::Load(const milvus::BinarySet& binary_set) {
|
|||
}
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
VecIndexCreator::Query(const milvus::DatasetPtr& dataset, const SearchInfo& search_info, const BitsetView& bitset) {
|
||||
VecIndexCreator::Query(const milvus::DatasetPtr& dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) {
|
||||
auto vector_index = dynamic_cast<index::VectorIndex*>(index_.get());
|
||||
return vector_index->Query(dataset, search_info, bitset);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,9 @@ class VecIndexCreator : public IndexCreatorBase {
|
|||
dim();
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
Query(const milvus::DatasetPtr& dataset, const SearchInfo& search_info, const BitsetView& bitset);
|
||||
Query(const milvus::DatasetPtr& dataset,
|
||||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset);
|
||||
|
||||
public:
|
||||
void
|
||||
|
|
|
@ -39,18 +39,23 @@ CreateIndex(enum CDataType dtype,
|
|||
std::string remote_root_path(c_storage_config.remote_root_path);
|
||||
std::string storage_type(c_storage_config.storage_type);
|
||||
std::string iam_endpoint(c_storage_config.iam_endpoint);
|
||||
auto storage_config = milvus::storage::StorageConfig{address,
|
||||
bucket_name,
|
||||
access_key,
|
||||
access_value,
|
||||
remote_root_path,
|
||||
storage_type,
|
||||
iam_endpoint,
|
||||
c_storage_config.useSSL,
|
||||
c_storage_config.useIAM};
|
||||
auto storage_config =
|
||||
milvus::storage::StorageConfig{address,
|
||||
bucket_name,
|
||||
access_key,
|
||||
access_value,
|
||||
remote_root_path,
|
||||
storage_type,
|
||||
iam_endpoint,
|
||||
c_storage_config.useSSL,
|
||||
c_storage_config.useIAM};
|
||||
|
||||
auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex(
|
||||
dtype, serialized_type_params, serialized_index_params, storage_config);
|
||||
auto index =
|
||||
milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex(
|
||||
dtype,
|
||||
serialized_type_params,
|
||||
serialized_index_params,
|
||||
storage_config);
|
||||
*res_index = index.release();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
|
@ -66,7 +71,8 @@ DeleteIndex(CIndex index) {
|
|||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index, "failed to delete index, passed index was null");
|
||||
auto cIndex = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
delete cIndex;
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
|
@ -78,12 +84,17 @@ DeleteIndex(CIndex index) {
|
|||
}
|
||||
|
||||
CStatus
|
||||
BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors) {
|
||||
BuildFloatVecIndex(CIndex index,
|
||||
int64_t float_value_num,
|
||||
const float* vectors) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index, "failed to build float vector index, passed index was null");
|
||||
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
AssertInfo(index,
|
||||
"failed to build float vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = float_value_num / dim;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
|
@ -101,9 +112,13 @@ CStatus
|
|||
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index, "failed to build binary vector index, passed index was null");
|
||||
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to build binary vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = (data_size * 8) / dim;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
|
@ -126,9 +141,11 @@ CStatus
|
|||
BuildScalarIndex(CIndex c_index, int64_t size, const void* field_data) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(c_index, "failed to build scalar index, passed index was null");
|
||||
AssertInfo(c_index,
|
||||
"failed to build scalar index, passed index was null");
|
||||
|
||||
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(c_index);
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(c_index);
|
||||
const int64_t dim = 8; // not important here
|
||||
auto dataset = knowhere::GenDataSet(size, dim, field_data);
|
||||
real_index->Build(dataset);
|
||||
|
@ -146,9 +163,13 @@ CStatus
|
|||
SerializeIndexToBinarySet(CIndex index, CBinarySet* c_binary_set) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index, "failed to serialize index to binary set, passed index was null");
|
||||
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto binary = std::make_unique<knowhere::BinarySet>(real_index->Serialize());
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to serialize index to binary set, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto binary =
|
||||
std::make_unique<knowhere::BinarySet>(real_index->Serialize());
|
||||
*c_binary_set = binary.release();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
|
@ -163,8 +184,11 @@ CStatus
|
|||
LoadIndexFromBinarySet(CIndex index, CBinarySet c_binary_set) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index, "failed to load index from binary set, passed index was null");
|
||||
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to load index from binary set, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto binary_set = reinterpret_cast<knowhere::BinarySet*>(c_binary_set);
|
||||
real_index->Load(*binary_set);
|
||||
status.error_code = Success;
|
||||
|
@ -180,9 +204,12 @@ CStatus
|
|||
CleanLocalData(CIndex index) {
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index, "failed to build float vector index, passed index was null");
|
||||
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
AssertInfo(index,
|
||||
"failed to build float vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
cIndex->CleanLocalData();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
|
|
|
@ -64,7 +64,7 @@ SetThreadName(const std::string& name) {
|
|||
|
||||
std::string
|
||||
GetThreadName() {
|
||||
std::string thread_name = "unamed";
|
||||
std::string thread_name = "unnamed";
|
||||
char name[16];
|
||||
size_t len = 16;
|
||||
auto err = pthread_getname_np(pthread_self(), name, len);
|
||||
|
@ -108,12 +108,17 @@ get_thread_starttime() {
|
|||
|
||||
int64_t pid = getpid();
|
||||
char filename[256];
|
||||
snprintf(filename, sizeof(filename), "/proc/%lld/task/%lld/stat", (long long)pid, (long long)tid); // NOLINT
|
||||
snprintf(filename,
|
||||
sizeof(filename),
|
||||
"/proc/%lld/task/%lld/stat",
|
||||
(long long)pid,
|
||||
(long long)tid); // NOLINT
|
||||
|
||||
int64_t val = 0;
|
||||
char comm[16], state;
|
||||
FILE* thread_stat = fopen(filename, "r");
|
||||
auto ret = fscanf(thread_stat, "%lld %s %s ", (long long*)&val, comm, &state); // NOLINT
|
||||
auto ret = fscanf(
|
||||
thread_stat, "%lld %s %s ", (long long*)&val, comm, &state); // NOLINT
|
||||
|
||||
for (auto i = 4; i < 23; i++) {
|
||||
ret = fscanf(thread_stat, "%lld ", (long long*)&val); // NOLINT
|
||||
|
@ -131,7 +136,8 @@ get_thread_starttime() {
|
|||
int64_t
|
||||
get_thread_start_timestamp() {
|
||||
try {
|
||||
return get_now_timestamp() - get_system_boottime() + get_thread_starttime();
|
||||
return get_now_timestamp() - get_system_boottime() +
|
||||
get_thread_starttime();
|
||||
} catch (...) {
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -42,33 +42,44 @@
|
|||
|
||||
// Use this macro whenever possible
|
||||
// Depends variables: context Context
|
||||
#define MLOG(level, module, error_code) \
|
||||
LOG(level) << " | " << VAR_REQUEST_ID << " | " << #level << " | " << VAR_COLLECTION_NAME << " | " << VAR_CLIENT_ID \
|
||||
<< " | " << VAR_CLIENT_TAG << " | " << VAR_CLIENT_IPPORT << " | " << VAR_THREAD_ID << " | " \
|
||||
<< VAR_THREAD_START_TIMESTAMP << " | " << VAR_COMMAND_TAG << " | " << #module << " | " << error_code \
|
||||
<< " | "
|
||||
#define MLOG(level, module, error_code) \
|
||||
LOG(level) << " | " << VAR_REQUEST_ID << " | " << #level << " | " \
|
||||
<< VAR_COLLECTION_NAME << " | " << VAR_CLIENT_ID << " | " \
|
||||
<< VAR_CLIENT_TAG << " | " << VAR_CLIENT_IPPORT << " | " \
|
||||
<< VAR_THREAD_ID << " | " << VAR_THREAD_START_TIMESTAMP \
|
||||
<< " | " << VAR_COMMAND_TAG << " | " << #module << " | " \
|
||||
<< error_code << " | "
|
||||
|
||||
// Use in some background process only
|
||||
#define MLOG_(level, module, error_code) \
|
||||
LOG(level) << " | " \
|
||||
<< "" \
|
||||
<< " | " << #level << " | " \
|
||||
<< "" \
|
||||
<< " | " \
|
||||
<< "" \
|
||||
<< " | " \
|
||||
<< "" \
|
||||
<< " | " \
|
||||
<< "" \
|
||||
<< " | " << VAR_THREAD_ID << " | " << VAR_THREAD_START_TIMESTAMP << " | " \
|
||||
<< "" \
|
||||
#define MLOG_(level, module, error_code) \
|
||||
LOG(level) << " | " \
|
||||
<< "" \
|
||||
<< " | " << #level << " | " \
|
||||
<< "" \
|
||||
<< " | " \
|
||||
<< "" \
|
||||
<< " | " \
|
||||
<< "" \
|
||||
<< " | " \
|
||||
<< "" \
|
||||
<< " | " << VAR_THREAD_ID << " | " \
|
||||
<< VAR_THREAD_START_TIMESTAMP << " | " \
|
||||
<< "" \
|
||||
<< " | " << #module << " | " << error_code << " | "
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#define SEGCORE_MODULE_NAME "SEGCORE"
|
||||
#define SEGCORE_MODULE_CLASS_FUNCTION \
|
||||
LogOut("[%s][%s::%s][%s] ", SEGCORE_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str())
|
||||
#define SEGCORE_MODULE_FUNCTION LogOut("[%s][%s][%s] ", SEGCORE_MODULE_NAME, __FUNCTION__, GetThreadName().c_str())
|
||||
LogOut("[%s][%s::%s][%s] ", \
|
||||
SEGCORE_MODULE_NAME, \
|
||||
(typeid(*this).name()), \
|
||||
__FUNCTION__, \
|
||||
GetThreadName().c_str())
|
||||
#define SEGCORE_MODULE_FUNCTION \
|
||||
LogOut("[%s][%s][%s] ", \
|
||||
SEGCORE_MODULE_NAME, \
|
||||
__FUNCTION__, \
|
||||
GetThreadName().c_str())
|
||||
|
||||
#define LOG_SEGCORE_TRACE_C LOG(TRACE) << SEGCORE_MODULE_CLASS_FUNCTION
|
||||
#define LOG_SEGCORE_DEBUG_C LOG(DEBUG) << SEGCORE_MODULE_CLASS_FUNCTION
|
||||
|
@ -87,8 +98,16 @@
|
|||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#define SERVER_MODULE_NAME "SERVER"
|
||||
#define SERVER_MODULE_CLASS_FUNCTION \
|
||||
LogOut("[%s][%s::%s][%s] ", SERVER_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str())
|
||||
#define SERVER_MODULE_FUNCTION LogOut("[%s][%s][%s] ", SERVER_MODULE_NAME, __FUNCTION__, GetThreadName().c_str())
|
||||
LogOut("[%s][%s::%s][%s] ", \
|
||||
SERVER_MODULE_NAME, \
|
||||
(typeid(*this).name()), \
|
||||
__FUNCTION__, \
|
||||
GetThreadName().c_str())
|
||||
#define SERVER_MODULE_FUNCTION \
|
||||
LogOut("[%s][%s][%s] ", \
|
||||
SERVER_MODULE_NAME, \
|
||||
__FUNCTION__, \
|
||||
GetThreadName().c_str())
|
||||
|
||||
#define LOG_SERVER_TRACE_C LOG(TRACE) << SERVER_MODULE_CLASS_FUNCTION
|
||||
#define LOG_SERVER_DEBUG_C LOG(DEBUG) << SERVER_MODULE_CLASS_FUNCTION
|
||||
|
|
|
@ -49,7 +49,8 @@ struct BinaryExprBase : Expr {
|
|||
|
||||
BinaryExprBase() = delete;
|
||||
|
||||
BinaryExprBase(ExprPtr& left, ExprPtr& right) : left_(std::move(left)), right_(std::move(right)) {
|
||||
BinaryExprBase(ExprPtr& left, ExprPtr& right)
|
||||
: left_(std::move(left)), right_(std::move(right)) {
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -66,7 +67,8 @@ struct LogicalUnaryExpr : UnaryExprBase {
|
|||
enum class OpType { Invalid = 0, LogicalNot = 1 };
|
||||
const OpType op_type_;
|
||||
|
||||
LogicalUnaryExpr(const OpType op_type, ExprPtr& child) : UnaryExprBase(child), op_type_(op_type) {
|
||||
LogicalUnaryExpr(const OpType op_type, ExprPtr& child)
|
||||
: UnaryExprBase(child), op_type_(op_type) {
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -76,7 +78,13 @@ struct LogicalUnaryExpr : UnaryExprBase {
|
|||
|
||||
struct LogicalBinaryExpr : BinaryExprBase {
|
||||
// Note: bitA - bitB == bitA & ~bitB, alias to LogicalMinus
|
||||
enum class OpType { Invalid = 0, LogicalAnd = 1, LogicalOr = 2, LogicalXor = 3, LogicalMinus = 4 };
|
||||
enum class OpType {
|
||||
Invalid = 0,
|
||||
LogicalAnd = 1,
|
||||
LogicalOr = 2,
|
||||
LogicalXor = 3,
|
||||
LogicalMinus = 4
|
||||
};
|
||||
const OpType op_type_;
|
||||
|
||||
LogicalBinaryExpr(const OpType op_type, ExprPtr& left, ExprPtr& right)
|
||||
|
@ -93,10 +101,11 @@ struct TermExpr : Expr {
|
|||
const DataType data_type_;
|
||||
|
||||
protected:
|
||||
// prevent accidential instantiation
|
||||
// prevent accidental instantiation
|
||||
TermExpr() = delete;
|
||||
|
||||
TermExpr(const FieldId field_id, const DataType data_type) : field_id_(field_id), data_type_(data_type) {
|
||||
TermExpr(const FieldId field_id, const DataType data_type)
|
||||
: field_id_(field_id), data_type_(data_type) {
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -106,14 +115,20 @@ struct TermExpr : Expr {
|
|||
|
||||
static const std::map<std::string, ArithOpType> arith_op_mapping_ = {
|
||||
// arith_op_name -> arith_op
|
||||
{"add", ArithOpType::Add}, {"sub", ArithOpType::Sub}, {"mul", ArithOpType::Mul},
|
||||
{"div", ArithOpType::Div}, {"mod", ArithOpType::Mod},
|
||||
{"add", ArithOpType::Add},
|
||||
{"sub", ArithOpType::Sub},
|
||||
{"mul", ArithOpType::Mul},
|
||||
{"div", ArithOpType::Div},
|
||||
{"mod", ArithOpType::Mod},
|
||||
};
|
||||
|
||||
static const std::map<ArithOpType, std::string> mapping_arith_op_ = {
|
||||
// arith_op_name -> arith_op
|
||||
{ArithOpType::Add, "add"}, {ArithOpType::Sub, "sub"}, {ArithOpType::Mul, "mul"},
|
||||
{ArithOpType::Div, "div"}, {ArithOpType::Mod, "mod"},
|
||||
{ArithOpType::Add, "add"},
|
||||
{ArithOpType::Sub, "sub"},
|
||||
{ArithOpType::Mul, "mul"},
|
||||
{ArithOpType::Div, "div"},
|
||||
{ArithOpType::Mod, "mod"},
|
||||
};
|
||||
|
||||
struct BinaryArithOpEvalRangeExpr : Expr {
|
||||
|
@ -123,14 +138,17 @@ struct BinaryArithOpEvalRangeExpr : Expr {
|
|||
const ArithOpType arith_op_;
|
||||
|
||||
protected:
|
||||
// prevent accidential instantiation
|
||||
// prevent accidental instantiation
|
||||
BinaryArithOpEvalRangeExpr() = delete;
|
||||
|
||||
BinaryArithOpEvalRangeExpr(const FieldId field_id,
|
||||
const DataType data_type,
|
||||
const OpType op_type,
|
||||
const ArithOpType arith_op)
|
||||
: field_id_(field_id), data_type_(data_type), op_type_(op_type), arith_op_(arith_op) {
|
||||
: field_id_(field_id),
|
||||
data_type_(data_type),
|
||||
op_type_(op_type),
|
||||
arith_op_(arith_op) {
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -140,9 +158,14 @@ struct BinaryArithOpEvalRangeExpr : Expr {
|
|||
|
||||
static const std::map<std::string, OpType> mapping_ = {
|
||||
// op_name -> op
|
||||
{"lt", OpType::LessThan}, {"le", OpType::LessEqual}, {"lte", OpType::LessEqual},
|
||||
{"gt", OpType::GreaterThan}, {"ge", OpType::GreaterEqual}, {"gte", OpType::GreaterEqual},
|
||||
{"eq", OpType::Equal}, {"ne", OpType::NotEqual},
|
||||
{"lt", OpType::LessThan},
|
||||
{"le", OpType::LessEqual},
|
||||
{"lte", OpType::LessEqual},
|
||||
{"gt", OpType::GreaterThan},
|
||||
{"ge", OpType::GreaterEqual},
|
||||
{"gte", OpType::GreaterEqual},
|
||||
{"eq", OpType::Equal},
|
||||
{"ne", OpType::NotEqual},
|
||||
};
|
||||
|
||||
struct UnaryRangeExpr : Expr {
|
||||
|
@ -151,10 +174,12 @@ struct UnaryRangeExpr : Expr {
|
|||
const OpType op_type_;
|
||||
|
||||
protected:
|
||||
// prevent accidential instantiation
|
||||
// prevent accidental instantiation
|
||||
UnaryRangeExpr() = delete;
|
||||
|
||||
UnaryRangeExpr(const FieldId field_id, const DataType data_type, const OpType op_type)
|
||||
UnaryRangeExpr(const FieldId field_id,
|
||||
const DataType data_type,
|
||||
const OpType op_type)
|
||||
: field_id_(field_id), data_type_(data_type), op_type_(op_type) {
|
||||
}
|
||||
|
||||
|
@ -170,7 +195,7 @@ struct BinaryRangeExpr : Expr {
|
|||
const bool upper_inclusive_;
|
||||
|
||||
protected:
|
||||
// prevent accidential instantiation
|
||||
// prevent accidental instantiation
|
||||
BinaryRangeExpr() = delete;
|
||||
|
||||
BinaryRangeExpr(const FieldId field_id,
|
||||
|
|
|
@ -28,7 +28,9 @@ template <typename T>
|
|||
struct TermExprImpl : TermExpr {
|
||||
const std::vector<T> terms_;
|
||||
|
||||
TermExprImpl(const FieldId field_id, const DataType data_type, const std::vector<T>& terms)
|
||||
TermExprImpl(const FieldId field_id,
|
||||
const DataType data_type,
|
||||
const std::vector<T>& terms)
|
||||
: TermExpr(field_id, data_type), terms_(terms) {
|
||||
}
|
||||
};
|
||||
|
@ -54,7 +56,10 @@ template <typename T>
|
|||
struct UnaryRangeExprImpl : UnaryRangeExpr {
|
||||
const T value_;
|
||||
|
||||
UnaryRangeExprImpl(const FieldId field_id, const DataType data_type, const OpType op_type, const T value)
|
||||
UnaryRangeExprImpl(const FieldId field_id,
|
||||
const DataType data_type,
|
||||
const OpType op_type,
|
||||
const T value)
|
||||
: UnaryRangeExpr(field_id, data_type, op_type), value_(value) {
|
||||
}
|
||||
};
|
||||
|
@ -70,7 +75,8 @@ struct BinaryRangeExprImpl : BinaryRangeExpr {
|
|||
const bool upper_inclusive,
|
||||
const T lower_value,
|
||||
const T upper_value)
|
||||
: BinaryRangeExpr(field_id, data_type, lower_inclusive, upper_inclusive),
|
||||
: BinaryRangeExpr(
|
||||
field_id, data_type, lower_inclusive, upper_inclusive),
|
||||
lower_value_(lower_value),
|
||||
upper_value_(upper_value) {
|
||||
}
|
||||
|
|
|
@ -232,7 +232,8 @@ Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) {
|
|||
terms[i] = value;
|
||||
}
|
||||
std::sort(terms.begin(), terms.end());
|
||||
return std::make_unique<TermExprImpl<T>>(schema.get_field_id(field_name), schema[field_name].get_data_type(),
|
||||
return std::make_unique<TermExprImpl<T>>(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type(),
|
||||
terms);
|
||||
}
|
||||
|
||||
|
@ -277,8 +278,10 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
|
|||
auto arith = item.value();
|
||||
auto arith_body = arith.begin();
|
||||
|
||||
auto arith_op_name = boost::algorithm::to_lower_copy(std::string(arith_body.key()));
|
||||
AssertInfo(arith_op_mapping_.count(arith_op_name), "arith op(" + arith_op_name + ") not found");
|
||||
auto arith_op_name =
|
||||
boost::algorithm::to_lower_copy(std::string(arith_body.key()));
|
||||
AssertInfo(arith_op_mapping_.count(arith_op_name),
|
||||
"arith op(" + arith_op_name + ") not found");
|
||||
|
||||
auto& arith_op_body = arith_body.value();
|
||||
Assert(arith_op_body.is_object());
|
||||
|
@ -299,8 +302,12 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
|
|||
}
|
||||
|
||||
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
|
||||
schema.get_field_id(field_name), schema[field_name].get_data_type(),
|
||||
arith_op_mapping_.at(arith_op_name), right_operand, mapping_.at(op_name), value);
|
||||
schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type(),
|
||||
arith_op_mapping_.at(arith_op_name),
|
||||
right_operand,
|
||||
mapping_.at(op_name),
|
||||
value);
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
|
@ -313,7 +320,10 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
|
|||
static_assert(always_false<T>, "unsupported type");
|
||||
}
|
||||
return std::make_unique<UnaryRangeExprImpl<T>>(
|
||||
schema.get_field_id(field_name), schema[field_name].get_data_type(), mapping_.at(op_name), item.value());
|
||||
schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type(),
|
||||
mapping_.at(op_name),
|
||||
item.value());
|
||||
} else if (body.size() == 2) {
|
||||
bool has_lower_value = false;
|
||||
bool has_upper_value = false;
|
||||
|
@ -322,8 +332,10 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
|
|||
T lower_value;
|
||||
T upper_value;
|
||||
for (auto& item : body.items()) {
|
||||
auto op_name = boost::algorithm::to_lower_copy(std::string(item.key()));
|
||||
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
|
||||
auto op_name =
|
||||
boost::algorithm::to_lower_copy(std::string(item.key()));
|
||||
AssertInfo(mapping_.count(op_name),
|
||||
"op(" + op_name + ") not found");
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(item.value().is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
|
@ -351,10 +363,15 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
|
|||
PanicInfo("unsupported operator in binary-range node");
|
||||
}
|
||||
}
|
||||
AssertInfo(has_lower_value && has_upper_value, "illegal binary-range node");
|
||||
return std::make_unique<BinaryRangeExprImpl<T>>(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type(), lower_inclusive,
|
||||
upper_inclusive, lower_value, upper_value);
|
||||
AssertInfo(has_lower_value && has_upper_value,
|
||||
"illegal binary-range node");
|
||||
return std::make_unique<BinaryRangeExprImpl<T>>(
|
||||
schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type(),
|
||||
lower_inclusive,
|
||||
upper_inclusive,
|
||||
lower_value,
|
||||
upper_value);
|
||||
} else {
|
||||
PanicInfo("illegal range node, too more or too few ops");
|
||||
}
|
||||
|
@ -377,7 +394,10 @@ Parser::ParseItemList(const Json& body) {
|
|||
}
|
||||
auto old_size = results.size();
|
||||
|
||||
auto new_end = std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) { return x == nullptr; });
|
||||
auto new_end =
|
||||
std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) {
|
||||
return x == nullptr;
|
||||
});
|
||||
|
||||
results.resize(new_end - results.begin());
|
||||
|
||||
|
@ -421,7 +441,8 @@ Parser::ParseMustNode(const Json& body) {
|
|||
auto item_list = ParseItemList(body);
|
||||
auto merger = [](ExprPtr left, ExprPtr right) {
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
return std::make_unique<LogicalBinaryExpr>(OpType::LogicalAnd, left, right);
|
||||
return std::make_unique<LogicalBinaryExpr>(
|
||||
OpType::LogicalAnd, left, right);
|
||||
};
|
||||
return ConstructTree(merger, std::move(item_list));
|
||||
}
|
||||
|
@ -432,7 +453,8 @@ Parser::ParseShouldNode(const Json& body) {
|
|||
Assert(item_list.size() >= 1);
|
||||
auto merger = [](ExprPtr left, ExprPtr right) {
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
return std::make_unique<LogicalBinaryExpr>(OpType::LogicalOr, left, right);
|
||||
return std::make_unique<LogicalBinaryExpr>(
|
||||
OpType::LogicalOr, left, right);
|
||||
};
|
||||
return ConstructTree(merger, std::move(item_list));
|
||||
}
|
||||
|
@ -443,7 +465,8 @@ Parser::ParseMustNotNode(const Json& body) {
|
|||
Assert(item_list.size() >= 1);
|
||||
auto merger = [](ExprPtr left, ExprPtr right) {
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
return std::make_unique<LogicalBinaryExpr>(OpType::LogicalAnd, left, right);
|
||||
return std::make_unique<LogicalBinaryExpr>(
|
||||
OpType::LogicalAnd, left, right);
|
||||
};
|
||||
auto subtree = ConstructTree(merger, std::move(item_list));
|
||||
|
||||
|
|
|
@ -23,16 +23,21 @@ namespace milvus::query {
|
|||
|
||||
// deprecated
|
||||
std::unique_ptr<PlaceholderGroup>
|
||||
ParsePlaceholderGroup(const Plan* plan, const std::string& placeholder_group_blob) {
|
||||
return ParsePlaceholderGroup(plan, reinterpret_cast<const uint8_t*>(placeholder_group_blob.c_str()),
|
||||
placeholder_group_blob.size());
|
||||
ParsePlaceholderGroup(const Plan* plan,
|
||||
const std::string& placeholder_group_blob) {
|
||||
return ParsePlaceholderGroup(
|
||||
plan,
|
||||
reinterpret_cast<const uint8_t*>(placeholder_group_blob.c_str()),
|
||||
placeholder_group_blob.size());
|
||||
}
|
||||
|
||||
std::unique_ptr<PlaceholderGroup>
|
||||
ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, const int64_t blob_len) {
|
||||
namespace ser = milvus::proto::common;
|
||||
ParsePlaceholderGroup(const Plan* plan,
|
||||
const uint8_t* blob,
|
||||
const int64_t blob_len) {
|
||||
namespace set = milvus::proto::common;
|
||||
auto result = std::make_unique<PlaceholderGroup>();
|
||||
ser::PlaceholderGroup ph_group;
|
||||
set::PlaceholderGroup ph_group;
|
||||
auto ok = ph_group.ParseFromArray(blob, blob_len);
|
||||
Assert(ok);
|
||||
for (auto& info : ph_group.placeholders()) {
|
||||
|
@ -45,7 +50,8 @@ ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, const int64_t blob_
|
|||
AssertInfo(element.num_of_queries_, "must have queries");
|
||||
Assert(element.num_of_queries_ > 0);
|
||||
element.line_sizeof_ = info.values().Get(0).size();
|
||||
AssertInfo(field_meta.get_sizeof() == element.line_sizeof_, "vector dimension mismatch");
|
||||
AssertInfo(field_meta.get_sizeof() == element.line_sizeof_,
|
||||
"vector dimension mismatch");
|
||||
auto& target = element.blob_;
|
||||
target.reserve(element.line_sizeof_ * element.num_of_queries_);
|
||||
for (auto& line : info.values()) {
|
||||
|
@ -66,7 +72,9 @@ CreatePlan(const Schema& schema, const std::string& dsl_str) {
|
|||
}
|
||||
|
||||
std::unique_ptr<Plan>
|
||||
CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size) {
|
||||
CreateSearchPlanByExpr(const Schema& schema,
|
||||
const void* serialized_expr_plan,
|
||||
const int64_t size) {
|
||||
// Note: serialized_expr_plan is of binary format
|
||||
proto::plan::PlanNode plan_node;
|
||||
plan_node.ParseFromArray(serialized_expr_plan, size);
|
||||
|
@ -74,7 +82,9 @@ CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, c
|
|||
}
|
||||
|
||||
std::unique_ptr<RetrievePlan>
|
||||
CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size) {
|
||||
CreateRetrievePlanByExpr(const Schema& schema,
|
||||
const void* serialized_expr_plan,
|
||||
const int64_t size) {
|
||||
proto::plan::PlanNode plan_node;
|
||||
plan_node.ParseFromArray(serialized_expr_plan, size);
|
||||
return ProtoParser(schema).CreateRetrievePlan(plan_node);
|
||||
|
@ -111,9 +121,11 @@ Plan::check_identical(Plan& other) {
|
|||
auto json = ShowPlanNodeVisitor().call_child(*this->plan_node_);
|
||||
auto other_json = ShowPlanNodeVisitor().call_child(*other.plan_node_);
|
||||
Assert(json.dump(2) == other_json.dump(2));
|
||||
Assert(this->extra_info_opt_.has_value() == other.extra_info_opt_.has_value());
|
||||
Assert(this->extra_info_opt_.has_value() ==
|
||||
other.extra_info_opt_.has_value());
|
||||
if (this->extra_info_opt_.has_value()) {
|
||||
Assert(this->extra_info_opt_->involved_fields_ == other.extra_info_opt_->involved_fields_);
|
||||
Assert(this->extra_info_opt_->involved_fields_ ==
|
||||
other.extra_info_opt_->involved_fields_);
|
||||
}
|
||||
Assert(this->tag2field_ == other.tag2field_);
|
||||
Assert(this->target_entries_ == other.target_entries_);
|
||||
|
|
|
@ -31,20 +31,27 @@ CreatePlan(const Schema& schema, const std::string& dsl);
|
|||
|
||||
// Note: serialized_expr_plan is of binary format
|
||||
std::unique_ptr<Plan>
|
||||
CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size);
|
||||
CreateSearchPlanByExpr(const Schema& schema,
|
||||
const void* serialized_expr_plan,
|
||||
const int64_t size);
|
||||
|
||||
std::unique_ptr<PlaceholderGroup>
|
||||
ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, const int64_t blob_len);
|
||||
ParsePlaceholderGroup(const Plan* plan,
|
||||
const uint8_t* blob,
|
||||
const int64_t blob_len);
|
||||
|
||||
// deprecated
|
||||
std::unique_ptr<PlaceholderGroup>
|
||||
ParsePlaceholderGroup(const Plan* plan, const std::string& placeholder_group_blob);
|
||||
ParsePlaceholderGroup(const Plan* plan,
|
||||
const std::string& placeholder_group_blob);
|
||||
|
||||
int64_t
|
||||
GetNumOfQueries(const PlaceholderGroup*);
|
||||
|
||||
std::unique_ptr<RetrievePlan>
|
||||
CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size);
|
||||
CreateRetrievePlanByExpr(const Schema& schema,
|
||||
const void* serialized_expr_plan,
|
||||
const int64_t size);
|
||||
|
||||
// Query Overall TopK from Plan
|
||||
// Used to alloc result memory at Go side
|
||||
|
|
|
@ -25,7 +25,9 @@ namespace planpb = milvus::proto::plan;
|
|||
|
||||
template <typename T>
|
||||
std::unique_ptr<TermExprImpl<T>>
|
||||
ExtractTermExprImpl(FieldId field_id, DataType data_type, const planpb::TermExpr& expr_proto) {
|
||||
ExtractTermExprImpl(FieldId field_id,
|
||||
DataType data_type,
|
||||
const planpb::TermExpr& expr_proto) {
|
||||
static_assert(IsScalar<T>);
|
||||
auto size = expr_proto.values_size();
|
||||
std::vector<T> terms(size);
|
||||
|
@ -53,7 +55,9 @@ ExtractTermExprImpl(FieldId field_id, DataType data_type, const planpb::TermExpr
|
|||
|
||||
template <typename T>
|
||||
std::unique_ptr<UnaryRangeExprImpl<T>>
|
||||
ExtractUnaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::UnaryRangeExpr& expr_proto) {
|
||||
ExtractUnaryRangeExprImpl(FieldId field_id,
|
||||
DataType data_type,
|
||||
const planpb::UnaryRangeExpr& expr_proto) {
|
||||
static_assert(IsScalar<T>);
|
||||
auto getValue = [&](const auto& value_proto) -> T {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
|
@ -72,13 +76,18 @@ ExtractUnaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::Un
|
|||
static_assert(always_false<T>);
|
||||
}
|
||||
};
|
||||
return std::make_unique<UnaryRangeExprImpl<T>>(field_id, data_type, static_cast<OpType>(expr_proto.op()),
|
||||
getValue(expr_proto.value()));
|
||||
return std::make_unique<UnaryRangeExprImpl<T>>(
|
||||
field_id,
|
||||
data_type,
|
||||
static_cast<OpType>(expr_proto.op()),
|
||||
getValue(expr_proto.value()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<BinaryRangeExprImpl<T>>
|
||||
ExtractBinaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::BinaryRangeExpr& expr_proto) {
|
||||
ExtractBinaryRangeExprImpl(FieldId field_id,
|
||||
DataType data_type,
|
||||
const planpb::BinaryRangeExpr& expr_proto) {
|
||||
static_assert(IsScalar<T>);
|
||||
auto getValue = [&](const auto& value_proto) -> T {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
|
@ -97,16 +106,21 @@ ExtractBinaryRangeExprImpl(FieldId field_id, DataType data_type, const planpb::B
|
|||
static_assert(always_false<T>);
|
||||
}
|
||||
};
|
||||
return std::make_unique<BinaryRangeExprImpl<T>>(field_id, data_type, expr_proto.lower_inclusive(),
|
||||
expr_proto.upper_inclusive(), getValue(expr_proto.lower_value()),
|
||||
getValue(expr_proto.upper_value()));
|
||||
return std::make_unique<BinaryRangeExprImpl<T>>(
|
||||
field_id,
|
||||
data_type,
|
||||
expr_proto.lower_inclusive(),
|
||||
expr_proto.upper_inclusive(),
|
||||
getValue(expr_proto.lower_value()),
|
||||
getValue(expr_proto.upper_value()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<BinaryArithOpEvalRangeExprImpl<T>>
|
||||
ExtractBinaryArithOpEvalRangeExprImpl(FieldId field_id,
|
||||
DataType data_type,
|
||||
const planpb::BinaryArithOpEvalRangeExpr& expr_proto) {
|
||||
ExtractBinaryArithOpEvalRangeExprImpl(
|
||||
FieldId field_id,
|
||||
DataType data_type,
|
||||
const planpb::BinaryArithOpEvalRangeExpr& expr_proto) {
|
||||
static_assert(std::is_fundamental_v<T>);
|
||||
auto getValue = [&](const auto& value_proto) -> T {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
|
@ -123,8 +137,12 @@ ExtractBinaryArithOpEvalRangeExprImpl(FieldId field_id,
|
|||
}
|
||||
};
|
||||
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
|
||||
field_id, data_type, static_cast<ArithOpType>(expr_proto.arith_op()), getValue(expr_proto.right_operand()),
|
||||
static_cast<OpType>(expr_proto.op()), getValue(expr_proto.value()));
|
||||
field_id,
|
||||
data_type,
|
||||
static_cast<ArithOpType>(expr_proto.arith_op()),
|
||||
getValue(expr_proto.right_operand()),
|
||||
static_cast<OpType>(expr_proto.op()),
|
||||
getValue(expr_proto.value()));
|
||||
}
|
||||
|
||||
std::unique_ptr<VectorPlanNode>
|
||||
|
@ -165,12 +183,15 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
}
|
||||
|
||||
std::unique_ptr<RetrievePlanNode>
|
||||
ProtoParser::RetrievePlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
||||
ProtoParser::RetrievePlanNodeFromProto(
|
||||
const planpb::PlanNode& plan_node_proto) {
|
||||
Assert(plan_node_proto.has_predicates());
|
||||
auto& predicate_proto = plan_node_proto.predicates();
|
||||
auto expr_opt = [&]() -> ExprPtr { return ParseExpr(predicate_proto); }();
|
||||
|
||||
auto plan_node = [&]() -> std::unique_ptr<RetrievePlanNode> { return std::make_unique<RetrievePlanNode>(); }();
|
||||
auto plan_node = [&]() -> std::unique_ptr<RetrievePlanNode> {
|
||||
return std::make_unique<RetrievePlanNode>();
|
||||
}();
|
||||
plan_node->predicate_ = std::move(expr_opt);
|
||||
return plan_node;
|
||||
}
|
||||
|
@ -224,28 +245,36 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) {
|
|||
auto result = [&]() -> ExprPtr {
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
return ExtractUnaryRangeExprImpl<bool>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<bool>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
return ExtractUnaryRangeExprImpl<int8_t>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<int8_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ExtractUnaryRangeExprImpl<int16_t>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<int16_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ExtractUnaryRangeExprImpl<int32_t>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<int32_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ExtractUnaryRangeExprImpl<int64_t>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<int64_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ExtractUnaryRangeExprImpl<float>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<float>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ExtractUnaryRangeExprImpl<double>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<double>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
return ExtractUnaryRangeExprImpl<std::string>(field_id, data_type, expr_pb);
|
||||
return ExtractUnaryRangeExprImpl<std::string>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data type");
|
||||
|
@ -265,28 +294,36 @@ ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) {
|
|||
auto result = [&]() -> ExprPtr {
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
return ExtractBinaryRangeExprImpl<bool>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<bool>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
return ExtractBinaryRangeExprImpl<int8_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<int8_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ExtractBinaryRangeExprImpl<int16_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<int16_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ExtractBinaryRangeExprImpl<int32_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<int32_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ExtractBinaryRangeExprImpl<int64_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<int64_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ExtractBinaryRangeExprImpl<float>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<float>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ExtractBinaryRangeExprImpl<double>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<double>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
return ExtractBinaryRangeExprImpl<std::string>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryRangeExprImpl<std::string>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data type");
|
||||
|
@ -301,12 +338,14 @@ ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) {
|
|||
auto& left_column_info = expr_pb.left_column_info();
|
||||
auto left_field_id = FieldId(left_column_info.field_id());
|
||||
auto left_data_type = schema[left_field_id].get_data_type();
|
||||
Assert(left_data_type == static_cast<DataType>(left_column_info.data_type()));
|
||||
Assert(left_data_type ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
|
||||
auto& right_column_info = expr_pb.right_column_info();
|
||||
auto right_field_id = FieldId(right_column_info.field_id());
|
||||
auto right_data_type = schema[right_field_id].get_data_type();
|
||||
Assert(right_data_type == static_cast<DataType>(right_column_info.data_type()));
|
||||
Assert(right_data_type ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
|
||||
return [&]() -> ExprPtr {
|
||||
auto result = std::make_unique<CompareExpr>();
|
||||
|
@ -333,25 +372,31 @@ ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) {
|
|||
return ExtractTermExprImpl<bool>(field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
return ExtractTermExprImpl<int8_t>(field_id, data_type, expr_pb);
|
||||
return ExtractTermExprImpl<int8_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ExtractTermExprImpl<int16_t>(field_id, data_type, expr_pb);
|
||||
return ExtractTermExprImpl<int16_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ExtractTermExprImpl<int32_t>(field_id, data_type, expr_pb);
|
||||
return ExtractTermExprImpl<int32_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ExtractTermExprImpl<int64_t>(field_id, data_type, expr_pb);
|
||||
return ExtractTermExprImpl<int64_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ExtractTermExprImpl<float>(field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ExtractTermExprImpl<double>(field_id, data_type, expr_pb);
|
||||
return ExtractTermExprImpl<double>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
return ExtractTermExprImpl<std::string>(field_id, data_type, expr_pb);
|
||||
return ExtractTermExprImpl<std::string>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data type");
|
||||
|
@ -378,7 +423,8 @@ ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
|
|||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
|
||||
ProtoParser::ParseBinaryArithOpEvalRangeExpr(
|
||||
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
|
@ -387,22 +433,28 @@ ProtoParser::ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEva
|
|||
auto result = [&]() -> ExprPtr {
|
||||
switch (data_type) {
|
||||
case DataType::INT8: {
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int8_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int8_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int16_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int16_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int32_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int32_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int64_t>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<int64_t>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<float>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<float>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<double>(field_id, data_type, expr_pb);
|
||||
return ExtractBinaryArithOpEvalRangeExprImpl<double>(
|
||||
field_id, data_type, expr_pb);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data type");
|
||||
|
@ -435,7 +487,8 @@ ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) {
|
|||
return ParseCompareExpr(expr_pb.compare_expr());
|
||||
}
|
||||
case ppe::kBinaryArithOpEvalRangeExpr: {
|
||||
return ParseBinaryArithOpEvalRangeExpr(expr_pb.binary_arith_op_eval_range_expr());
|
||||
return ParseBinaryArithOpEvalRangeExpr(
|
||||
expr_pb.binary_arith_op_eval_range_expr());
|
||||
}
|
||||
default:
|
||||
PanicInfo("unsupported expr proto node");
|
||||
|
|
|
@ -30,7 +30,8 @@ class ProtoParser {
|
|||
// ExprFromProto(const proto::plan::Expr& expr_proto);
|
||||
|
||||
ExprPtr
|
||||
ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
|
||||
ParseBinaryArithOpEvalRangeExpr(
|
||||
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
|
||||
|
||||
ExprPtr
|
||||
ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb);
|
||||
|
|
|
@ -50,7 +50,10 @@ struct Relational {
|
|||
template <typename T, typename U>
|
||||
bool
|
||||
operator()(const T& t, const U& u) const {
|
||||
return RelationalImpl<Op, T, U>(t, u, typename TagDispatchTrait<T>::Tag{}, typename TagDispatchTrait<U>::Tag{});
|
||||
return RelationalImpl<Op, T, U>(t,
|
||||
u,
|
||||
typename TagDispatchTrait<T>::Tag{},
|
||||
typename TagDispatchTrait<U>::Tag{});
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
|
|
|
@ -22,15 +22,19 @@
|
|||
namespace milvus::query {
|
||||
|
||||
void
|
||||
CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info) {
|
||||
CheckBruteForceSearchParam(const FieldMeta& field,
|
||||
const SearchInfo& search_info) {
|
||||
auto data_type = field.get_data_type();
|
||||
auto& metric_type = search_info.metric_type_;
|
||||
|
||||
AssertInfo(datatype_is_vector(data_type), "[BruteForceSearch] Data type isn't vector type");
|
||||
AssertInfo(datatype_is_vector(data_type),
|
||||
"[BruteForceSearch] Data type isn't vector type");
|
||||
bool is_float_data_type = (data_type == DataType::VECTOR_FLOAT);
|
||||
bool is_float_metric_type =
|
||||
IsMetricType(metric_type, knowhere::metric::IP) || IsMetricType(metric_type, knowhere::metric::L2);
|
||||
AssertInfo(is_float_data_type == is_float_metric_type, "[BruteForceSearch] Data type and metric type mis-match");
|
||||
IsMetricType(metric_type, knowhere::metric::IP) ||
|
||||
IsMetricType(metric_type, knowhere::metric::L2);
|
||||
AssertInfo(is_float_data_type == is_float_metric_type,
|
||||
"[BruteForceSearch] Data type and metric type miss-match");
|
||||
}
|
||||
|
||||
SubSearchResult
|
||||
|
@ -39,13 +43,17 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
|||
int64_t chunk_rows,
|
||||
const knowhere::Json& conf,
|
||||
const BitsetView& bitset) {
|
||||
SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal);
|
||||
SubSearchResult sub_result(dataset.num_queries,
|
||||
dataset.topk,
|
||||
dataset.metric_type,
|
||||
dataset.round_decimal);
|
||||
try {
|
||||
auto nq = dataset.num_queries;
|
||||
auto dim = dataset.dim;
|
||||
auto topk = dataset.topk;
|
||||
|
||||
auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
|
||||
auto base_dataset =
|
||||
knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw);
|
||||
auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data);
|
||||
auto config = knowhere::Json{
|
||||
{knowhere::meta::METRIC_TYPE, dataset.metric_type},
|
||||
|
@ -60,24 +68,36 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
|||
config[RADIUS] = conf[RADIUS].get<float>();
|
||||
if (conf.contains(RANGE_FILTER)) {
|
||||
config[RANGE_FILTER] = conf[RANGE_FILTER].get<float>();
|
||||
CheckRangeSearchParam(config[RADIUS], config[RANGE_FILTER], dataset.metric_type);
|
||||
CheckRangeSearchParam(
|
||||
config[RADIUS], config[RANGE_FILTER], dataset.metric_type);
|
||||
}
|
||||
auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, config, bitset);
|
||||
auto res = knowhere::BruteForce::RangeSearch(
|
||||
base_dataset, query_dataset, config, bitset);
|
||||
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
|
||||
"failed to range search, " + MatchKnowhereError(res.error()));
|
||||
"failed to range search, " +
|
||||
MatchKnowhereError(res.error()));
|
||||
}
|
||||
auto result = SortRangeSearchResult(res.value(), topk, nq, dataset.metric_type);
|
||||
std::copy_n(GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
|
||||
std::copy_n(GetDatasetDistance(result), nq * topk, sub_result.get_distances());
|
||||
auto result = SortRangeSearchResult(
|
||||
res.value(), topk, nq, dataset.metric_type);
|
||||
std::copy_n(
|
||||
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
|
||||
std::copy_n(GetDatasetDistance(result),
|
||||
nq * topk,
|
||||
sub_result.get_distances());
|
||||
} else {
|
||||
auto stat = knowhere::BruteForce::SearchWithBuf(base_dataset, query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_distances().data(), config, bitset);
|
||||
auto stat = knowhere::BruteForce::SearchWithBuf(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
config,
|
||||
bitset);
|
||||
|
||||
if (stat != knowhere::Status::success) {
|
||||
throw std::invalid_argument("invalid metric type, " + MatchKnowhereError(stat));
|
||||
throw std::invalid_argument("invalid metric type, " +
|
||||
MatchKnowhereError(stat));
|
||||
}
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
|
|
|
@ -20,7 +20,8 @@
|
|||
namespace milvus::query {
|
||||
|
||||
void
|
||||
CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info);
|
||||
CheckBruteForceSearchParam(const FieldMeta& field,
|
||||
const SearchInfo& search_info);
|
||||
|
||||
SubSearchResult
|
||||
BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
|
|
|
@ -37,31 +37,42 @@ FloatIndexSearch(const segcore::SegmentGrowingImpl& segment,
|
|||
auto vecfield_id = info.field_id_;
|
||||
auto& field = schema[vecfield_id];
|
||||
|
||||
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT");
|
||||
dataset::SearchDataset search_dataset{info.metric_type_, num_queries, info.topk_,
|
||||
info.round_decimal_, field.get_dim(), query_data};
|
||||
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT,
|
||||
"[FloatSearch]Field data type isn't VECTOR_FLOAT");
|
||||
dataset::SearchDataset search_dataset{info.metric_type_,
|
||||
num_queries,
|
||||
info.topk_,
|
||||
info.round_decimal_,
|
||||
field.get_dim(),
|
||||
query_data};
|
||||
auto vec_ptr = record.get_field_data<FloatVector>(vecfield_id);
|
||||
|
||||
int current_chunk_id = 0;
|
||||
if (indexing_record.is_in(vecfield_id)) {
|
||||
auto max_indexed_id = indexing_record.get_finished_ack();
|
||||
const auto& field_indexing = indexing_record.get_vec_field_indexing(vecfield_id);
|
||||
const auto& field_indexing =
|
||||
indexing_record.get_vec_field_indexing(vecfield_id);
|
||||
auto search_params = field_indexing.get_search_params(info.topk_);
|
||||
SearchInfo search_conf(info);
|
||||
search_conf.search_params_ = search_params;
|
||||
AssertInfo(vec_ptr->get_size_per_chunk() == field_indexing.get_size_per_chunk(),
|
||||
"[FloatSearch]Chunk size of vector not equal to chunk size of field index");
|
||||
AssertInfo(vec_ptr->get_size_per_chunk() ==
|
||||
field_indexing.get_size_per_chunk(),
|
||||
"[FloatSearch]Chunk size of vector not equal to chunk size "
|
||||
"of field index");
|
||||
|
||||
auto size_per_chunk = field_indexing.get_size_per_chunk();
|
||||
for (int chunk_id = current_chunk_id; chunk_id < max_indexed_id; ++chunk_id) {
|
||||
for (int chunk_id = current_chunk_id; chunk_id < max_indexed_id;
|
||||
++chunk_id) {
|
||||
if ((chunk_id + 1) * size_per_chunk > ins_barrier) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto indexing = field_indexing.get_chunk_indexing(chunk_id);
|
||||
auto sub_view = bitset.subview(chunk_id * size_per_chunk, size_per_chunk);
|
||||
auto sub_view =
|
||||
bitset.subview(chunk_id * size_per_chunk, size_per_chunk);
|
||||
auto vec_index = (index::VectorIndex*)(indexing);
|
||||
auto sub_qr = SearchOnIndex(search_dataset, *vec_index, search_conf, sub_view);
|
||||
auto sub_qr = SearchOnIndex(
|
||||
search_dataset, *vec_index, search_conf, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_seg_offsets()) {
|
||||
|
@ -87,7 +98,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||
SearchResult& results) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& record = segment.get_insert_record();
|
||||
auto active_count = std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
|
||||
auto active_count =
|
||||
std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
|
||||
|
||||
// step 1.1: get meta
|
||||
// step 1.2: get which vector field to search
|
||||
|
@ -96,7 +108,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||
CheckBruteForceSearchParam(field, info);
|
||||
|
||||
auto data_type = field.get_data_type();
|
||||
AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type");
|
||||
AssertInfo(datatype_is_vector(data_type),
|
||||
"[SearchOnGrowing]Data type isn't vector type");
|
||||
|
||||
auto dim = field.get_dim();
|
||||
auto topk = info.topk_;
|
||||
|
@ -105,11 +118,18 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||
|
||||
// step 2: small indexing search
|
||||
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
|
||||
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
dataset::SearchDataset search_dataset{
|
||||
metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
|
||||
int32_t current_chunk_id = 0;
|
||||
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
current_chunk_id = FloatIndexSearch(segment, info, query_data, num_queries, active_count, bitset, final_qr);
|
||||
current_chunk_id = FloatIndexSearch(segment,
|
||||
info,
|
||||
query_data,
|
||||
num_queries,
|
||||
active_count,
|
||||
bitset,
|
||||
final_qr);
|
||||
}
|
||||
|
||||
// step 3: brute force search where small indexing is unavailable
|
||||
|
@ -121,11 +141,16 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||
auto chunk_data = vec_ptr->get_chunk_data(chunk_id);
|
||||
|
||||
auto element_begin = chunk_id * vec_size_per_chunk;
|
||||
auto element_end = std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto element_end =
|
||||
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto size_per_chunk = element_end - element_begin;
|
||||
|
||||
auto sub_view = bitset.subview(element_begin, size_per_chunk);
|
||||
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, info.search_params_, sub_view);
|
||||
auto sub_qr = BruteForceSearch(search_dataset,
|
||||
chunk_data,
|
||||
size_per_chunk,
|
||||
info.search_params_,
|
||||
sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_seg_offsets()) {
|
||||
|
|
|
@ -22,7 +22,8 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
|
|||
auto dim = search_dataset.dim;
|
||||
auto metric_type = search_dataset.metric_type;
|
||||
auto round_decimal = search_dataset.round_decimal;
|
||||
auto dataset = knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
|
||||
auto dataset =
|
||||
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
|
||||
|
||||
// NOTE: VecIndex Query API forget to add const qualifier
|
||||
// NOTE: use const_cast as a workaround
|
||||
|
@ -30,8 +31,10 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
|
|||
auto ans = indexing_nonconst.Query(dataset, search_conf, bitset);
|
||||
|
||||
SubSearchResult sub_qr(num_queries, topK, metric_type, round_decimal);
|
||||
std::copy_n(ans->distances_.data(), num_queries * topK, sub_qr.get_distances());
|
||||
std::copy_n(ans->seg_offsets_.data(), num_queries * topK, sub_qr.get_seg_offsets());
|
||||
std::copy_n(
|
||||
ans->distances_.data(), num_queries * topK, sub_qr.get_distances());
|
||||
std::copy_n(
|
||||
ans->seg_offsets_.data(), num_queries * topK, sub_qr.get_seg_offsets());
|
||||
sub_qr.round_values();
|
||||
return sub_qr;
|
||||
}
|
||||
|
|
|
@ -45,7 +45,8 @@ SearchOnSealedIndex(const Schema& schema,
|
|||
auto conf = search_info.search_params_;
|
||||
conf[knowhere::meta::TOPK] = search_info.topk_;
|
||||
conf[knowhere::meta::METRIC_TYPE] = field_indexing->metric_type_;
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||
auto vec_index =
|
||||
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||
auto index_type = vec_index->GetIndexType();
|
||||
return vec_index->Query(ds, search_info, bitset);
|
||||
}();
|
||||
|
@ -81,11 +82,16 @@ SearchOnSealed(const Schema& schema,
|
|||
auto field_id = search_info.field_id_;
|
||||
auto& field = schema[field_id];
|
||||
|
||||
query::dataset::SearchDataset dataset{search_info.metric_type_, num_queries, search_info.topk_,
|
||||
search_info.round_decimal_, field.get_dim(), query_data};
|
||||
query::dataset::SearchDataset dataset{search_info.metric_type_,
|
||||
num_queries,
|
||||
search_info.topk_,
|
||||
search_info.round_decimal_,
|
||||
field.get_dim(),
|
||||
query_data};
|
||||
|
||||
CheckBruteForceSearchParam(field, search_info);
|
||||
auto sub_qr = BruteForceSearch(dataset, vec_data, row_count, search_info.search_params_, bitset);
|
||||
auto sub_qr = BruteForceSearch(
|
||||
dataset, vec_data, row_count, search_info.search_params_, bitset);
|
||||
|
||||
result.distances_ = std::move(sub_qr.mutable_distances());
|
||||
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
|
||||
|
|
|
@ -19,10 +19,13 @@ namespace milvus::query {
|
|||
template <bool is_desc>
|
||||
void
|
||||
SubSearchResult::merge_impl(const SubSearchResult& right) {
|
||||
AssertInfo(num_queries_ == right.num_queries_, "[SubSearchResult]Nq check failed");
|
||||
AssertInfo(num_queries_ == right.num_queries_,
|
||||
"[SubSearchResult]Nq check failed");
|
||||
AssertInfo(topk_ == right.topk_, "[SubSearchResult]Topk check failed");
|
||||
AssertInfo(metric_type_ == right.metric_type_, "[SubSearchResult]Metric type check failed");
|
||||
AssertInfo(is_desc == PositivelyRelated(metric_type_), "[SubSearchResult]Metric type isn't desc");
|
||||
AssertInfo(metric_type_ == right.metric_type_,
|
||||
"[SubSearchResult]Metric type check failed");
|
||||
AssertInfo(is_desc == PositivelyRelated(metric_type_),
|
||||
"[SubSearchResult]Metric type isn't desc");
|
||||
|
||||
for (int64_t qn = 0; qn < num_queries_; ++qn) {
|
||||
auto offset = qn * topk_;
|
||||
|
@ -72,7 +75,8 @@ SubSearchResult::merge_impl(const SubSearchResult& right) {
|
|||
|
||||
void
|
||||
SubSearchResult::merge(const SubSearchResult& sub_result) {
|
||||
AssertInfo(metric_type_ == sub_result.metric_type_, "[SubSearchResult]Metric type check failed when merge");
|
||||
AssertInfo(metric_type_ == sub_result.metric_type_,
|
||||
"[SubSearchResult]Metric type check failed when merge");
|
||||
if (PositivelyRelated(metric_type_)) {
|
||||
this->merge_impl<true>(sub_result);
|
||||
} else {
|
||||
|
@ -85,7 +89,8 @@ SubSearchResult::round_values() {
|
|||
if (round_decimal_ == -1)
|
||||
return;
|
||||
const float multiplier = pow(10.0, round_decimal_);
|
||||
for (auto it = this->distances_.begin(); it != this->distances_.end(); it++) {
|
||||
for (auto it = this->distances_.begin(); it != this->distances_.end();
|
||||
it++) {
|
||||
*it = round(*it * multiplier) / multiplier;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,10 @@ namespace milvus::query {
|
|||
|
||||
class SubSearchResult {
|
||||
public:
|
||||
SubSearchResult(int64_t num_queries, int64_t topk, const MetricType& metric_type, int64_t round_decimal)
|
||||
SubSearchResult(int64_t num_queries,
|
||||
int64_t topk,
|
||||
const MetricType& metric_type,
|
||||
int64_t round_decimal)
|
||||
: num_queries_(num_queries),
|
||||
topk_(topk),
|
||||
round_decimal_(round_decimal),
|
||||
|
@ -43,7 +46,8 @@ class SubSearchResult {
|
|||
public:
|
||||
static float
|
||||
init_value(const MetricType& metric_type) {
|
||||
return (PositivelyRelated(metric_type) ? -1 : 1) * std::numeric_limits<float>::max();
|
||||
return (PositivelyRelated(metric_type) ? -1 : 1) *
|
||||
std::numeric_limits<float>::max();
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
|
@ -38,7 +38,9 @@ Match<std::string>(const std::string& str, const std::string& val, OpType op) {
|
|||
|
||||
template <>
|
||||
inline bool
|
||||
Match<std::string_view>(const std::string_view& str, const std::string& val, OpType op) {
|
||||
Match<std::string_view>(const std::string_view& str,
|
||||
const std::string& val,
|
||||
OpType op) {
|
||||
switch (op) {
|
||||
case OpType::PrefixMatch:
|
||||
return PrefixMatch(str, val);
|
||||
|
|
|
@ -22,7 +22,9 @@ namespace milvus {
|
|||
namespace query_old {
|
||||
|
||||
BinaryQueryPtr
|
||||
ConstructBinTree(std::vector<BooleanQueryPtr> queries, QueryRelation relation, uint64_t idx) {
|
||||
ConstructBinTree(std::vector<BooleanQueryPtr> queries,
|
||||
QueryRelation relation,
|
||||
uint64_t idx) {
|
||||
if (idx == queries.size()) {
|
||||
return nullptr;
|
||||
} else if (idx == queries.size() - 1) {
|
||||
|
@ -40,7 +42,9 @@ ConstructBinTree(std::vector<BooleanQueryPtr> queries, QueryRelation relation, u
|
|||
}
|
||||
|
||||
Status
|
||||
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_queries, BinaryQueryPtr binary_query, uint64_t idx) {
|
||||
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_queries,
|
||||
BinaryQueryPtr binary_query,
|
||||
uint64_t idx) {
|
||||
if (idx == leaf_queries.size()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -58,23 +62,27 @@ ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_queries, BinaryQueryPtr bina
|
|||
binary_query->left_query->bin->relation = binary_query->relation;
|
||||
binary_query->right_query->leaf = leaf_queries[idx];
|
||||
++idx;
|
||||
return ConstructLeafBinTree(leaf_queries, binary_query->left_query->bin, idx);
|
||||
return ConstructLeafBinTree(
|
||||
leaf_queries, binary_query->left_query->bin, idx);
|
||||
}
|
||||
}
|
||||
|
||||
Status
|
||||
GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) {
|
||||
if (query->getBooleanQueries().size() == 0) {
|
||||
if (binary_query->relation == QueryRelation::AND || binary_query->relation == QueryRelation::OR) {
|
||||
if (binary_query->relation == QueryRelation::AND ||
|
||||
binary_query->relation == QueryRelation::OR) {
|
||||
// Put VectorQuery to the end of leaf queries
|
||||
auto query_size = query->getLeafQueries().size();
|
||||
for (uint64_t i = 0; i < query_size; ++i) {
|
||||
if (query->getLeafQueries()[i]->vector_placeholder.size() > 0) {
|
||||
std::swap(query->getLeafQueries()[i], query->getLeafQueries()[0]);
|
||||
std::swap(query->getLeafQueries()[i],
|
||||
query->getLeafQueries()[0]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ConstructLeafBinTree(query->getLeafQueries(), binary_query, 0);
|
||||
return ConstructLeafBinTree(
|
||||
query->getLeafQueries(), binary_query, 0);
|
||||
} else {
|
||||
switch (query->getOccur()) {
|
||||
case Occur::MUST: {
|
||||
|
@ -154,7 +162,8 @@ GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) {
|
|||
|
||||
if (must_not_queries.size() > 1) {
|
||||
// Construct a must_not binary tree
|
||||
must_not_bquery = ConstructBinTree(must_not_queries, QueryRelation::R1, 0);
|
||||
must_not_bquery =
|
||||
ConstructBinTree(must_not_queries, QueryRelation::R1, 0);
|
||||
++bquery_num;
|
||||
} else if (must_not_queries.size() == 1) {
|
||||
must_not_bquery = must_not_queries[0]->getBinaryQuery();
|
||||
|
@ -228,7 +237,8 @@ BinaryQueryHeight(BinaryQueryPtr& binary_query) {
|
|||
*/
|
||||
|
||||
Status
|
||||
rule_1(BooleanQueryPtr& boolean_query, std::stack<BooleanQueryPtr>& path_stack) {
|
||||
rule_1(BooleanQueryPtr& boolean_query,
|
||||
std::stack<BooleanQueryPtr>& path_stack) {
|
||||
auto status = Status::OK();
|
||||
if (boolean_query != nullptr) {
|
||||
path_stack.push(boolean_query);
|
||||
|
@ -236,9 +246,11 @@ rule_1(BooleanQueryPtr& boolean_query, std::stack<BooleanQueryPtr>& path_stack)
|
|||
if (!leaf_query->vector_placeholder.empty()) {
|
||||
while (!path_stack.empty()) {
|
||||
auto query = path_stack.top();
|
||||
if (query->getOccur() == Occur::SHOULD || query->getOccur() == Occur::MUST_NOT) {
|
||||
if (query->getOccur() == Occur::SHOULD ||
|
||||
query->getOccur() == Occur::MUST_NOT) {
|
||||
std::string msg =
|
||||
"The child node of 'should' and 'must_not' can only be 'term query' and 'range query'.";
|
||||
"The child node of 'should' and 'must_not' can "
|
||||
"only be 'term query' and 'range query'.";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
path_stack.pop();
|
||||
|
@ -259,8 +271,10 @@ Status
|
|||
rule_2(BooleanQueryPtr& boolean_query) {
|
||||
auto status = Status::OK();
|
||||
if (boolean_query != nullptr) {
|
||||
if (!boolean_query->getBooleanQueries().empty() && !boolean_query->getLeafQueries().empty()) {
|
||||
std::string msg = "One layer cannot include bool query and leaf query.";
|
||||
if (!boolean_query->getBooleanQueries().empty() &&
|
||||
!boolean_query->getLeafQueries().empty()) {
|
||||
std::string msg =
|
||||
"One layer cannot include bool query and leaf query.";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
} else {
|
||||
for (auto query : boolean_query->getBooleanQueries()) {
|
||||
|
@ -279,8 +293,11 @@ ValidateBooleanQuery(BooleanQueryPtr& boolean_query) {
|
|||
auto status = Status::OK();
|
||||
if (boolean_query != nullptr) {
|
||||
for (auto& query : boolean_query->getBooleanQueries()) {
|
||||
if (query->getOccur() == Occur::SHOULD || query->getOccur() == Occur::MUST_NOT) {
|
||||
std::string msg = "The direct child node of 'bool' node cannot be 'should' node or 'must_not' node.";
|
||||
if (query->getOccur() == Occur::SHOULD ||
|
||||
query->getOccur() == Occur::MUST_NOT) {
|
||||
std::string msg =
|
||||
"The direct child node of 'bool' node cannot be 'should' "
|
||||
"node or 'must_not' node.";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,10 +20,14 @@ namespace milvus {
|
|||
namespace query_old {
|
||||
|
||||
BinaryQueryPtr
|
||||
ConstructBinTree(std::vector<BooleanQueryPtr> clauses, QueryRelation relation, uint64_t idx);
|
||||
ConstructBinTree(std::vector<BooleanQueryPtr> clauses,
|
||||
QueryRelation relation,
|
||||
uint64_t idx);
|
||||
|
||||
Status
|
||||
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_clauses, BinaryQueryPtr binary_query, uint64_t idx);
|
||||
ConstructLeafBinTree(std::vector<LeafQueryPtr> leaf_clauses,
|
||||
BinaryQueryPtr binary_query,
|
||||
uint64_t idx);
|
||||
|
||||
Status
|
||||
GenBinaryQuery(BooleanQueryPtr clause, BinaryQueryPtr& binary_query);
|
||||
|
|
|
@ -131,9 +131,10 @@ using QueryPtr = std::shared_ptr<Query>;
|
|||
|
||||
namespace query {
|
||||
struct QueryDeprecated {
|
||||
int64_t num_queries; //
|
||||
int topK; // topK of queries
|
||||
std::string field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM)
|
||||
int64_t num_queries; //
|
||||
int topK; // topK of queries
|
||||
std::string
|
||||
field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM)
|
||||
std::vector<float> query_raw_data; // must be size of num_queries * DIM
|
||||
};
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ CopyRowRecords(const RepeatedPtrField<proto::service::PlaceholderValue>& grpc_re
|
|||
memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
// step 3: contruct vectors
|
||||
// step 3: construct vectors
|
||||
vectors.vector_count_ = grpc_records.size();
|
||||
vectors.float_data_.swap(float_array);
|
||||
vectors.binary_data_.swap(binary_array);
|
||||
|
@ -62,7 +62,9 @@ CopyRowRecords(const RepeatedPtrField<proto::service::PlaceholderValue>& grpc_re
|
|||
#endif
|
||||
|
||||
Status
|
||||
ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& query, std::string& field_name) {
|
||||
ProcessLeafQueryJson(const milvus::json& query_json,
|
||||
query_old::BooleanQueryPtr& query,
|
||||
std::string& field_name) {
|
||||
#if 1
|
||||
if (query_json.contains("term")) {
|
||||
auto leaf_query = std::make_shared<query_old::LeafQuery>();
|
||||
|
@ -120,12 +122,15 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
|
|||
|
||||
for (auto& json : must_json) {
|
||||
auto must_query = std::make_shared<query_old::BooleanQuery>();
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(json, must_query, query_ptr));
|
||||
if (json.contains("must") || json.contains("should") ||
|
||||
json.contains("must_not")) {
|
||||
STATUS_CHECK(
|
||||
ProcessBooleanQueryJson(json, must_query, query_ptr));
|
||||
boolean_query->AddBooleanQuery(must_query);
|
||||
} else {
|
||||
std::string field_name;
|
||||
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name));
|
||||
STATUS_CHECK(
|
||||
ProcessLeafQueryJson(json, boolean_query, field_name));
|
||||
if (!field_name.empty()) {
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
|
@ -141,12 +146,15 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
|
|||
|
||||
for (auto& json : should_json) {
|
||||
auto should_query = std::make_shared<query_old::BooleanQuery>();
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(json, should_query, query_ptr));
|
||||
if (json.contains("must") || json.contains("should") ||
|
||||
json.contains("must_not")) {
|
||||
STATUS_CHECK(
|
||||
ProcessBooleanQueryJson(json, should_query, query_ptr));
|
||||
boolean_query->AddBooleanQuery(should_query);
|
||||
} else {
|
||||
std::string field_name;
|
||||
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name));
|
||||
STATUS_CHECK(
|
||||
ProcessLeafQueryJson(json, boolean_query, field_name));
|
||||
if (!field_name.empty()) {
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
|
@ -161,20 +169,25 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
|
|||
}
|
||||
|
||||
for (auto& json : should_json) {
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
auto must_not_query = std::make_shared<query_old::BooleanQuery>();
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(json, must_not_query, query_ptr));
|
||||
if (json.contains("must") || json.contains("should") ||
|
||||
json.contains("must_not")) {
|
||||
auto must_not_query =
|
||||
std::make_shared<query_old::BooleanQuery>();
|
||||
STATUS_CHECK(ProcessBooleanQueryJson(
|
||||
json, must_not_query, query_ptr));
|
||||
boolean_query->AddBooleanQuery(must_not_query);
|
||||
} else {
|
||||
std::string field_name;
|
||||
STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name));
|
||||
STATUS_CHECK(
|
||||
ProcessLeafQueryJson(json, boolean_query, field_name));
|
||||
if (!field_name.empty()) {
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::string msg = "BoolQuery json string does not include bool query";
|
||||
std::string msg =
|
||||
"BoolQuery json string does not include bool query";
|
||||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
}
|
||||
|
@ -183,7 +196,8 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
|
|||
}
|
||||
|
||||
Status
|
||||
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
|
||||
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<
|
||||
::milvus::grpc::VectorParam>& vector_params,
|
||||
const std::string& dsl_string,
|
||||
query_old::BooleanQueryPtr& boolean_query,
|
||||
query_old::QueryPtr& query_ptr) {
|
||||
|
@ -196,7 +210,8 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::gr
|
|||
}
|
||||
auto status = Status::OK();
|
||||
if (vector_params.empty()) {
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL must include vector query");
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER,
|
||||
"DSL must include vector query");
|
||||
}
|
||||
for (const auto& vector_param : vector_params) {
|
||||
const std::string& vector_string = vector_param.json();
|
||||
|
@ -216,32 +231,41 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::gr
|
|||
if (param_json.contains("metric_type")) {
|
||||
std::string metric_type = param_json["metric_type"];
|
||||
vector_query->metric_type = metric_type;
|
||||
query_ptr->metric_types.insert({field_name, param_json["metric_type"]});
|
||||
query_ptr->metric_types.insert(
|
||||
{field_name, param_json["metric_type"]});
|
||||
}
|
||||
if (!vector_param_it.value()["params"].empty()) {
|
||||
vector_query->extra_params = vector_param_it.value()["params"];
|
||||
vector_query->extra_params =
|
||||
vector_param_it.value()["params"];
|
||||
}
|
||||
query_ptr->index_fields.insert(field_name);
|
||||
}
|
||||
|
||||
engine::VectorsData vector_data;
|
||||
CopyRowRecords(vector_param.row_record().records(),
|
||||
google::protobuf::RepeatedField<google::protobuf::int64>(), vector_data);
|
||||
CopyRowRecords(
|
||||
vector_param.row_record().records(),
|
||||
google::protobuf::RepeatedField<google::protobuf::int64>(),
|
||||
vector_data);
|
||||
vector_query->query_vector.vector_count = vector_data.vector_count_;
|
||||
vector_query->query_vector.binary_data.swap(vector_data.binary_data_);
|
||||
vector_query->query_vector.binary_data.swap(
|
||||
vector_data.binary_data_);
|
||||
vector_query->query_vector.float_data.swap(vector_data.float_data_);
|
||||
|
||||
query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));
|
||||
query_ptr->vectors.insert(
|
||||
std::make_pair(placeholder, vector_query));
|
||||
}
|
||||
if (dsl_json.contains("bool")) {
|
||||
auto boolean_query_json = dsl_json["bool"];
|
||||
JSON_NULL_CHECK(boolean_query_json);
|
||||
status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr);
|
||||
status = ProcessBooleanQueryJson(
|
||||
boolean_query_json, boolean_query, query_ptr);
|
||||
if (!status.ok()) {
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool");
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER,
|
||||
"DSL does not include bool");
|
||||
}
|
||||
} else {
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool query");
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER,
|
||||
"DSL does not include bool query");
|
||||
}
|
||||
return Status::OK();
|
||||
} catch (std::exception& e) {
|
||||
|
@ -254,7 +278,8 @@ DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::gr
|
|||
#endif
|
||||
query_old::QueryPtr
|
||||
Transformer(proto::service::Query* request) {
|
||||
query_old::BooleanQueryPtr boolean_query = std::make_shared<query_old::BooleanQuery>();
|
||||
query_old::BooleanQueryPtr boolean_query =
|
||||
std::make_shared<query_old::BooleanQuery>();
|
||||
query_old::QueryPtr query_ptr = std::make_shared<query_old::Query>();
|
||||
#if 0
|
||||
query_ptr->collection_id = request->collection_name();
|
||||
|
|
|
@ -43,8 +43,10 @@ CheckParameterRange(const milvus::json& json_params,
|
|||
bool min_err = min_close ? value < min : value <= min;
|
||||
bool max_err = max_closed ? value > max : value >= max;
|
||||
if (min_err || max_err) {
|
||||
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value) + ". Valid range is " +
|
||||
(min_close ? "[" : "(") + std::to_string(min) + ", " + std::to_string(max) +
|
||||
std::string msg = "Invalid " + param_name +
|
||||
" value: " + std::to_string(value) +
|
||||
". Valid range is " + (min_close ? "[" : "(") +
|
||||
std::to_string(min) + ", " + std::to_string(max) +
|
||||
(max_closed ? "]" : ")");
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
|
@ -60,7 +62,8 @@ CheckParameterRange(const milvus::json& json_params,
|
|||
}
|
||||
|
||||
Status
|
||||
CheckParameterExistence(const milvus::json& json_params, const std::string& param_name) {
|
||||
CheckParameterExistence(const milvus::json& json_params,
|
||||
const std::string& param_name) {
|
||||
if (json_params.find(param_name) == json_params.end()) {
|
||||
std::string msg = "Parameter list must contain: ";
|
||||
msg += param_name;
|
||||
|
@ -71,7 +74,8 @@ CheckParameterExistence(const milvus::json& json_params, const std::string& para
|
|||
try {
|
||||
int64_t value = json_params[param_name];
|
||||
if (value < 0) {
|
||||
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value);
|
||||
std::string msg =
|
||||
"Invalid " + param_name + " value: " + std::to_string(value);
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
|
@ -96,11 +100,13 @@ ValidateCollectionName(const std::string& collection_name) {
|
|||
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
|
||||
}
|
||||
|
||||
std::string invalid_msg = "Invalid collection name: " + collection_name + ". ";
|
||||
std::string invalid_msg =
|
||||
"Invalid collection name: " + collection_name + ". ";
|
||||
// Collection name size shouldn't exceed engine::MAX_NAME_LENGTH.
|
||||
if (collection_name.size() > engine::MAX_NAME_LENGTH) {
|
||||
std::string msg = invalid_msg + "The length of a collection name must be less than " +
|
||||
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
|
||||
std::string msg =
|
||||
invalid_msg + "The length of a collection name must be less than " +
|
||||
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
|
||||
}
|
||||
|
@ -108,7 +114,9 @@ ValidateCollectionName(const std::string& collection_name) {
|
|||
// Collection name first character should be underscore or character.
|
||||
char first_char = collection_name[0];
|
||||
if (first_char != '_' && std::isalpha(first_char) == 0) {
|
||||
std::string msg = invalid_msg + "The first character of a collection name must be an underscore or letter.";
|
||||
std::string msg = invalid_msg +
|
||||
"The first character of a collection name must be an "
|
||||
"underscore or letter.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
|
||||
}
|
||||
|
@ -116,8 +124,11 @@ ValidateCollectionName(const std::string& collection_name) {
|
|||
int64_t table_name_size = collection_name.size();
|
||||
for (int64_t i = 1; i < table_name_size; ++i) {
|
||||
char name_char = collection_name[i];
|
||||
if (name_char != '_' && name_char != '$' && std::isalnum(name_char) == 0) {
|
||||
std::string msg = invalid_msg + "Collection name can only contain numbers, letters, and underscores.";
|
||||
if (name_char != '_' && name_char != '$' &&
|
||||
std::isalnum(name_char) == 0) {
|
||||
std::string msg = invalid_msg +
|
||||
"Collection name can only contain numbers, "
|
||||
"letters, and underscores.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
|
||||
}
|
||||
|
@ -138,8 +149,9 @@ ValidateFieldName(const std::string& field_name) {
|
|||
std::string invalid_msg = "Invalid field name: " + field_name + ". ";
|
||||
// Field name size shouldn't exceed engine::MAX_NAME_LENGTH.
|
||||
if (field_name.size() > engine::MAX_NAME_LENGTH) {
|
||||
std::string msg = invalid_msg + "The length of a field name must be less than " +
|
||||
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
|
||||
std::string msg =
|
||||
invalid_msg + "The length of a field name must be less than " +
|
||||
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_FIELD_NAME, msg);
|
||||
}
|
||||
|
@ -147,7 +159,9 @@ ValidateFieldName(const std::string& field_name) {
|
|||
// Field name first character should be underscore or character.
|
||||
char first_char = field_name[0];
|
||||
if (first_char != '_' && std::isalpha(first_char) == 0) {
|
||||
std::string msg = invalid_msg + "The first character of a field name must be an underscore or letter.";
|
||||
std::string msg = invalid_msg +
|
||||
"The first character of a field name must be an "
|
||||
"underscore or letter.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_FIELD_NAME, msg);
|
||||
}
|
||||
|
@ -156,7 +170,9 @@ ValidateFieldName(const std::string& field_name) {
|
|||
for (int64_t i = 1; i < field_name_size; ++i) {
|
||||
char name_char = field_name[i];
|
||||
if (name_char != '_' && std::isalnum(name_char) == 0) {
|
||||
std::string msg = invalid_msg + "Field name cannot only contain numbers, letters, and underscores.";
|
||||
std::string msg = invalid_msg +
|
||||
"Field name cannot only contain numbers, "
|
||||
"letters, and underscores.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_FIELD_NAME, msg);
|
||||
}
|
||||
|
@ -175,7 +191,8 @@ ValidateVectorIndexType(std::string& index_type, bool is_binary) {
|
|||
}
|
||||
|
||||
// string case insensitive
|
||||
std::transform(index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
|
||||
std::transform(
|
||||
index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
|
||||
|
||||
static std::set<std::string> s_vector_index_type = {
|
||||
knowhere::IndexEnum::INVALID,
|
||||
|
@ -192,7 +209,8 @@ ValidateVectorIndexType(std::string& index_type, bool is_binary) {
|
|||
knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
|
||||
};
|
||||
|
||||
std::set<std::string>& index_types = is_binary ? s_binary_index_types : s_vector_index_type;
|
||||
std::set<std::string>& index_types =
|
||||
is_binary ? s_binary_index_types : s_vector_index_type;
|
||||
if (index_types.find(index_type) == index_types.end()) {
|
||||
std::string msg = "Invalid index type: " + index_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
|
@ -212,7 +230,8 @@ ValidateStructuredIndexType(std::string& index_type) {
|
|||
}
|
||||
|
||||
// string case insensitive
|
||||
std::transform(index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
|
||||
std::transform(
|
||||
index_type.begin(), index_type.end(), index_type.begin(), ::toupper);
|
||||
|
||||
static std::set<std::string> s_index_types = {
|
||||
engine::DEFAULT_STRUCTURED_INDEX,
|
||||
|
@ -230,14 +249,16 @@ ValidateStructuredIndexType(std::string& index_type) {
|
|||
Status
|
||||
ValidateDimension(int64_t dim, bool is_binary) {
|
||||
if (dim <= 0 || dim > engine::MAX_DIMENSION) {
|
||||
std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be in range 1 ~ " +
|
||||
std::string msg = "Invalid dimension: " + std::to_string(dim) +
|
||||
". Should be in range 1 ~ " +
|
||||
std::to_string(engine::MAX_DIMENSION) + ".";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
|
||||
}
|
||||
|
||||
if (is_binary && (dim % 8) != 0) {
|
||||
std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be multiple of 8.";
|
||||
std::string msg = "Invalid dimension: " + std::to_string(dim) +
|
||||
". Should be multiple of 8.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
|
||||
}
|
||||
|
@ -246,30 +267,36 @@ ValidateDimension(int64_t dim, bool is_binary) {
|
|||
}
|
||||
|
||||
Status
|
||||
ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const std::string& index_type) {
|
||||
ValidateIndexParams(const milvus::json& index_params,
|
||||
int64_t dimension,
|
||||
const std::string& index_type) {
|
||||
if (engine::utils::IsFlatIndexType(index_type)) {
|
||||
return Status::OK();
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT ||
|
||||
index_type == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 ||
|
||||
index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536);
|
||||
auto status = CheckParameterRange(
|
||||
index_params, knowhere::IndexParams::nlist, 1, 65536);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536);
|
||||
auto status = CheckParameterRange(
|
||||
index_params, knowhere::IndexParams::nlist, 1, 65536);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CheckParameterExistence(index_params, knowhere::IndexParams::m);
|
||||
status =
|
||||
CheckParameterExistence(index_params, knowhere::IndexParams::m);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// special check for 'm' parameter
|
||||
int64_t m_value = index_params[knowhere::IndexParams::m];
|
||||
if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, m_value)) {
|
||||
if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension,
|
||||
m_value)) {
|
||||
std::string msg = "Invalid m, dimension can't not be divided by m ";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
|
@ -298,16 +325,19 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
|
|||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}*/
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_HNSW) {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64);
|
||||
auto status =
|
||||
CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = CheckParameterRange(index_params, knowhere::IndexParams::efConstruction, 8, 512);
|
||||
status = CheckParameterRange(
|
||||
index_params, knowhere::IndexParams::efConstruction, 8, 512);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024);
|
||||
auto status = CheckParameterRange(
|
||||
index_params, knowhere::IndexParams::n_trees, 1, 1024);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -321,8 +351,10 @@ ValidateSegmentRowCount(int64_t segment_row_count) {
|
|||
int64_t min = config.engine.build_index_threshold();
|
||||
int max = engine::MAX_SEGMENT_ROW_COUNT;
|
||||
if (segment_row_count < min || segment_row_count > max) {
|
||||
std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " +
|
||||
"Should be in range " + std::to_string(min) + " ~ " + std::to_string(max) + ".";
|
||||
std::string msg =
|
||||
"Invalid segment row count: " + std::to_string(segment_row_count) +
|
||||
". " + "Should be in range " + std::to_string(min) + " ~ " +
|
||||
std::to_string(max) + ".";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg);
|
||||
}
|
||||
|
@ -330,21 +362,26 @@ ValidateSegmentRowCount(int64_t segment_row_count) {
|
|||
}
|
||||
|
||||
Status
|
||||
ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type) {
|
||||
ValidateIndexMetricType(const std::string& metric_type,
|
||||
const std::string& index_type) {
|
||||
if (engine::utils::IsFlatIndexType(index_type)) {
|
||||
// pass
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
|
||||
// binary
|
||||
if (metric_type != knowhere::Metric::HAMMING && metric_type != knowhere::Metric::JACCARD &&
|
||||
if (metric_type != knowhere::Metric::HAMMING &&
|
||||
metric_type != knowhere::Metric::JACCARD &&
|
||||
metric_type != knowhere::Metric::TANIMOTO) {
|
||||
std::string msg = "Index metric type " + metric_type + " does not match index type " + index_type;
|
||||
std::string msg = "Index metric type " + metric_type +
|
||||
" does not match index type " + index_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
} else {
|
||||
// float
|
||||
if (metric_type != knowhere::Metric::L2 && metric_type != knowhere::Metric::IP) {
|
||||
std::string msg = "Index metric type " + metric_type + " does not match index type " + index_type;
|
||||
if (metric_type != knowhere::Metric::L2 &&
|
||||
metric_type != knowhere::Metric::IP) {
|
||||
std::string msg = "Index metric type " + metric_type +
|
||||
" does not match index type " + index_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
|
@ -357,16 +394,22 @@ Status
|
|||
ValidateSearchMetricType(const std::string& metric_type, bool is_binary) {
|
||||
if (is_binary) {
|
||||
// binary
|
||||
if (metric_type == knowhere::Metric::L2 || metric_type == knowhere::Metric::IP) {
|
||||
std::string msg = "Cannot search binary entities with index metric type " + metric_type;
|
||||
if (metric_type == knowhere::Metric::L2 ||
|
||||
metric_type == knowhere::Metric::IP) {
|
||||
std::string msg =
|
||||
"Cannot search binary entities with index metric type " +
|
||||
metric_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
} else {
|
||||
// float
|
||||
if (metric_type == knowhere::Metric::HAMMING || metric_type == knowhere::Metric::JACCARD ||
|
||||
if (metric_type == knowhere::Metric::HAMMING ||
|
||||
metric_type == knowhere::Metric::JACCARD ||
|
||||
metric_type == knowhere::Metric::TANIMOTO) {
|
||||
std::string msg = "Cannot search float entities with index metric type " + metric_type;
|
||||
std::string msg =
|
||||
"Cannot search float entities with index metric type " +
|
||||
metric_type;
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
|
@ -378,8 +421,8 @@ ValidateSearchMetricType(const std::string& metric_type, bool is_binary) {
|
|||
Status
|
||||
ValidateSearchTopk(int64_t top_k) {
|
||||
if (top_k <= 0 || top_k > QUERY_MAX_TOPK) {
|
||||
std::string msg =
|
||||
"Invalid topk: " + std::to_string(top_k) + ". " + "The topk must be within the range of 1 ~ 16384.";
|
||||
std::string msg = "Invalid topk: " + std::to_string(top_k) + ". " +
|
||||
"The topk must be within the range of 1 ~ 16384.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_TOPK, msg);
|
||||
}
|
||||
|
@ -400,7 +443,9 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
|
|||
std::string invalid_msg = "Invalid partition tag: " + tag + ". ";
|
||||
// Partition tag size shouldn't exceed 255.
|
||||
if (tag.size() > engine::MAX_NAME_LENGTH) {
|
||||
std::string msg = invalid_msg + "The length of a partition tag must be less than 255 characters.";
|
||||
std::string msg = invalid_msg +
|
||||
"The length of a partition tag must be less than "
|
||||
"255 characters.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_PARTITION_TAG, msg);
|
||||
}
|
||||
|
@ -408,7 +453,9 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
|
|||
// Partition tag first character should be underscore or character.
|
||||
char first_char = tag[0];
|
||||
if (first_char != '_' && std::isalnum(first_char) == 0) {
|
||||
std::string msg = invalid_msg + "The first character of a partition tag must be an underscore or letter.";
|
||||
std::string msg = invalid_msg +
|
||||
"The first character of a partition tag must be "
|
||||
"an underscore or letter.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_PARTITION_TAG, msg);
|
||||
}
|
||||
|
@ -416,8 +463,11 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
|
|||
int64_t tag_size = tag.size();
|
||||
for (int64_t i = 1; i < tag_size; ++i) {
|
||||
char name_char = tag[i];
|
||||
if (name_char != '_' && name_char != '$' && std::isalnum(name_char) == 0) {
|
||||
std::string msg = invalid_msg + "Partition tag can only contain numbers, letters, and underscores.";
|
||||
if (name_char != '_' && name_char != '$' &&
|
||||
std::isalnum(name_char) == 0) {
|
||||
std::string msg = invalid_msg +
|
||||
"Partition tag can only contain numbers, "
|
||||
"letters, and underscores.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_PARTITION_TAG, msg);
|
||||
}
|
||||
|
@ -457,8 +507,9 @@ ValidateInsertDataSize(const InsertParam& insert_param) {
|
|||
}
|
||||
|
||||
if (chunk_size > engine::MAX_INSERT_DATA_SIZE) {
|
||||
std::string msg = "The amount of data inserted each time cannot exceed " +
|
||||
std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB";
|
||||
std::string msg =
|
||||
"The amount of data inserted each time cannot exceed " +
|
||||
std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB";
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg);
|
||||
}
|
||||
|
||||
|
@ -468,7 +519,9 @@ ValidateInsertDataSize(const InsertParam& insert_param) {
|
|||
Status
|
||||
ValidateCompactThreshold(double threshold) {
|
||||
if (threshold > 1.0 || threshold < 0.0) {
|
||||
std::string msg = "Invalid compact threshold: " + std::to_string(threshold) + ". Should be in range [0.0, 1.0]";
|
||||
std::string msg =
|
||||
"Invalid compact threshold: " + std::to_string(threshold) +
|
||||
". Should be in range [0.0, 1.0]";
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg);
|
||||
}
|
||||
|
||||
|
|
|
@ -42,13 +42,16 @@ extern Status
|
|||
ValidateStructuredIndexType(std::string& index_type);
|
||||
|
||||
extern Status
|
||||
ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const std::string& index_type);
|
||||
ValidateIndexParams(const milvus::json& index_params,
|
||||
int64_t dimension,
|
||||
const std::string& index_type);
|
||||
|
||||
extern Status
|
||||
ValidateSegmentRowCount(int64_t segment_row_count);
|
||||
|
||||
extern Status
|
||||
ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type);
|
||||
ValidateIndexMetricType(const std::string& metric_type,
|
||||
const std::string& index_type);
|
||||
|
||||
extern Status
|
||||
ValidateSearchMetricType(const std::string& metric_type, bool is_binary);
|
||||
|
|
|
@ -45,7 +45,9 @@ class ExecExprVisitor : public ExprVisitor {
|
|||
visit(CompareExpr& expr) override;
|
||||
|
||||
public:
|
||||
ExecExprVisitor(const segcore::SegmentInternalInterface& segment, int64_t row_count, Timestamp timestamp)
|
||||
ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
|
||||
int64_t row_count,
|
||||
Timestamp timestamp)
|
||||
: segment_(segment), row_count_(row_count), timestamp_(timestamp) {
|
||||
}
|
||||
|
||||
|
@ -62,11 +64,15 @@ class ExecExprVisitor : public ExprVisitor {
|
|||
public:
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecRangeVisitorImpl(FieldId field_id, IndexFunc func, ElementFunc element_func) -> BitsetType;
|
||||
ExecRangeVisitorImpl(FieldId field_id,
|
||||
IndexFunc func,
|
||||
ElementFunc element_func) -> BitsetType;
|
||||
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func) -> BitsetType;
|
||||
ExecDataRangeVisitorImpl(FieldId field_id,
|
||||
IndexFunc index_func,
|
||||
ElementFunc element_func) -> BitsetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
|
@ -74,7 +80,8 @@ class ExecExprVisitor : public ExprVisitor {
|
|||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
|
||||
ExecBinaryArithOpEvalRangeVisitorDispatcher(
|
||||
BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
|
@ -90,7 +97,8 @@ class ExecExprVisitor : public ExprVisitor {
|
|||
|
||||
template <typename CmpFunc>
|
||||
auto
|
||||
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) -> BitsetType;
|
||||
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
|
||||
-> BitsetType;
|
||||
|
||||
private:
|
||||
const segcore::SegmentInternalInterface& segment_;
|
||||
|
|
|
@ -34,10 +34,13 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
|
|||
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
|
||||
Timestamp timestamp,
|
||||
const PlaceholderGroup* placeholder_group)
|
||||
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
|
||||
: segment_(segment),
|
||||
timestamp_(timestamp),
|
||||
placeholder_group_(placeholder_group) {
|
||||
}
|
||||
|
||||
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, Timestamp timestamp)
|
||||
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
|
||||
Timestamp timestamp)
|
||||
: segment_(segment), timestamp_(timestamp) {
|
||||
placeholder_group_ = nullptr;
|
||||
}
|
||||
|
|
|
@ -40,7 +40,8 @@ class ExtractInfoExprVisitor : public ExprVisitor {
|
|||
visit(CompareExpr& expr) override;
|
||||
|
||||
public:
|
||||
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
|
||||
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info)
|
||||
: plan_info_(plan_info) {
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -28,7 +28,8 @@ class ExtractInfoPlanNodeVisitor : public PlanNodeVisitor {
|
|||
visit(RetrievePlanNode& node) override;
|
||||
|
||||
public:
|
||||
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
|
||||
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info)
|
||||
: plan_info_(plan_info) {
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -28,15 +28,19 @@ namespace milvus::query {
|
|||
namespace impl {
|
||||
class ExecExprVisitor : ExprVisitor {
|
||||
public:
|
||||
ExecExprVisitor(const segcore::SegmentInternalInterface& segment, int64_t row_count, Timestamp timestamp)
|
||||
ExecExprVisitor(const segcore::SegmentInternalInterface& segment,
|
||||
int64_t row_count,
|
||||
Timestamp timestamp)
|
||||
: segment_(segment), row_count_(row_count), timestamp_(timestamp) {
|
||||
}
|
||||
|
||||
BitsetType
|
||||
call_child(Expr& expr) {
|
||||
AssertInfo(!bitset_opt_.has_value(), "[ExecExprVisitor]Bitset already has value before accept");
|
||||
AssertInfo(!bitset_opt_.has_value(),
|
||||
"[ExecExprVisitor]Bitset already has value before accept");
|
||||
expr.accept(*this);
|
||||
AssertInfo(bitset_opt_.has_value(), "[ExecExprVisitor]Bitset doesn't have value after accept");
|
||||
AssertInfo(bitset_opt_.has_value(),
|
||||
"[ExecExprVisitor]Bitset doesn't have value after accept");
|
||||
auto res = std::move(bitset_opt_);
|
||||
bitset_opt_ = std::nullopt;
|
||||
return std::move(res.value());
|
||||
|
@ -45,7 +49,9 @@ class ExecExprVisitor : ExprVisitor {
|
|||
public:
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecRangeVisitorImpl(FieldId field_id, IndexFunc func, ElementFunc element_func) -> BitsetType;
|
||||
ExecRangeVisitorImpl(FieldId field_id,
|
||||
IndexFunc func,
|
||||
ElementFunc element_func) -> BitsetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
|
@ -53,7 +59,8 @@ class ExecExprVisitor : ExprVisitor {
|
|||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
|
||||
ExecBinaryArithOpEvalRangeVisitorDispatcher(
|
||||
BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
|
@ -69,7 +76,8 @@ class ExecExprVisitor : ExprVisitor {
|
|||
|
||||
template <typename CmpFunc>
|
||||
auto
|
||||
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) -> BitsetType;
|
||||
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
|
||||
-> BitsetType;
|
||||
|
||||
private:
|
||||
const segcore::SegmentInternalInterface& segment_;
|
||||
|
@ -93,7 +101,8 @@ ExecExprVisitor::visit(LogicalUnaryExpr& expr) {
|
|||
PanicInfo("Invalid Unary Op");
|
||||
}
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
|
||||
|
@ -102,7 +111,8 @@ ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
|
|||
using OpType = LogicalBinaryExpr::OpType;
|
||||
auto left = call_child(*expr.left_);
|
||||
auto right = call_child(*expr.right_);
|
||||
AssertInfo(left.size() == right.size(), "[ExecExprVisitor]Left size not equal to right size");
|
||||
AssertInfo(left.size() == right.size(),
|
||||
"[ExecExprVisitor]Left size not equal to right size");
|
||||
auto res = std::move(left);
|
||||
switch (expr.op_type_) {
|
||||
case OpType::LogicalAnd: {
|
||||
|
@ -125,7 +135,8 @@ ExecExprVisitor::visit(LogicalBinaryExpr& expr) {
|
|||
PanicInfo("Invalid Binary Op");
|
||||
}
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
|
||||
|
@ -151,7 +162,9 @@ Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
|
|||
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func) -> BitsetType {
|
||||
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
|
||||
IndexFunc index_func,
|
||||
ElementFunc element_func) -> BitsetType {
|
||||
auto& schema = segment_.get_schema();
|
||||
auto& field_meta = schema[field_id];
|
||||
auto indexing_barrier = segment_.num_chunk_index(field_id);
|
||||
|
@ -159,18 +172,24 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, El
|
|||
auto num_chunk = upper_div(row_count_, size_per_chunk);
|
||||
std::deque<BitsetType> results;
|
||||
|
||||
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
|
||||
const Index& indexing = segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
|
||||
const Index& indexing =
|
||||
segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
|
||||
// NOTE: knowhere is not const-ready
|
||||
// This is a dirty workaround
|
||||
auto data = index_func(const_cast<Index*>(&indexing));
|
||||
AssertInfo(data->size() == size_per_chunk, "[ExecExprVisitor]Data size not equal to size_per_chunk");
|
||||
AssertInfo(data->size() == size_per_chunk,
|
||||
"[ExecExprVisitor]Data size not equal to size_per_chunk");
|
||||
results.emplace_back(std::move(*data));
|
||||
}
|
||||
for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
|
||||
auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
|
||||
auto this_size = chunk_id == num_chunk - 1
|
||||
? row_count_ - chunk_id * size_per_chunk
|
||||
: size_per_chunk;
|
||||
BitsetType result(this_size);
|
||||
auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
|
||||
const T* data = chunk.data();
|
||||
|
@ -182,13 +201,16 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, IndexFunc index_func, El
|
|||
results.emplace_back(std::move(result));
|
||||
}
|
||||
auto final_result = Assemble(results);
|
||||
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
|
||||
AssertInfo(final_result.size() == row_count_,
|
||||
"[ExecExprVisitor]Final result size not equal to row count");
|
||||
return final_result;
|
||||
}
|
||||
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func, ElementFunc element_func)
|
||||
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
|
||||
IndexFunc index_func,
|
||||
ElementFunc element_func)
|
||||
-> BitsetType {
|
||||
auto& schema = segment_.get_schema();
|
||||
auto& field_meta = schema[field_id];
|
||||
|
@ -205,23 +227,31 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func
|
|||
// if sealed segment has loaded raw data on this field, then index_barrier = 0 and data_barrier = 1
|
||||
// in this case, sealed segment execute expr plan using raw data
|
||||
for (auto chunk_id = 0; chunk_id < data_barrier; ++chunk_id) {
|
||||
auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
|
||||
auto this_size = chunk_id == num_chunk - 1
|
||||
? row_count_ - chunk_id * size_per_chunk
|
||||
: size_per_chunk;
|
||||
BitsetType result(this_size);
|
||||
auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
|
||||
const T* data = chunk.data();
|
||||
for (int index = 0; index < this_size; ++index) {
|
||||
result[index] = element_func(data[index]);
|
||||
}
|
||||
AssertInfo(result.size() == this_size, "[ExecExprVisitor]Chunk result size not equal to expected size");
|
||||
AssertInfo(
|
||||
result.size() == this_size,
|
||||
"[ExecExprVisitor]Chunk result size not equal to expected size");
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
|
||||
// if sealed segment has loaded scalar index for this field, then index_barrier = 1 and data_barrier = 0
|
||||
// in this case, sealed segment execute expr plan using scalar index
|
||||
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
for (auto chunk_id = data_barrier; chunk_id < indexing_barrier; ++chunk_id) {
|
||||
auto& indexing = segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
|
||||
for (auto chunk_id = data_barrier; chunk_id < indexing_barrier;
|
||||
++chunk_id) {
|
||||
auto& indexing =
|
||||
segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
|
||||
auto this_size = const_cast<Index*>(&indexing)->Count();
|
||||
BitsetType result(this_size);
|
||||
for (int offset = 0; offset < this_size; ++offset) {
|
||||
|
@ -231,7 +261,8 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func
|
|||
}
|
||||
|
||||
auto final_result = Assemble(results);
|
||||
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
|
||||
AssertInfo(final_result.size() == row_count_,
|
||||
"[ExecExprVisitor]Final result size not equal to row count");
|
||||
return final_result;
|
||||
}
|
||||
|
||||
|
@ -239,8 +270,11 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, IndexFunc index_func
|
|||
#pragma ide diagnostic ignored "Simplify"
|
||||
template <typename T>
|
||||
auto
|
||||
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType {
|
||||
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
|
||||
ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw)
|
||||
-> BitsetType {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
auto& expr = static_cast<UnaryRangeExprImpl<IndexInnerType>&>(expr_raw);
|
||||
|
||||
|
@ -248,34 +282,52 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
|
|||
auto val = IndexInnerType(expr.value_);
|
||||
switch (op) {
|
||||
case OpType::Equal: {
|
||||
auto index_func = [val](Index* index) { return index->In(1, &val); };
|
||||
auto index_func = [val](Index* index) {
|
||||
return index->In(1, &val);
|
||||
};
|
||||
auto elem_func = [val](T x) { return (x == val); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case OpType::NotEqual: {
|
||||
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
|
||||
auto index_func = [val](Index* index) {
|
||||
return index->NotIn(1, &val);
|
||||
};
|
||||
auto elem_func = [val](T x) { return (x != val); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case OpType::GreaterEqual: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, OpType::GreaterEqual); };
|
||||
auto index_func = [val](Index* index) {
|
||||
return index->Range(val, OpType::GreaterEqual);
|
||||
};
|
||||
auto elem_func = [val](T x) { return (x >= val); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case OpType::GreaterThan: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, OpType::GreaterThan); };
|
||||
auto index_func = [val](Index* index) {
|
||||
return index->Range(val, OpType::GreaterThan);
|
||||
};
|
||||
auto elem_func = [val](T x) { return (x > val); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case OpType::LessEqual: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, OpType::LessEqual); };
|
||||
auto index_func = [val](Index* index) {
|
||||
return index->Range(val, OpType::LessEqual);
|
||||
};
|
||||
auto elem_func = [val](T x) { return (x <= val); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case OpType::LessThan: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, OpType::LessThan); };
|
||||
auto index_func = [val](Index* index) {
|
||||
return index->Range(val, OpType::LessThan);
|
||||
};
|
||||
auto elem_func = [val](T x) { return (x < val); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case OpType::PrefixMatch: {
|
||||
auto index_func = [val](Index* index) {
|
||||
|
@ -285,7 +337,8 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
|
|||
return index->Query(std::move(dataset));
|
||||
};
|
||||
auto elem_func = [val, op](T x) { return Match(x, val, op); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
// TODO: PostfixMatch
|
||||
default: {
|
||||
|
@ -299,7 +352,8 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
|
|||
#pragma ide diagnostic ignored "Simplify"
|
||||
template <typename T>
|
||||
auto
|
||||
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
|
||||
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(
|
||||
BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
|
||||
auto& expr = static_cast<BinaryArithOpEvalRangeExprImpl<T>&>(expr_raw);
|
||||
using Index = index::ScalarIndex<T>;
|
||||
auto arith_op = expr.arith_op_;
|
||||
|
@ -311,46 +365,64 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRa
|
|||
case OpType::Equal: {
|
||||
switch (arith_op) {
|
||||
case ArithOpType::Add: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x + right_operand) == val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x + right_operand) == val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x + right_operand) == val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Sub: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x - right_operand) == val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x - right_operand) == val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x - right_operand) == val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Mul: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x * right_operand) == val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x * right_operand) == val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x * right_operand) == val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Div: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x / right_operand) == val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x / right_operand) == val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x / right_operand) == val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Mod: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return static_cast<T>(fmod(x, right_operand)) == val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return (static_cast<T>(fmod(x, right_operand)) == val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported arithmetic operation");
|
||||
|
@ -360,46 +432,64 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRa
|
|||
case OpType::NotEqual: {
|
||||
switch (arith_op) {
|
||||
case ArithOpType::Add: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x + right_operand) != val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x + right_operand) != val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x + right_operand) != val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Sub: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x - right_operand) != val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x - right_operand) != val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x - right_operand) != val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Mul: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x * right_operand) != val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x * right_operand) != val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x * right_operand) != val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Div: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return (x / right_operand) != val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) { return ((x / right_operand) != val); };
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return ((x / right_operand) != val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
case ArithOpType::Mod: {
|
||||
auto index_func = [val, right_operand](Index* index, size_t offset) {
|
||||
auto index_func = [val, right_operand](Index* index,
|
||||
size_t offset) {
|
||||
auto x = index->Reverse_Lookup(offset);
|
||||
return static_cast<T>(fmod(x, right_operand)) != val;
|
||||
};
|
||||
auto elem_func = [val, right_operand](T x) {
|
||||
return (static_cast<T>(fmod(x, right_operand)) != val);
|
||||
};
|
||||
return ExecDataRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
return ExecDataRangeVisitorImpl<T>(
|
||||
expr.field_id_, index_func, elem_func);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported arithmetic operation");
|
||||
|
@ -417,8 +507,11 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRa
|
|||
#pragma ide diagnostic ignored "Simplify"
|
||||
template <typename T>
|
||||
auto
|
||||
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType {
|
||||
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
|
||||
ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw)
|
||||
-> BitsetType {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
auto& expr = static_cast<BinaryRangeExprImpl<IndexInnerType>&>(expr_raw);
|
||||
|
||||
|
@ -427,7 +520,9 @@ ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) ->
|
|||
IndexInnerType val1 = IndexInnerType(expr.lower_value_);
|
||||
IndexInnerType val2 = IndexInnerType(expr.upper_value_);
|
||||
|
||||
auto index_func = [=](Index* index) { return index->Range(val1, lower_inclusive, val2, upper_inclusive); };
|
||||
auto index_func = [=](Index* index) {
|
||||
return index->Range(val1, lower_inclusive, val2, upper_inclusive);
|
||||
};
|
||||
if (lower_inclusive && upper_inclusive) {
|
||||
auto elem_func = [val1, val2](T x) { return (val1 <= x && x <= val2); };
|
||||
return ExecRangeVisitorImpl<T>(expr.field_id_, index_func, elem_func);
|
||||
|
@ -490,7 +585,8 @@ ExecExprVisitor::visit(UnaryRangeExpr& expr) {
|
|||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
|
||||
|
@ -528,7 +624,8 @@ ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
|
|||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
|
||||
|
@ -578,7 +675,8 @@ ExecExprVisitor::visit(BinaryRangeExpr& expr) {
|
|||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
|
||||
|
@ -598,8 +696,16 @@ struct relational {
|
|||
|
||||
template <typename Op>
|
||||
auto
|
||||
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) -> BitsetType {
|
||||
using number = boost::variant<bool, int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
|
||||
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
|
||||
-> BitsetType {
|
||||
using number = boost::variant<bool,
|
||||
int8_t,
|
||||
int16_t,
|
||||
int32_t,
|
||||
int64_t,
|
||||
float,
|
||||
double,
|
||||
std::string>;
|
||||
auto size_per_chunk = segment_.size_per_chunk();
|
||||
auto num_chunk = upper_div(row_count_, size_per_chunk);
|
||||
std::deque<BitsetType> bitsets;
|
||||
|
@ -607,120 +713,194 @@ ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) -> BitsetTy
|
|||
// check for sealed segment, load either raw field data or index
|
||||
auto left_indexing_barrier = segment_.num_chunk_index(expr.left_field_id_);
|
||||
auto left_data_barrier = segment_.num_chunk_data(expr.left_field_id_);
|
||||
AssertInfo(std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
|
||||
"max(left_data_barrier, left_indexing_barrier) not equal to num_chunk");
|
||||
AssertInfo(
|
||||
std::max(left_data_barrier, left_indexing_barrier) == num_chunk,
|
||||
"max(left_data_barrier, left_indexing_barrier) not equal to num_chunk");
|
||||
|
||||
auto right_indexing_barrier = segment_.num_chunk_index(expr.right_field_id_);
|
||||
auto right_indexing_barrier =
|
||||
segment_.num_chunk_index(expr.right_field_id_);
|
||||
auto right_data_barrier = segment_.num_chunk_data(expr.right_field_id_);
|
||||
AssertInfo(std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
|
||||
"max(right_data_barrier, right_indexing_barrier) not equal to num_chunk");
|
||||
AssertInfo(
|
||||
std::max(right_data_barrier, right_indexing_barrier) == num_chunk,
|
||||
"max(right_data_barrier, right_indexing_barrier) not equal to "
|
||||
"num_chunk");
|
||||
|
||||
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
|
||||
auto size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
|
||||
auto getChunkData = [&, chunk_id](DataType type, FieldId field_id,
|
||||
int64_t data_barrier) -> std::function<const number(int)> {
|
||||
auto size = chunk_id == num_chunk - 1
|
||||
? row_count_ - chunk_id * size_per_chunk
|
||||
: size_per_chunk;
|
||||
auto getChunkData =
|
||||
[&, chunk_id](DataType type, FieldId field_id, int64_t data_barrier)
|
||||
-> std::function<const number(int)> {
|
||||
switch (type) {
|
||||
case DataType::BOOL: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<bool>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<bool>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<bool>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<bool>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::INT8: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<int8_t>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<int8_t>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<int8_t>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<int8_t>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::INT16: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<int16_t>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<int16_t>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<int16_t>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<int16_t>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::INT32: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<int32_t>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<int32_t>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<int32_t>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<int32_t>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::INT64: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<int64_t>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<int64_t>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<int64_t>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<int64_t>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<float>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<float>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<float>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<float>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
if (chunk_id < data_barrier) {
|
||||
auto chunk_data = segment_.chunk_data<double>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_.chunk_data<double>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<double>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing = segment_.chunk_scalar_index<double>(
|
||||
field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
if (chunk_id < data_barrier) {
|
||||
if (segment_.type() == SegmentType::Growing) {
|
||||
auto chunk_data = segment_.chunk_data<std::string>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
auto chunk_data =
|
||||
segment_
|
||||
.chunk_data<std::string>(field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return chunk_data[i];
|
||||
};
|
||||
} else {
|
||||
auto chunk_data = segment_.chunk_data<std::string_view>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return std::string(chunk_data[i]); };
|
||||
auto chunk_data = segment_
|
||||
.chunk_data<std::string_view>(
|
||||
field_id, chunk_id)
|
||||
.data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return std::string(chunk_data[i]);
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// for case, sealed segment has loaded index for scalar field instead of raw data
|
||||
auto& indexing = segment_.chunk_scalar_index<std::string>(field_id, chunk_id);
|
||||
return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); };
|
||||
auto& indexing =
|
||||
segment_.chunk_scalar_index<std::string>(field_id,
|
||||
chunk_id);
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
default:
|
||||
PanicInfo("unsupported datatype");
|
||||
}
|
||||
};
|
||||
auto left = getChunkData(expr.left_data_type_, expr.left_field_id_, left_data_barrier);
|
||||
auto right = getChunkData(expr.right_data_type_, expr.right_field_id_, right_data_barrier);
|
||||
auto left = getChunkData(
|
||||
expr.left_data_type_, expr.left_field_id_, left_data_barrier);
|
||||
auto right = getChunkData(
|
||||
expr.right_data_type_, expr.right_field_id_, right_data_barrier);
|
||||
|
||||
BitsetType bitset(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
bool is_in = boost::apply_visitor(Relational<decltype(op)>{}, left(i), right(i));
|
||||
bool is_in = boost::apply_visitor(
|
||||
Relational<decltype(op)>{}, left(i), right(i));
|
||||
bitset[i] = is_in;
|
||||
}
|
||||
bitsets.emplace_back(std::move(bitset));
|
||||
}
|
||||
auto final_result = Assemble(bitsets);
|
||||
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(final_result.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
return final_result;
|
||||
}
|
||||
|
||||
|
@ -729,10 +909,12 @@ ExecExprVisitor::visit(CompareExpr& expr) {
|
|||
auto& schema = segment_.get_schema();
|
||||
auto& left_field_meta = schema[expr.left_field_id_];
|
||||
auto& right_field_meta = schema[expr.right_field_id_];
|
||||
AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(),
|
||||
"[ExecExprVisitor]Left data type not equal to left field mata type");
|
||||
AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(),
|
||||
"[ExecExprVisitor]right data type not equal to right field mata type");
|
||||
AssertInfo(
|
||||
expr.left_data_type_ == left_field_meta.get_data_type(),
|
||||
"[ExecExprVisitor]Left data type not equal to left field meta type");
|
||||
AssertInfo(
|
||||
expr.right_data_type_ == right_field_meta.get_data_type(),
|
||||
"[ExecExprVisitor]right data type not equal to right field meta type");
|
||||
|
||||
BitsetType res;
|
||||
switch (expr.op_type_) {
|
||||
|
@ -761,7 +943,8 @@ ExecExprVisitor::visit(CompareExpr& expr) {
|
|||
break;
|
||||
}
|
||||
case OpType::PrefixMatch: {
|
||||
res = ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
|
||||
res =
|
||||
ExecCompareExprDispatcher(expr, MatchOp<OpType::PrefixMatch>{});
|
||||
break;
|
||||
}
|
||||
// case OpType::PostfixMatch: {
|
||||
|
@ -770,7 +953,8 @@ ExecExprVisitor::visit(CompareExpr& expr) {
|
|||
PanicInfo("unsupported optype");
|
||||
}
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
|
||||
|
@ -785,7 +969,8 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
|
|||
|
||||
bool use_pk_index = false;
|
||||
if (primary_filed_id.has_value()) {
|
||||
use_pk_index = primary_filed_id.value() == field_id && IsPrimaryKeyDataType(field_meta.get_data_type());
|
||||
use_pk_index = primary_filed_id.value() == field_id &&
|
||||
IsPrimaryKeyDataType(field_meta.get_data_type());
|
||||
}
|
||||
|
||||
if (use_pk_index) {
|
||||
|
@ -816,7 +1001,8 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
|
|||
auto _offset = (int64_t)offset.get();
|
||||
bitset[_offset] = true;
|
||||
}
|
||||
AssertInfo(bitset.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(bitset.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
return bitset;
|
||||
}
|
||||
|
||||
|
@ -825,27 +1011,34 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType {
|
|||
|
||||
template <>
|
||||
auto
|
||||
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw) -> BitsetType {
|
||||
ExecExprVisitor::ExecTermVisitorImpl<std::string>(TermExpr& expr_raw)
|
||||
-> BitsetType {
|
||||
return ExecTermVisitorImplTemplate<std::string>(expr_raw);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto
|
||||
ExecExprVisitor::ExecTermVisitorImpl<std::string_view>(TermExpr& expr_raw) -> BitsetType {
|
||||
ExecExprVisitor::ExecTermVisitorImpl<std::string_view>(TermExpr& expr_raw)
|
||||
-> BitsetType {
|
||||
return ExecTermVisitorImplTemplate<std::string_view>(expr_raw);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
|
||||
typedef std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T> IndexInnerType;
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
auto& expr = static_cast<TermExprImpl<IndexInnerType>&>(expr_raw);
|
||||
const std::vector<IndexInnerType> terms(expr.terms_.begin(), expr.terms_.end());
|
||||
const std::vector<IndexInnerType> terms(expr.terms_.begin(),
|
||||
expr.terms_.end());
|
||||
auto n = terms.size();
|
||||
std::unordered_set<T> term_set(expr.terms_.begin(), expr.terms_.end());
|
||||
|
||||
auto index_func = [&terms, n](Index* index) { return index->In(n, terms.data()); };
|
||||
auto index_func = [&terms, n](Index* index) {
|
||||
return index->In(n, terms.data());
|
||||
};
|
||||
auto elem_func = [&terms, &term_set](T x) {
|
||||
//// terms has already been sorted.
|
||||
// return std::binary_search(terms.begin(), terms.end(), x);
|
||||
|
@ -858,7 +1051,8 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
|
|||
// TODO: bool is so ugly here.
|
||||
template <>
|
||||
auto
|
||||
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw) -> BitsetType {
|
||||
ExecExprVisitor::ExecTermVisitorImplTemplate<bool>(TermExpr& expr_raw)
|
||||
-> BitsetType {
|
||||
using T = bool;
|
||||
auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
|
||||
using Index = index::ScalarIndex<T>;
|
||||
|
@ -932,7 +1126,8 @@ ExecExprVisitor::visit(TermExpr& expr) {
|
|||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
|
||||
AssertInfo(res.size() == row_count_,
|
||||
"[ExecExprVisitor]Size of results not equal row count");
|
||||
bitset_opt_ = std::move(res);
|
||||
}
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -29,7 +29,9 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
|
|||
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
|
||||
Timestamp timestamp,
|
||||
const PlaceholderGroup& placeholder_group)
|
||||
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
|
||||
: segment_(segment),
|
||||
timestamp_(timestamp),
|
||||
placeholder_group_(placeholder_group) {
|
||||
}
|
||||
|
||||
SearchResult
|
||||
|
@ -59,7 +61,10 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
|
|||
static SearchResult
|
||||
empty_search_result(int64_t num_queries, SearchInfo& search_info) {
|
||||
SearchResult final_result;
|
||||
SubSearchResult result(num_queries, search_info.topk_, search_info.metric_type_, search_info.round_decimal_);
|
||||
SubSearchResult result(num_queries,
|
||||
search_info.topk_,
|
||||
search_info.metric_type_,
|
||||
search_info.round_decimal_);
|
||||
final_result.total_nq_ = num_queries;
|
||||
final_result.unity_topK_ = search_info.topk_;
|
||||
final_result.seg_offsets_ = std::move(result.mutable_seg_offsets());
|
||||
|
@ -72,7 +77,8 @@ void
|
|||
ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
||||
// TODO: optimize here, remove the dynamic cast
|
||||
assert(!search_result_opt_.has_value());
|
||||
auto segment = dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
|
||||
auto segment =
|
||||
dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
|
||||
AssertInfo(segment, "support SegmentSmallIndex Only");
|
||||
SearchResult search_result;
|
||||
auto& ph = placeholder_group_->at(0);
|
||||
|
@ -85,14 +91,16 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
|||
|
||||
// skip all calculation
|
||||
if (active_count == 0) {
|
||||
search_result_opt_ = empty_search_result(num_queries, node.search_info_);
|
||||
search_result_opt_ =
|
||||
empty_search_result(num_queries, node.search_info_);
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<BitsetType> bitset_holder;
|
||||
if (node.predicate_.has_value()) {
|
||||
bitset_holder = std::make_unique<BitsetType>(
|
||||
ExecExprVisitor(*segment, active_count, timestamp_).call_child(*node.predicate_.value()));
|
||||
ExecExprVisitor(*segment, active_count, timestamp_)
|
||||
.call_child(*node.predicate_.value()));
|
||||
bitset_holder->flip();
|
||||
} else {
|
||||
bitset_holder = std::make_unique<BitsetType>(active_count, false);
|
||||
|
@ -102,11 +110,17 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
|||
segment->mask_with_delete(*bitset_holder, active_count, timestamp_);
|
||||
// if bitset_holder is all 1's, we got empty result
|
||||
if (bitset_holder->all()) {
|
||||
search_result_opt_ = empty_search_result(num_queries, node.search_info_);
|
||||
search_result_opt_ =
|
||||
empty_search_result(num_queries, node.search_info_);
|
||||
return;
|
||||
}
|
||||
BitsetView final_view = *bitset_holder;
|
||||
segment->vector_search(node.search_info_, src_data, num_queries, timestamp_, final_view, search_result);
|
||||
segment->vector_search(node.search_info_,
|
||||
src_data,
|
||||
num_queries,
|
||||
timestamp_,
|
||||
final_view,
|
||||
search_result);
|
||||
|
||||
search_result_opt_ = std::move(search_result);
|
||||
}
|
||||
|
@ -114,7 +128,8 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
|||
void
|
||||
ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
||||
assert(!retrieve_result_opt_.has_value());
|
||||
auto segment = dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
|
||||
auto segment =
|
||||
dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
|
||||
AssertInfo(segment, "Support SegmentSmallIndex Only");
|
||||
RetrieveResult retrieve_result;
|
||||
|
||||
|
@ -127,7 +142,8 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
|||
|
||||
BitsetType bitset_holder;
|
||||
if (node.predicate_ != nullptr) {
|
||||
bitset_holder = ExecExprVisitor(*segment, active_count, timestamp_).call_child(*(node.predicate_));
|
||||
bitset_holder = ExecExprVisitor(*segment, active_count, timestamp_)
|
||||
.call_child(*(node.predicate_));
|
||||
bitset_holder.flip();
|
||||
}
|
||||
|
||||
|
@ -142,8 +158,9 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
|||
|
||||
BitsetView final_view = bitset_holder;
|
||||
auto seg_offsets = segment->search_ids(final_view, timestamp_);
|
||||
retrieve_result.result_offsets_.assign((int64_t*)seg_offsets.data(),
|
||||
(int64_t*)seg_offsets.data() + seg_offsets.size());
|
||||
retrieve_result.result_offsets_.assign(
|
||||
(int64_t*)seg_offsets.data(),
|
||||
(int64_t*)seg_offsets.data() + seg_offsets.size());
|
||||
retrieve_result_opt_ = std::move(retrieve_result);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ namespace impl {
|
|||
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
|
||||
class ExtractInfoExprVisitor : ExprVisitor {
|
||||
public:
|
||||
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
|
||||
explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info)
|
||||
: plan_info_(plan_info) {
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -20,7 +20,8 @@ namespace impl {
|
|||
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
|
||||
class ExtractInfoPlanNodeVisitor : PlanNodeVisitor {
|
||||
public:
|
||||
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {
|
||||
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info)
|
||||
: plan_info_(plan_info) {
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -58,11 +58,13 @@ class ShowExprNodeVisitor : ExprVisitor {
|
|||
|
||||
void
|
||||
ShowExprVisitor::visit(LogicalUnaryExpr& expr) {
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
using OpType = LogicalUnaryExpr::OpType;
|
||||
|
||||
// TODO: use magic_enum if available
|
||||
AssertInfo(expr.op_type_ == OpType::LogicalNot, "[ShowExprVisitor]Expr op type isn't LogicNot");
|
||||
AssertInfo(expr.op_type_ == OpType::LogicalNot,
|
||||
"[ShowExprVisitor]Expr op type isn't LogicNot");
|
||||
auto op_name = "LogicalNot";
|
||||
|
||||
Json extra{
|
||||
|
@ -74,7 +76,8 @@ ShowExprVisitor::visit(LogicalUnaryExpr& expr) {
|
|||
|
||||
void
|
||||
ShowExprVisitor::visit(LogicalBinaryExpr& expr) {
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
|
||||
// TODO: use magic_enum if available
|
||||
|
@ -108,8 +111,10 @@ TermExtract(const TermExpr& expr_raw) {
|
|||
|
||||
void
|
||||
ShowExprVisitor::visit(TermExpr& expr) {
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false,
|
||||
"[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
auto terms = [&] {
|
||||
switch (expr.data_type_) {
|
||||
case DataType::BOOL:
|
||||
|
@ -145,7 +150,9 @@ UnaryRangeExtract(const UnaryRangeExpr& expr_raw) {
|
|||
using proto::plan::OpType;
|
||||
using proto::plan::OpType_Name;
|
||||
auto expr = dynamic_cast<const UnaryRangeExprImpl<T>*>(&expr_raw);
|
||||
AssertInfo(expr, "[ShowExprVisitor]UnaryRangeExpr cast to UnaryRangeExprImpl failed");
|
||||
AssertInfo(
|
||||
expr,
|
||||
"[ShowExprVisitor]UnaryRangeExpr cast to UnaryRangeExprImpl failed");
|
||||
Json res{{"expr_type", "UnaryRange"},
|
||||
{"field_id", expr->field_id_.get()},
|
||||
{"data_type", datatype_name(expr->data_type_)},
|
||||
|
@ -156,8 +163,10 @@ UnaryRangeExtract(const UnaryRangeExpr& expr_raw) {
|
|||
|
||||
void
|
||||
ShowExprVisitor::visit(UnaryRangeExpr& expr) {
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false,
|
||||
"[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
switch (expr.data_type_) {
|
||||
case DataType::BOOL:
|
||||
json_opt_ = UnaryRangeExtract<bool>(expr);
|
||||
|
@ -191,7 +200,9 @@ BinaryRangeExtract(const BinaryRangeExpr& expr_raw) {
|
|||
using proto::plan::OpType;
|
||||
using proto::plan::OpType_Name;
|
||||
auto expr = dynamic_cast<const BinaryRangeExprImpl<T>*>(&expr_raw);
|
||||
AssertInfo(expr, "[ShowExprVisitor]BinaryRangeExpr cast to BinaryRangeExprImpl failed");
|
||||
AssertInfo(
|
||||
expr,
|
||||
"[ShowExprVisitor]BinaryRangeExpr cast to BinaryRangeExprImpl failed");
|
||||
Json res{{"expr_type", "BinaryRange"},
|
||||
{"field_id", expr->field_id_.get()},
|
||||
{"data_type", datatype_name(expr->data_type_)},
|
||||
|
@ -204,8 +215,10 @@ BinaryRangeExtract(const BinaryRangeExpr& expr_raw) {
|
|||
|
||||
void
|
||||
ShowExprVisitor::visit(BinaryRangeExpr& expr) {
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false,
|
||||
"[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
switch (expr.data_type_) {
|
||||
case DataType::BOOL:
|
||||
json_opt_ = BinaryRangeExtract<bool>(expr);
|
||||
|
@ -237,7 +250,8 @@ void
|
|||
ShowExprVisitor::visit(CompareExpr& expr) {
|
||||
using proto::plan::OpType;
|
||||
using proto::plan::OpType_Name;
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
|
||||
Json res{{"expr_type", "Compare"},
|
||||
{"left_field_id", expr.left_field_id_.get()},
|
||||
|
@ -256,13 +270,17 @@ BinaryArithOpEvalRangeExtract(const BinaryArithOpEvalRangeExpr& expr_raw) {
|
|||
using proto::plan::OpType;
|
||||
using proto::plan::OpType_Name;
|
||||
|
||||
auto expr = dynamic_cast<const BinaryArithOpEvalRangeExprImpl<T>*>(&expr_raw);
|
||||
AssertInfo(expr, "[ShowExprVisitor]BinaryArithOpEvalRangeExpr cast to BinaryArithOpEvalRangeExprImpl failed");
|
||||
auto expr =
|
||||
dynamic_cast<const BinaryArithOpEvalRangeExprImpl<T>*>(&expr_raw);
|
||||
AssertInfo(expr,
|
||||
"[ShowExprVisitor]BinaryArithOpEvalRangeExpr cast to "
|
||||
"BinaryArithOpEvalRangeExprImpl failed");
|
||||
|
||||
Json res{{"expr_type", "BinaryArithOpEvalRange"},
|
||||
{"field_offset", expr->field_id_.get()},
|
||||
{"data_type", datatype_name(expr->data_type_)},
|
||||
{"arith_op", ArithOpType_Name(static_cast<ArithOpType>(expr->arith_op_))},
|
||||
{"arith_op",
|
||||
ArithOpType_Name(static_cast<ArithOpType>(expr->arith_op_))},
|
||||
{"right_operand", expr->right_operand_},
|
||||
{"op", OpType_Name(static_cast<OpType>(expr->op_type_))},
|
||||
{"value", expr->value_}};
|
||||
|
@ -271,8 +289,10 @@ BinaryArithOpEvalRangeExtract(const BinaryArithOpEvalRangeExpr& expr_raw) {
|
|||
|
||||
void
|
||||
ShowExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
|
||||
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
AssertInfo(!json_opt_.has_value(),
|
||||
"[ShowExprVisitor]Ret json already has value before visit");
|
||||
AssertInfo(datatype_is_vector(expr.data_type_) == false,
|
||||
"[ShowExprVisitor]Data type of expr isn't vector type");
|
||||
switch (expr.data_type_) {
|
||||
case DataType::INT8:
|
||||
json_opt_ = BinaryArithOpEvalRangeExtract<int8_t>(expr);
|
||||
|
|
|
@ -62,8 +62,10 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
|||
};
|
||||
if (node.predicate_.has_value()) {
|
||||
ShowExprVisitor expr_show;
|
||||
AssertInfo(node.predicate_.value(), "[ShowPlanNodeVisitor]Can't get value from node predict");
|
||||
json_body["predicate"] = expr_show.call_child(node.predicate_->operator*());
|
||||
AssertInfo(node.predicate_.value(),
|
||||
"[ShowPlanNodeVisitor]Can't get value from node predict");
|
||||
json_body["predicate"] =
|
||||
expr_show.call_child(node.predicate_->operator*());
|
||||
} else {
|
||||
json_body["predicate"] = "None";
|
||||
}
|
||||
|
@ -84,8 +86,10 @@ ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) {
|
|||
};
|
||||
if (node.predicate_.has_value()) {
|
||||
ShowExprVisitor expr_show;
|
||||
AssertInfo(node.predicate_.value(), "[ShowPlanNodeVisitor]Can't get value from node predict");
|
||||
json_body["predicate"] = expr_show.call_child(node.predicate_->operator*());
|
||||
AssertInfo(node.predicate_.value(),
|
||||
"[ShowPlanNodeVisitor]Can't get value from node predict");
|
||||
json_body["predicate"] =
|
||||
expr_show.call_child(node.predicate_->operator*());
|
||||
} else {
|
||||
json_body["predicate"] = "None";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
|
||||
namespace milvus::segcore {
|
||||
|
||||
Collection::Collection(const std::string& collection_proto) : schema_proto_(collection_proto) {
|
||||
Collection::Collection(const std::string& collection_proto)
|
||||
: schema_proto_(collection_proto) {
|
||||
parse();
|
||||
}
|
||||
|
||||
|
@ -35,7 +36,8 @@ Collection::parse() {
|
|||
|
||||
Assert(!schema_proto_.empty());
|
||||
milvus::proto::schema::CollectionSchema collection_schema;
|
||||
auto suc = google::protobuf::TextFormat::ParseFromString(schema_proto_, &collection_schema);
|
||||
auto suc = google::protobuf::TextFormat::ParseFromString(
|
||||
schema_proto_, &collection_schema);
|
||||
|
||||
if (!suc) {
|
||||
std::cerr << "unmarshal schema string failed" << std::endl;
|
||||
|
|
|
@ -20,9 +20,13 @@ VectorBase::set_data_raw(ssize_t element_offset,
|
|||
const FieldMeta& field_meta) {
|
||||
if (field_meta.is_vector()) {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
return set_data_raw(element_offset, data->vectors().float_vector().data().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->vectors().float_vector().data().data(),
|
||||
element_count);
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
|
||||
return set_data_raw(element_offset, data->vectors().binary_vector().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->vectors().binary_vector().data(),
|
||||
element_count);
|
||||
} else {
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
|
@ -30,7 +34,9 @@ VectorBase::set_data_raw(ssize_t element_offset,
|
|||
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::BOOL: {
|
||||
return set_data_raw(element_offset, data->scalars().bool_data().data().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->scalars().bool_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
auto src_data = data->scalars().int_data().data();
|
||||
|
@ -45,16 +51,24 @@ VectorBase::set_data_raw(ssize_t element_offset,
|
|||
return set_data_raw(element_offset, data_raw.data(), element_count);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return set_data_raw(element_offset, data->scalars().int_data().data().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->scalars().int_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return set_data_raw(element_offset, data->scalars().long_data().data().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->scalars().long_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return set_data_raw(element_offset, data->scalars().float_data().data().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->scalars().float_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return set_data_raw(element_offset, data->scalars().double_data().data().data(), element_count);
|
||||
return set_data_raw(element_offset,
|
||||
data->scalars().double_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto begin = data->scalars().string_data().data().begin();
|
||||
|
@ -69,12 +83,16 @@ VectorBase::set_data_raw(ssize_t element_offset,
|
|||
}
|
||||
|
||||
void
|
||||
VectorBase::fill_chunk_data(ssize_t element_count, const DataArray* data, const FieldMeta& field_meta) {
|
||||
VectorBase::fill_chunk_data(ssize_t element_count,
|
||||
const DataArray* data,
|
||||
const FieldMeta& field_meta) {
|
||||
if (field_meta.is_vector()) {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
return fill_chunk_data(data->vectors().float_vector().data().data(), element_count);
|
||||
return fill_chunk_data(data->vectors().float_vector().data().data(),
|
||||
element_count);
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
|
||||
return fill_chunk_data(data->vectors().binary_vector().data(), element_count);
|
||||
return fill_chunk_data(data->vectors().binary_vector().data(),
|
||||
element_count);
|
||||
} else {
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
|
@ -82,7 +100,8 @@ VectorBase::fill_chunk_data(ssize_t element_count, const DataArray* data, const
|
|||
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::BOOL: {
|
||||
return fill_chunk_data(data->scalars().bool_data().data().data(), element_count);
|
||||
return fill_chunk_data(data->scalars().bool_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
auto src_data = data->scalars().int_data().data();
|
||||
|
@ -97,16 +116,20 @@ VectorBase::fill_chunk_data(ssize_t element_count, const DataArray* data, const
|
|||
return fill_chunk_data(data_raw.data(), element_count);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return fill_chunk_data(data->scalars().int_data().data().data(), element_count);
|
||||
return fill_chunk_data(data->scalars().int_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return fill_chunk_data(data->scalars().long_data().data().data(), element_count);
|
||||
return fill_chunk_data(data->scalars().long_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return fill_chunk_data(data->scalars().float_data().data().data(), element_count);
|
||||
return fill_chunk_data(data->scalars().float_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return fill_chunk_data(data->scalars().double_data().data().data(), element_count);
|
||||
return fill_chunk_data(data->scalars().double_data().data().data(),
|
||||
element_count);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto vec = static_cast<ConcurrentVector<std::string>*>(this);
|
||||
|
|
|
@ -81,7 +81,8 @@ class ThreadSafeVector {
|
|||
|
||||
class VectorBase {
|
||||
public:
|
||||
explicit VectorBase(int64_t size_per_chunk) : size_per_chunk_(size_per_chunk) {
|
||||
explicit VectorBase(int64_t size_per_chunk)
|
||||
: size_per_chunk_(size_per_chunk) {
|
||||
}
|
||||
virtual ~VectorBase() = default;
|
||||
|
||||
|
@ -89,16 +90,23 @@ class VectorBase {
|
|||
grow_to_at_least(int64_t element_count) = 0;
|
||||
|
||||
virtual void
|
||||
set_data_raw(ssize_t element_offset, const void* source, ssize_t element_count) = 0;
|
||||
set_data_raw(ssize_t element_offset,
|
||||
const void* source,
|
||||
ssize_t element_count) = 0;
|
||||
|
||||
void
|
||||
set_data_raw(ssize_t element_offset, ssize_t element_count, const DataArray* data, const FieldMeta& field_meta);
|
||||
set_data_raw(ssize_t element_offset,
|
||||
ssize_t element_count,
|
||||
const DataArray* data,
|
||||
const FieldMeta& field_meta);
|
||||
|
||||
virtual void
|
||||
fill_chunk_data(const void* source, ssize_t element_count) = 0;
|
||||
|
||||
void
|
||||
fill_chunk_data(ssize_t element_count, const DataArray* data, const FieldMeta& field_meta);
|
||||
fill_chunk_data(ssize_t element_count,
|
||||
const DataArray* data,
|
||||
const FieldMeta& field_meta);
|
||||
|
||||
virtual SpanBase
|
||||
get_span_base(int64_t chunk_id) const = 0;
|
||||
|
@ -135,7 +143,11 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
operator=(const ConcurrentVectorImpl&) = delete;
|
||||
|
||||
using TraitType =
|
||||
std::conditional_t<is_scalar, Type, std::conditional_t<std::is_same_v<Type, float>, FloatVector, BinaryVector>>;
|
||||
std::conditional_t<is_scalar,
|
||||
Type,
|
||||
std::conditional_t<std::is_same_v<Type, float>,
|
||||
FloatVector,
|
||||
BinaryVector>>;
|
||||
|
||||
public:
|
||||
explicit ConcurrentVectorImpl(ssize_t dim, int64_t size_per_chunk)
|
||||
|
@ -160,11 +172,13 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
auto& chunk = get_chunk(chunk_id);
|
||||
if constexpr (is_scalar) {
|
||||
return Span<TraitType>(chunk.data(), chunk.size());
|
||||
} else if constexpr (std::is_same_v<Type, int64_t> || std::is_same_v<Type, int>) {
|
||||
} else if constexpr (std::is_same_v<Type, int64_t> ||
|
||||
std::is_same_v<Type, int>) {
|
||||
// only for testing
|
||||
PanicInfo("unimplemented");
|
||||
} else {
|
||||
static_assert(std::is_same_v<typename TraitType::embedded_type, Type>);
|
||||
static_assert(
|
||||
std::is_same_v<typename TraitType::embedded_type, Type>);
|
||||
return Span<TraitType>(chunk.data(), chunk.size(), Dim);
|
||||
}
|
||||
}
|
||||
|
@ -185,23 +199,29 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
}
|
||||
|
||||
void
|
||||
set_data_raw(ssize_t element_offset, const void* source, ssize_t element_count) override {
|
||||
set_data_raw(ssize_t element_offset,
|
||||
const void* source,
|
||||
ssize_t element_count) override {
|
||||
if (element_count == 0) {
|
||||
return;
|
||||
}
|
||||
this->grow_to_at_least(element_offset + element_count);
|
||||
set_data(element_offset, static_cast<const Type*>(source), element_count);
|
||||
set_data(
|
||||
element_offset, static_cast<const Type*>(source), element_count);
|
||||
}
|
||||
|
||||
void
|
||||
set_data(ssize_t element_offset, const Type* source, ssize_t element_count) {
|
||||
set_data(ssize_t element_offset,
|
||||
const Type* source,
|
||||
ssize_t element_count) {
|
||||
auto chunk_id = element_offset / size_per_chunk_;
|
||||
auto chunk_offset = element_offset % size_per_chunk_;
|
||||
ssize_t source_offset = 0;
|
||||
// first partition:
|
||||
if (chunk_offset + element_count <= size_per_chunk_) {
|
||||
// only first
|
||||
fill_chunk(chunk_id, chunk_offset, element_count, source, source_offset);
|
||||
fill_chunk(
|
||||
chunk_id, chunk_offset, element_count, source, source_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -280,17 +300,25 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
|
||||
private:
|
||||
void
|
||||
fill_chunk(
|
||||
ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source, ssize_t source_offset) {
|
||||
fill_chunk(ssize_t chunk_id,
|
||||
ssize_t chunk_offset,
|
||||
ssize_t element_count,
|
||||
const Type* source,
|
||||
ssize_t source_offset) {
|
||||
if (element_count <= 0) {
|
||||
return;
|
||||
}
|
||||
auto chunk_num = chunks_.size();
|
||||
AssertInfo(chunk_id < chunk_num,
|
||||
fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}", chunk_id, chunk_num));
|
||||
AssertInfo(
|
||||
chunk_id < chunk_num,
|
||||
fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}",
|
||||
chunk_id,
|
||||
chunk_num));
|
||||
Chunk& chunk = chunks_[chunk_id];
|
||||
auto ptr = chunk.data();
|
||||
std::copy_n(source + source_offset * Dim, element_count * Dim, ptr + chunk_offset * Dim);
|
||||
std::copy_n(source + source_offset * Dim,
|
||||
element_count * Dim,
|
||||
ptr + chunk_offset * Dim);
|
||||
}
|
||||
|
||||
const ssize_t Dim;
|
||||
|
@ -304,20 +332,24 @@ class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
|
|||
public:
|
||||
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
|
||||
explicit ConcurrentVector(int64_t size_per_chunk)
|
||||
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(1, size_per_chunk) {
|
||||
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(
|
||||
1, size_per_chunk) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {
|
||||
class ConcurrentVector<FloatVector>
|
||||
: public ConcurrentVectorImpl<float, false> {
|
||||
public:
|
||||
ConcurrentVector(int64_t dim, int64_t size_per_chunk)
|
||||
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(dim, size_per_chunk) {
|
||||
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(
|
||||
dim, size_per_chunk) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ConcurrentVector<BinaryVector> : public ConcurrentVectorImpl<uint8_t, false> {
|
||||
class ConcurrentVector<BinaryVector>
|
||||
: public ConcurrentVectorImpl<uint8_t, false> {
|
||||
public:
|
||||
explicit ConcurrentVector(int64_t dim, int64_t size_per_chunk)
|
||||
: binary_dim_(dim), ConcurrentVectorImpl(dim / 8, size_per_chunk) {
|
||||
|
|
|
@ -32,7 +32,9 @@ struct DeletedRecord {
|
|||
};
|
||||
static constexpr int64_t deprecated_size_per_chunk = 32 * 1024;
|
||||
DeletedRecord()
|
||||
: lru_(std::make_shared<TmpBitmap>()), timestamps_(deprecated_size_per_chunk), pks_(deprecated_size_per_chunk) {
|
||||
: lru_(std::make_shared<TmpBitmap>()),
|
||||
timestamps_(deprecated_size_per_chunk),
|
||||
pks_(deprecated_size_per_chunk) {
|
||||
lru_->bitmap_ptr = std::make_shared<BitsetType>();
|
||||
}
|
||||
|
||||
|
@ -43,12 +45,16 @@ struct DeletedRecord {
|
|||
}
|
||||
|
||||
std::shared_ptr<TmpBitmap>
|
||||
clone_lru_entry(int64_t insert_barrier, int64_t del_barrier, int64_t& old_del_barrier, bool& hit_cache) {
|
||||
clone_lru_entry(int64_t insert_barrier,
|
||||
int64_t del_barrier,
|
||||
int64_t& old_del_barrier,
|
||||
bool& hit_cache) {
|
||||
std::shared_lock lck(shared_mutex_);
|
||||
auto res = lru_->clone(insert_barrier);
|
||||
old_del_barrier = lru_->del_barrier;
|
||||
|
||||
if (lru_->bitmap_ptr->size() == insert_barrier && lru_->del_barrier == del_barrier) {
|
||||
if (lru_->bitmap_ptr->size() == insert_barrier &&
|
||||
lru_->del_barrier == del_barrier) {
|
||||
hit_cache = true;
|
||||
} else {
|
||||
res->del_barrier = del_barrier;
|
||||
|
@ -61,7 +67,8 @@ struct DeletedRecord {
|
|||
insert_lru_entry(std::shared_ptr<TmpBitmap> new_entry, bool force = false) {
|
||||
std::lock_guard lck(shared_mutex_);
|
||||
if (new_entry->del_barrier <= lru_->del_barrier) {
|
||||
if (!force || new_entry->bitmap_ptr->size() <= lru_->bitmap_ptr->size()) {
|
||||
if (!force ||
|
||||
new_entry->bitmap_ptr->size() <= lru_->bitmap_ptr->size()) {
|
||||
// DO NOTHING
|
||||
return;
|
||||
}
|
||||
|
@ -81,7 +88,8 @@ struct DeletedRecord {
|
|||
};
|
||||
|
||||
inline auto
|
||||
DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr<TmpBitmap> {
|
||||
DeletedRecord::TmpBitmap::clone(int64_t capacity)
|
||||
-> std::shared_ptr<TmpBitmap> {
|
||||
auto res = std::make_shared<TmpBitmap>();
|
||||
res->del_barrier = this->del_barrier;
|
||||
res->bitmap_ptr = std::make_shared<BitsetType>();
|
||||
|
|
|
@ -21,8 +21,11 @@
|
|||
namespace milvus::segcore {
|
||||
|
||||
void
|
||||
VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
|
||||
AssertInfo(field_meta_.get_data_type() == DataType::VECTOR_FLOAT, "Data type of vector field is not VECTOR_FLOAT");
|
||||
VectorFieldIndexing::BuildIndexRange(int64_t ack_beg,
|
||||
int64_t ack_end,
|
||||
const VectorBase* vec_base) {
|
||||
AssertInfo(field_meta_.get_data_type() == DataType::VECTOR_FLOAT,
|
||||
"Data type of vector field is not VECTOR_FLOAT");
|
||||
auto dim = field_meta_.get_dim();
|
||||
|
||||
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
|
||||
|
@ -33,9 +36,12 @@ VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vec
|
|||
data_.grow_to_at_least(ack_end);
|
||||
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
|
||||
const auto& chunk = source->get_chunk(chunk_id);
|
||||
auto indexing = std::make_unique<index::VectorMemNMIndex>(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
|
||||
knowhere::metric::L2, IndexMode::MODE_CPU);
|
||||
auto dataset = knowhere::GenDataSet(source->get_size_per_chunk(), dim, chunk.data());
|
||||
auto indexing = std::make_unique<index::VectorMemNMIndex>(
|
||||
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
|
||||
knowhere::metric::L2,
|
||||
IndexMode::MODE_CPU);
|
||||
auto dataset = knowhere::GenDataSet(
|
||||
source->get_size_per_chunk(), dim, chunk.data());
|
||||
indexing->BuildWithDataset(dataset, conf);
|
||||
data_[chunk_id] = std::move(indexing);
|
||||
}
|
||||
|
@ -45,7 +51,8 @@ knowhere::Json
|
|||
VectorFieldIndexing::get_build_params() const {
|
||||
// TODO
|
||||
auto type_opt = field_meta_.get_metric_type();
|
||||
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
|
||||
AssertInfo(type_opt.has_value(),
|
||||
"Metric type of field meta doesn't have value");
|
||||
auto& metric_type = type_opt.value();
|
||||
auto& config = segcore_config_.at(metric_type);
|
||||
auto base_params = config.build_params;
|
||||
|
@ -61,12 +68,14 @@ knowhere::Json
|
|||
VectorFieldIndexing::get_search_params(int top_K) const {
|
||||
// TODO
|
||||
auto type_opt = field_meta_.get_metric_type();
|
||||
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
|
||||
AssertInfo(type_opt.has_value(),
|
||||
"Metric type of field meta doesn't have value");
|
||||
auto& metric_type = type_opt.value();
|
||||
auto& config = segcore_config_.at(metric_type);
|
||||
|
||||
auto base_params = config.search_params;
|
||||
AssertInfo(base_params.count("nprobe"), "Can't get nprobe from base params");
|
||||
AssertInfo(base_params.count("nprobe"),
|
||||
"Can't get nprobe from base params");
|
||||
base_params[knowhere::meta::TOPK] = top_K;
|
||||
base_params[knowhere::meta::METRIC_TYPE] = metric_type;
|
||||
|
||||
|
@ -75,7 +84,9 @@ VectorFieldIndexing::get_search_params(int top_K) const {
|
|||
|
||||
template <typename T>
|
||||
void
|
||||
ScalarFieldIndexing<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
|
||||
ScalarFieldIndexing<T>::BuildIndexRange(int64_t ack_beg,
|
||||
int64_t ack_end,
|
||||
const VectorBase* vec_base) {
|
||||
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
|
||||
AssertInfo(source, "vec_base can't cast to ConcurrentVector type");
|
||||
auto num_chunk = source->num_chunk();
|
||||
|
@ -101,7 +112,8 @@ std::unique_ptr<FieldIndexing>
|
|||
CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config) {
|
||||
if (field_meta.is_vector()) {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
return std::make_unique<VectorFieldIndexing>(field_meta, segcore_config);
|
||||
return std::make_unique<VectorFieldIndexing>(field_meta,
|
||||
segcore_config);
|
||||
} else {
|
||||
// TODO
|
||||
PanicInfo("unsupported");
|
||||
|
@ -109,21 +121,29 @@ CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config) {
|
|||
}
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::BOOL:
|
||||
return std::make_unique<ScalarFieldIndexing<bool>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<bool>>(field_meta,
|
||||
segcore_config);
|
||||
case DataType::INT8:
|
||||
return std::make_unique<ScalarFieldIndexing<int8_t>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<int8_t>>(
|
||||
field_meta, segcore_config);
|
||||
case DataType::INT16:
|
||||
return std::make_unique<ScalarFieldIndexing<int16_t>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<int16_t>>(
|
||||
field_meta, segcore_config);
|
||||
case DataType::INT32:
|
||||
return std::make_unique<ScalarFieldIndexing<int32_t>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<int32_t>>(
|
||||
field_meta, segcore_config);
|
||||
case DataType::INT64:
|
||||
return std::make_unique<ScalarFieldIndexing<int64_t>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<int64_t>>(
|
||||
field_meta, segcore_config);
|
||||
case DataType::FLOAT:
|
||||
return std::make_unique<ScalarFieldIndexing<float>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<float>>(field_meta,
|
||||
segcore_config);
|
||||
case DataType::DOUBLE:
|
||||
return std::make_unique<ScalarFieldIndexing<double>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<double>>(
|
||||
field_meta, segcore_config);
|
||||
case DataType::VARCHAR:
|
||||
return std::make_unique<ScalarFieldIndexing<std::string>>(field_meta, segcore_config);
|
||||
return std::make_unique<ScalarFieldIndexing<std::string>>(
|
||||
field_meta, segcore_config);
|
||||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@ namespace milvus::segcore {
|
|||
// All concurrent
|
||||
class FieldIndexing {
|
||||
public:
|
||||
explicit FieldIndexing(const FieldMeta& field_meta, const SegcoreConfig& segcore_config)
|
||||
explicit FieldIndexing(const FieldMeta& field_meta,
|
||||
const SegcoreConfig& segcore_config)
|
||||
: field_meta_(field_meta), segcore_config_(segcore_config) {
|
||||
}
|
||||
FieldIndexing(const FieldIndexing&) = delete;
|
||||
|
@ -41,7 +42,9 @@ class FieldIndexing {
|
|||
|
||||
// Do this in parallel
|
||||
virtual void
|
||||
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) = 0;
|
||||
BuildIndexRange(int64_t ack_beg,
|
||||
int64_t ack_end,
|
||||
const VectorBase* vec_base) = 0;
|
||||
|
||||
const FieldMeta&
|
||||
get_field_meta() {
|
||||
|
@ -68,7 +71,9 @@ class ScalarFieldIndexing : public FieldIndexing {
|
|||
using FieldIndexing::FieldIndexing;
|
||||
|
||||
void
|
||||
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override;
|
||||
BuildIndexRange(int64_t ack_beg,
|
||||
int64_t ack_end,
|
||||
const VectorBase* vec_base) override;
|
||||
|
||||
// concurrent
|
||||
index::ScalarIndex<T>*
|
||||
|
@ -86,7 +91,9 @@ class VectorFieldIndexing : public FieldIndexing {
|
|||
using FieldIndexing::FieldIndexing;
|
||||
|
||||
void
|
||||
BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) override;
|
||||
BuildIndexRange(int64_t ack_beg,
|
||||
int64_t ack_end,
|
||||
const VectorBase* vec_base) override;
|
||||
|
||||
// concurrent
|
||||
index::IndexBase*
|
||||
|
@ -110,7 +117,8 @@ CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config);
|
|||
|
||||
class IndexingRecord {
|
||||
public:
|
||||
explicit IndexingRecord(const Schema& schema, const SegcoreConfig& segcore_config)
|
||||
explicit IndexingRecord(const Schema& schema,
|
||||
const SegcoreConfig& segcore_config)
|
||||
: schema_(schema), segcore_config_(segcore_config) {
|
||||
Initialize();
|
||||
}
|
||||
|
@ -132,7 +140,8 @@ class IndexingRecord {
|
|||
}
|
||||
}
|
||||
|
||||
field_indexings_.try_emplace(field_id, CreateIndex(field_meta, segcore_config_));
|
||||
field_indexings_.try_emplace(
|
||||
field_id, CreateIndex(field_meta, segcore_config_));
|
||||
}
|
||||
assert(offset_id == schema_.size());
|
||||
}
|
||||
|
@ -140,7 +149,8 @@ class IndexingRecord {
|
|||
// concurrent, reentrant
|
||||
template <bool is_sealed>
|
||||
void
|
||||
UpdateResourceAck(int64_t chunk_ack, const InsertRecord<is_sealed>& record) {
|
||||
UpdateResourceAck(int64_t chunk_ack,
|
||||
const InsertRecord<is_sealed>& record) {
|
||||
if (resource_ack_ >= chunk_ack) {
|
||||
return;
|
||||
}
|
||||
|
@ -189,7 +199,8 @@ class IndexingRecord {
|
|||
|
||||
template <typename T>
|
||||
auto
|
||||
get_scalar_field_indexing(FieldId field_id) const -> const ScalarFieldIndexing<T>& {
|
||||
get_scalar_field_indexing(FieldId field_id) const
|
||||
-> const ScalarFieldIndexing<T>& {
|
||||
auto& entry = get_field_indexing(field_id);
|
||||
auto ptr = dynamic_cast<const ScalarFieldIndexing<T>*>(&entry);
|
||||
AssertInfo(ptr, "invalid indexing");
|
||||
|
|
|
@ -50,7 +50,8 @@ class OffsetHashMap : public OffsetMap {
|
|||
std::vector<int64_t>
|
||||
find(const PkType pk) const {
|
||||
auto offset_vector = map_.find(std::get<T>(pk));
|
||||
return offset_vector != map_.end() ? offset_vector->second : std::vector<int64_t>();
|
||||
return offset_vector != map_.end() ? offset_vector->second
|
||||
: std::vector<int64_t>();
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -60,7 +61,8 @@ class OffsetHashMap : public OffsetMap {
|
|||
|
||||
void
|
||||
seal() {
|
||||
PanicInfo("OffsetHashMap used for growing segment could not be sealed.");
|
||||
PanicInfo(
|
||||
"OffsetHashMap used for growing segment could not be sealed.");
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -138,26 +140,32 @@ struct InsertRecord {
|
|||
// pks to row offset
|
||||
std::unique_ptr<OffsetMap> pk2offset_;
|
||||
|
||||
InsertRecord(const Schema& schema, int64_t size_per_chunk) : row_ids_(size_per_chunk), timestamps_(size_per_chunk) {
|
||||
InsertRecord(const Schema& schema, int64_t size_per_chunk)
|
||||
: row_ids_(size_per_chunk), timestamps_(size_per_chunk) {
|
||||
std::optional<FieldId> pk_field_id = schema.get_primary_field_id();
|
||||
|
||||
for (auto& field : schema) {
|
||||
auto field_id = field.first;
|
||||
auto& field_meta = field.second;
|
||||
if (pk2offset_ == nullptr && pk_field_id.has_value() && pk_field_id.value() == field_id) {
|
||||
if (pk2offset_ == nullptr && pk_field_id.has_value() &&
|
||||
pk_field_id.value() == field_id) {
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::INT64: {
|
||||
if (is_sealed)
|
||||
pk2offset_ = std::make_unique<OffsetOrderedArray<int64_t>>();
|
||||
pk2offset_ =
|
||||
std::make_unique<OffsetOrderedArray<int64_t>>();
|
||||
else
|
||||
pk2offset_ = std::make_unique<OffsetHashMap<int64_t>>();
|
||||
pk2offset_ =
|
||||
std::make_unique<OffsetHashMap<int64_t>>();
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
if (is_sealed)
|
||||
pk2offset_ = std::make_unique<OffsetOrderedArray<std::string>>();
|
||||
pk2offset_ = std::make_unique<
|
||||
OffsetOrderedArray<std::string>>();
|
||||
else
|
||||
pk2offset_ = std::make_unique<OffsetHashMap<std::string>>();
|
||||
pk2offset_ =
|
||||
std::make_unique<OffsetHashMap<std::string>>();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -167,10 +175,13 @@ struct InsertRecord {
|
|||
}
|
||||
if (field_meta.is_vector()) {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
this->append_field_data<FloatVector>(field_id, field_meta.get_dim(), size_per_chunk);
|
||||
this->append_field_data<FloatVector>(
|
||||
field_id, field_meta.get_dim(), size_per_chunk);
|
||||
continue;
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
|
||||
this->append_field_data<BinaryVector>(field_id, field_meta.get_dim(), size_per_chunk);
|
||||
} else if (field_meta.get_data_type() ==
|
||||
DataType::VECTOR_BINARY) {
|
||||
this->append_field_data<BinaryVector>(
|
||||
field_id, field_meta.get_dim(), size_per_chunk);
|
||||
continue;
|
||||
} else {
|
||||
PanicInfo("unsupported");
|
||||
|
@ -206,7 +217,8 @@ struct InsertRecord {
|
|||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
this->append_field_data<std::string>(field_id, size_per_chunk);
|
||||
this->append_field_data<std::string>(field_id,
|
||||
size_per_chunk);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -263,7 +275,8 @@ struct InsertRecord {
|
|||
VectorBase*
|
||||
get_field_data_base(FieldId field_id) const {
|
||||
AssertInfo(fields_data_.find(field_id) != fields_data_.end(),
|
||||
"Cannot find field_data with field_id: " + std::to_string(field_id.get()));
|
||||
"Cannot find field_data with field_id: " +
|
||||
std::to_string(field_id.get()));
|
||||
auto ptr = fields_data_.at(field_id).get();
|
||||
return ptr;
|
||||
}
|
||||
|
@ -293,7 +306,8 @@ struct InsertRecord {
|
|||
void
|
||||
append_field_data(FieldId field_id, int64_t size_per_chunk) {
|
||||
static_assert(IsScalar<Type>);
|
||||
fields_data_.emplace(field_id, std::make_unique<ConcurrentVector<Type>>(size_per_chunk));
|
||||
fields_data_.emplace(
|
||||
field_id, std::make_unique<ConcurrentVector<Type>>(size_per_chunk));
|
||||
}
|
||||
|
||||
// append a column of vector type
|
||||
|
@ -301,7 +315,9 @@ struct InsertRecord {
|
|||
void
|
||||
append_field_data(FieldId field_id, int64_t dim, int64_t size_per_chunk) {
|
||||
static_assert(std::is_base_of_v<VectorTrait, VectorType>);
|
||||
fields_data_.emplace(field_id, std::make_unique<ConcurrentVector<VectorType>>(dim, size_per_chunk));
|
||||
fields_data_.emplace(field_id,
|
||||
std::make_unique<ConcurrentVector<VectorType>>(
|
||||
dim, size_per_chunk));
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -27,7 +27,8 @@ void
|
|||
ReduceHelper::Initialize() {
|
||||
AssertInfo(search_results_.size() > 0, "empty search result");
|
||||
AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs");
|
||||
AssertInfo(slice_nqs_.size() == slice_topKs_.size(), "unaligned slice_nqs and slice_topKs");
|
||||
AssertInfo(slice_nqs_.size() == slice_topKs_.size(),
|
||||
"unaligned slice_nqs and slice_topKs");
|
||||
|
||||
total_nq_ = search_results_[0]->total_nq_;
|
||||
num_segments_ = search_results_.size();
|
||||
|
@ -36,10 +37,13 @@ ReduceHelper::Initialize() {
|
|||
// prefix sum, get slices offsets
|
||||
AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed");
|
||||
slice_nqs_prefix_sum_.resize(num_slices_ + 1);
|
||||
std::partial_sum(slice_nqs_.begin(), slice_nqs_.end(), slice_nqs_prefix_sum_.begin() + 1);
|
||||
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_, "illegal req sizes, slice_nqs_prefix_sum_[last] = " +
|
||||
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
|
||||
", total_nq = " + std::to_string(total_nq_));
|
||||
std::partial_sum(slice_nqs_.begin(),
|
||||
slice_nqs_.end(),
|
||||
slice_nqs_prefix_sum_.begin() + 1);
|
||||
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_,
|
||||
"illegal req sizes, slice_nqs_prefix_sum_[last] = " +
|
||||
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
|
||||
", total_nq = " + std::to_string(total_nq_));
|
||||
|
||||
// init final_search_records and final_read_topKs
|
||||
final_search_records_.resize(num_segments_);
|
||||
|
@ -59,7 +63,8 @@ ReduceHelper::Reduce() {
|
|||
void
|
||||
ReduceHelper::Marshal() {
|
||||
// get search result data blobs of slices
|
||||
search_result_data_blobs_ = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
|
||||
search_result_data_blobs_ =
|
||||
std::make_unique<milvus::segcore::SearchResultDataBlobs>();
|
||||
search_result_data_blobs_->blobs.resize(num_slices_);
|
||||
for (int i = 0; i < num_slices_; i++) {
|
||||
auto proto = GetSearchResultDataSlice(i);
|
||||
|
@ -72,10 +77,12 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
|||
auto nq = search_result->total_nq_;
|
||||
auto topK = search_result->unity_topK_;
|
||||
AssertInfo(search_result->seg_offsets_.size() == nq * topK,
|
||||
"wrong seg offsets size, size = " + std::to_string(search_result->seg_offsets_.size()) +
|
||||
"wrong seg offsets size, size = " +
|
||||
std::to_string(search_result->seg_offsets_.size()) +
|
||||
", expected size = " + std::to_string(nq * topK));
|
||||
AssertInfo(search_result->distances_.size() == nq * topK,
|
||||
"wrong distances size, size = " + std::to_string(search_result->distances_.size()) +
|
||||
"wrong distances size, size = " +
|
||||
std::to_string(search_result->distances_.size()) +
|
||||
", expected size = " + std::to_string(nq * topK));
|
||||
std::vector<int64_t> real_topks(nq, 0);
|
||||
uint32_t valid_index = 0;
|
||||
|
@ -96,7 +103,9 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
|||
distances.resize(valid_index);
|
||||
|
||||
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
|
||||
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
std::partial_sum(real_topks.begin(),
|
||||
real_topks.end(),
|
||||
search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -107,7 +116,8 @@ ReduceHelper::FillPrimaryKey() {
|
|||
for (auto& search_result : search_results_) {
|
||||
FilterInvalidSearchResult(search_result);
|
||||
if (search_result->get_total_result_count() > 0) {
|
||||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
auto segment =
|
||||
static_cast<SegmentInterface*>(search_result->segment_);
|
||||
segment->FillPrimaryKeys(plan_, *search_result);
|
||||
search_results_[valid_index++] = search_result;
|
||||
}
|
||||
|
@ -144,20 +154,25 @@ ReduceHelper::RefreshSearchResult() {
|
|||
search_result->distances_.swap(distances);
|
||||
search_result->seg_offsets_.swap(seg_offsets);
|
||||
}
|
||||
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
std::partial_sum(real_topks.begin(),
|
||||
real_topks.end(),
|
||||
search_result->topk_per_nq_prefix_sum_.begin() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::FillEntryData() {
|
||||
for (auto search_result : search_results_) {
|
||||
auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
|
||||
auto segment = static_cast<milvus::segcore::SegmentInterface*>(
|
||||
search_result->segment_);
|
||||
segment->FillTargetEntry(plan_, *search_result);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) {
|
||||
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
|
||||
int64_t topk,
|
||||
int64_t& offset) {
|
||||
while (!heap_.empty()) {
|
||||
heap_.pop();
|
||||
}
|
||||
|
@ -175,7 +190,8 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offs
|
|||
auto primary_key = search_result->primary_keys_[offset_beg];
|
||||
auto distance = search_result->distances_[offset_beg];
|
||||
|
||||
pairs_.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
|
||||
pairs_.emplace_back(
|
||||
primary_key, distance, search_result, i, offset_beg, offset_end);
|
||||
heap_.push(&pairs_.back());
|
||||
}
|
||||
|
||||
|
@ -218,10 +234,14 @@ ReduceHelper::ReduceResultData() {
|
|||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_[i];
|
||||
auto result_count = search_result->get_total_result_count();
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size");
|
||||
AssertInfo(search_result->seg_offsets_.size() == result_count, "incorrect search result seg offset size");
|
||||
AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
|
||||
AssertInfo(search_result != nullptr,
|
||||
"search result must not equal to nullptr");
|
||||
AssertInfo(search_result->distances_.size() == result_count,
|
||||
"incorrect search result distance size");
|
||||
AssertInfo(search_result->seg_offsets_.size() == result_count,
|
||||
"incorrect search result seg offset size");
|
||||
AssertInfo(search_result->primary_keys_.size() == result_count,
|
||||
"incorrect search result primary key size");
|
||||
}
|
||||
|
||||
int64_t skip_dup_cnt = 0;
|
||||
|
@ -232,11 +252,13 @@ ReduceHelper::ReduceResultData() {
|
|||
// reduce search results
|
||||
int64_t offset = 0;
|
||||
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
|
||||
skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], offset);
|
||||
skip_dup_cnt += ReduceSearchResultForOneNQ(
|
||||
qi, slice_topKs_[slice_index], offset);
|
||||
}
|
||||
}
|
||||
if (skip_dup_cnt > 0) {
|
||||
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
|
||||
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = "
|
||||
<< skip_dup_cnt;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -247,13 +269,15 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
|
||||
int64_t result_count = 0;
|
||||
for (auto search_result : search_results_) {
|
||||
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
|
||||
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() ==
|
||||
search_result->total_nq_ + 1,
|
||||
"incorrect topk_per_nq_prefix_sum_ size in search result");
|
||||
result_count +=
|
||||
search_result->topk_per_nq_prefix_sum_[nq_end] - search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
result_count += search_result->topk_per_nq_prefix_sum_[nq_end] -
|
||||
search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
}
|
||||
|
||||
auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
|
||||
auto search_result_data =
|
||||
std::make_unique<milvus::proto::schema::SearchResultData>();
|
||||
// set unify_topK and total_nq
|
||||
search_result_data->set_top_k(slice_topKs_[slice_index]);
|
||||
search_result_data->set_num_queries(nq_end - nq_begin);
|
||||
|
@ -263,14 +287,16 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
|
||||
|
||||
// reserve space for pks
|
||||
auto primary_field_id = plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
|
||||
auto primary_field_id =
|
||||
plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
|
||||
AssertInfo(primary_field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
|
||||
auto pk_type = plan_->schema_[primary_field_id].get_data_type();
|
||||
switch (pk_type) {
|
||||
case milvus::DataType::INT64: {
|
||||
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
|
||||
ids->mutable_data()->Resize(result_count, 0);
|
||||
search_result_data->mutable_ids()->set_allocated_int_id(ids.release());
|
||||
search_result_data->mutable_ids()->set_allocated_int_id(
|
||||
ids.release());
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::VARCHAR: {
|
||||
|
@ -278,7 +304,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
std::vector<std::string> string_pks(result_count);
|
||||
// TODO: prevent mem copy
|
||||
*ids->mutable_data() = {string_pks.begin(), string_pks.end()};
|
||||
search_result_data->mutable_ids()->set_allocated_str_id(ids.release());
|
||||
search_result_data->mutable_ids()->set_allocated_str_id(
|
||||
ids.release());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -293,7 +320,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
for (auto qi = nq_begin; qi < nq_end; qi++) {
|
||||
int64_t topk_count = 0;
|
||||
for (auto search_result : search_results_) {
|
||||
AssertInfo(search_result != nullptr, "null search result when reorganize");
|
||||
AssertInfo(search_result != nullptr,
|
||||
"null search result when reorganize");
|
||||
if (search_result->result_offsets_.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
@ -305,18 +333,26 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
for (auto ki = topk_start; ki < topk_end; ki++) {
|
||||
auto loc = search_result->result_offsets_[ki];
|
||||
AssertInfo(loc < result_count && loc >= 0,
|
||||
"invalid loc when GetSearchResultDataSlice, loc = " + std::to_string(loc) +
|
||||
", result_count = " + std::to_string(result_count));
|
||||
"invalid loc when GetSearchResultDataSlice, loc = " +
|
||||
std::to_string(loc) + ", result_count = " +
|
||||
std::to_string(result_count));
|
||||
// set result pks
|
||||
switch (pk_type) {
|
||||
case milvus::DataType::INT64: {
|
||||
search_result_data->mutable_ids()->mutable_int_id()->mutable_data()->Set(
|
||||
loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[ki]));
|
||||
search_result_data->mutable_ids()
|
||||
->mutable_int_id()
|
||||
->mutable_data()
|
||||
->Set(loc,
|
||||
std::visit(Int64PKVisitor{},
|
||||
search_result->primary_keys_[ki]));
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::VARCHAR: {
|
||||
*search_result_data->mutable_ids()->mutable_str_id()->mutable_data()->Mutable(loc) =
|
||||
std::visit(StrPKVisitor{}, search_result->primary_keys_[ki]);
|
||||
*search_result_data->mutable_ids()
|
||||
->mutable_str_id()
|
||||
->mutable_data()
|
||||
->Mutable(loc) = std::visit(
|
||||
StrPKVisitor{}, search_result->primary_keys_[ki]);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -325,7 +361,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
}
|
||||
|
||||
// set result distances
|
||||
search_result_data->mutable_scores()->Set(loc, search_result->distances_[ki]);
|
||||
search_result_data->mutable_scores()->Set(
|
||||
loc, search_result->distances_[ki]);
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = std::make_pair(search_result, ki);
|
||||
}
|
||||
|
@ -336,14 +373,17 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
}
|
||||
|
||||
AssertInfo(search_result_data->scores_size() == result_count,
|
||||
"wrong scores size, size = " + std::to_string(search_result_data->scores_size()) +
|
||||
"wrong scores size, size = " +
|
||||
std::to_string(search_result_data->scores_size()) +
|
||||
", expected size = " + std::to_string(result_count));
|
||||
|
||||
// set output fields
|
||||
for (auto field_id : plan_->target_entries_) {
|
||||
auto& field_meta = plan_->schema_[field_id];
|
||||
auto field_data = milvus::segcore::MergeDataArray(result_pairs, field_meta);
|
||||
search_result_data->mutable_fields_data()->AddAllocated(field_data.release());
|
||||
auto field_data =
|
||||
milvus::segcore::MergeDataArray(result_pairs, field_meta);
|
||||
search_result_data->mutable_fields_data()->AddAllocated(
|
||||
field_data.release());
|
||||
}
|
||||
|
||||
// SearchResultData to blob
|
||||
|
|
|
@ -73,7 +73,9 @@ class ReduceHelper {
|
|||
FillEntryData();
|
||||
|
||||
int64_t
|
||||
ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset);
|
||||
ReduceSearchResultForOneNQ(int64_t qi,
|
||||
int64_t topk,
|
||||
int64_t& result_offset);
|
||||
|
||||
void
|
||||
ReduceResultData();
|
||||
|
@ -102,7 +104,10 @@ class ReduceHelper {
|
|||
// Used for merge results,
|
||||
// define these here to avoid allocating them for each query
|
||||
std::vector<SearchResultPair> pairs_;
|
||||
std::priority_queue<SearchResultPair*, std::vector<SearchResultPair*>, SearchResultPairComparator> heap_;
|
||||
std::priority_queue<SearchResultPair*,
|
||||
std::vector<SearchResultPair*>,
|
||||
SearchResultPairComparator>
|
||||
heap_;
|
||||
std::unordered_set<milvus::PkType> pk_set_;
|
||||
};
|
||||
|
||||
|
|
|
@ -27,8 +27,12 @@ struct SearchResultPair {
|
|||
int64_t offset_;
|
||||
int64_t offset_rb_; // right bound
|
||||
|
||||
SearchResultPair(
|
||||
milvus::PkType primary_key, float distance, SearchResult* result, int64_t index, int64_t lb, int64_t rb)
|
||||
SearchResultPair(milvus::PkType primary_key,
|
||||
float distance,
|
||||
SearchResult* result,
|
||||
int64_t index,
|
||||
int64_t lb,
|
||||
int64_t rb)
|
||||
: primary_key_(primary_key),
|
||||
distance_(distance),
|
||||
search_result_(result),
|
||||
|
|
|
@ -31,8 +31,12 @@ ScalarIndexVector::do_search_ids(const IdArray& ids) const {
|
|||
for (auto id : src_ids.data()) {
|
||||
using Pair = std::pair<T, SegOffset>;
|
||||
auto [iter_beg, iter_end] =
|
||||
std::equal_range(mapping_.begin(), mapping_.end(), std::make_pair(id, SegOffset(0)),
|
||||
[](const Pair& left, const Pair& right) { return left.first < right.first; });
|
||||
std::equal_range(mapping_.begin(),
|
||||
mapping_.end(),
|
||||
std::make_pair(id, SegOffset(0)),
|
||||
[](const Pair& left, const Pair& right) {
|
||||
return left.first < right.first;
|
||||
});
|
||||
|
||||
for (auto& iter = iter_beg; iter != iter_end; iter++) {
|
||||
auto [entry_id, entry_offset] = *iter;
|
||||
|
@ -51,8 +55,12 @@ ScalarIndexVector::do_search_ids(const std::vector<idx_t>& ids) const {
|
|||
for (auto id : ids) {
|
||||
using Pair = std::pair<T, SegOffset>;
|
||||
auto [iter_beg, iter_end] =
|
||||
std::equal_range(mapping_.begin(), mapping_.end(), std::make_pair(id, SegOffset(0)),
|
||||
[](const Pair& left, const Pair& right) { return left.first < right.first; });
|
||||
std::equal_range(mapping_.begin(),
|
||||
mapping_.end(),
|
||||
std::make_pair(id, SegOffset(0)),
|
||||
[](const Pair& left, const Pair& right) {
|
||||
return left.first < right.first;
|
||||
});
|
||||
|
||||
for (auto& iter = iter_beg; iter != iter_end; iter++) {
|
||||
auto [entry_id, entry_offset] = *iter_beg;
|
||||
|
@ -64,7 +72,9 @@ ScalarIndexVector::do_search_ids(const std::vector<idx_t>& ids) const {
|
|||
}
|
||||
|
||||
void
|
||||
ScalarIndexVector::append_data(const ScalarIndexVector::T* ids, int64_t count, SegOffset base) {
|
||||
ScalarIndexVector::append_data(const ScalarIndexVector::T* ids,
|
||||
int64_t count,
|
||||
SegOffset base) {
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
auto offset = base + SegOffset(i);
|
||||
mapping_.emplace_back(ids[i], offset);
|
||||
|
|
|
@ -53,7 +53,8 @@ class ScalarIndexVector : public ScalarIndexBase {
|
|||
debug() const override {
|
||||
std::string dbg_str;
|
||||
for (auto pr : mapping_) {
|
||||
dbg_str += "<" + std::to_string(pr.first) + "->" + std::to_string(pr.second.get()) + ">";
|
||||
dbg_str += "<" + std::to_string(pr.first) + "->" +
|
||||
std::to_string(pr.second.get()) + ">";
|
||||
}
|
||||
return dbg_str;
|
||||
}
|
||||
|
|
|
@ -33,7 +33,9 @@ using SealedIndexingEntryPtr = std::unique_ptr<SealedIndexingEntry>;
|
|||
|
||||
struct SealedIndexingRecord {
|
||||
void
|
||||
append_field_indexing(FieldId field_id, const MetricType& metric_type, index::IndexBasePtr indexing) {
|
||||
append_field_indexing(FieldId field_id,
|
||||
const MetricType& metric_type,
|
||||
index::IndexBasePtr indexing) {
|
||||
auto ptr = std::make_unique<SealedIndexingEntry>();
|
||||
ptr->indexing_ = std::move(indexing);
|
||||
ptr->metric_type_ = metric_type;
|
||||
|
|
|
@ -100,7 +100,8 @@ SegcoreConfig::parse_from(const std::string& config_path) {
|
|||
} catch (const SegcoreError& e) {
|
||||
throw e;
|
||||
} catch (const std::exception& e) {
|
||||
std::string str = std::string("Invalid Yaml: ") + config_path + ", err: " + e.what();
|
||||
std::string str =
|
||||
std::string("Invalid Yaml: ") + config_path + ", err: " + e.what();
|
||||
PanicInfo(str);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,7 +76,8 @@ class SegcoreConfig {
|
|||
}
|
||||
|
||||
void
|
||||
set_small_index_config(const MetricType& metric_type, const SmallIndexConf& small_index_conf) {
|
||||
set_small_index_config(const MetricType& metric_type,
|
||||
const SmallIndexConf& small_index_conf) {
|
||||
table_[metric_type] = small_index_conf;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,17 +36,21 @@ SegmentGrowingImpl::PreDelete(int64_t size) {
|
|||
}
|
||||
|
||||
void
|
||||
SegmentGrowingImpl::mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const {
|
||||
SegmentGrowingImpl::mask_with_delete(BitsetType& bitset,
|
||||
int64_t ins_barrier,
|
||||
Timestamp timestamp) const {
|
||||
auto del_barrier = get_barrier(get_deleted_record(), timestamp);
|
||||
if (del_barrier == 0) {
|
||||
return;
|
||||
}
|
||||
auto bitmap_holder = get_deleted_bitmap(del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
|
||||
auto bitmap_holder = get_deleted_bitmap(
|
||||
del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
|
||||
if (!bitmap_holder || !bitmap_holder->bitmap_ptr) {
|
||||
return;
|
||||
}
|
||||
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
|
||||
AssertInfo(delete_bitset.size() == bitset.size(), "Deleted bitmap size not equal to filtered bitmap size");
|
||||
AssertInfo(delete_bitset.size() == bitset.size(),
|
||||
"Deleted bitmap size not equal to filtered bitmap size");
|
||||
bitset |= delete_bitset;
|
||||
}
|
||||
|
||||
|
@ -56,7 +60,8 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
|
|||
const int64_t* row_ids,
|
||||
const Timestamp* timestamps_raw,
|
||||
const InsertData* insert_data) {
|
||||
AssertInfo(insert_data->num_rows() == size, "Entities_raw count not equal to insert size");
|
||||
AssertInfo(insert_data->num_rows() == size,
|
||||
"Entities_raw count not equal to insert size");
|
||||
// AssertInfo(insert_data->fields_data_size() == schema_->size(),
|
||||
// "num fields of insert data not equal to num of schema fields");
|
||||
// step 1: check insert data if valid
|
||||
|
@ -72,34 +77,45 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
|
|||
// query node already guarantees that the timestamp is ordered, avoid field data copy in c++
|
||||
|
||||
// step 3: fill into Segment.ConcurrentVector
|
||||
insert_record_.timestamps_.set_data_raw(reserved_offset, timestamps_raw, size);
|
||||
insert_record_.timestamps_.set_data_raw(
|
||||
reserved_offset, timestamps_raw, size);
|
||||
insert_record_.row_ids_.set_data_raw(reserved_offset, row_ids, size);
|
||||
for (auto [field_id, field_meta] : schema_->get_fields()) {
|
||||
AssertInfo(field_id_to_offset.count(field_id), "Cannot find field_id");
|
||||
auto data_offset = field_id_to_offset[field_id];
|
||||
insert_record_.get_field_data_base(field_id)->set_data_raw(reserved_offset, size,
|
||||
&insert_data->fields_data(data_offset), field_meta);
|
||||
insert_record_.get_field_data_base(field_id)->set_data_raw(
|
||||
reserved_offset,
|
||||
size,
|
||||
&insert_data->fields_data(data_offset),
|
||||
field_meta);
|
||||
}
|
||||
|
||||
// step 4: set pks to offset
|
||||
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
|
||||
AssertInfo(field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
|
||||
std::vector<PkType> pks(size);
|
||||
ParsePksFromFieldData(pks, insert_data->fields_data(field_id_to_offset[field_id]));
|
||||
ParsePksFromFieldData(
|
||||
pks, insert_data->fields_data(field_id_to_offset[field_id]));
|
||||
for (int i = 0; i < size; ++i) {
|
||||
insert_record_.insert_pk(pks[i], reserved_offset + i);
|
||||
}
|
||||
|
||||
// step 5: update small indexes
|
||||
insert_record_.ack_responder_.AddSegment(reserved_offset, reserved_offset + size);
|
||||
insert_record_.ack_responder_.AddSegment(reserved_offset,
|
||||
reserved_offset + size);
|
||||
if (enable_small_index_) {
|
||||
int64_t chunk_rows = segcore_config_.get_chunk_rows();
|
||||
indexing_record_.UpdateResourceAck(insert_record_.ack_responder_.GetAck() / chunk_rows, insert_record_);
|
||||
indexing_record_.UpdateResourceAck(
|
||||
insert_record_.ack_responder_.GetAck() / chunk_rows,
|
||||
insert_record_);
|
||||
}
|
||||
}
|
||||
|
||||
Status
|
||||
SegmentGrowingImpl::Delete(int64_t reserved_begin, int64_t size, const IdArray* ids, const Timestamp* timestamps_raw) {
|
||||
SegmentGrowingImpl::Delete(int64_t reserved_begin,
|
||||
int64_t size,
|
||||
const IdArray* ids,
|
||||
const Timestamp* timestamps_raw) {
|
||||
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
|
||||
AssertInfo(field_id.get() != -1, "Primary key is -1");
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
|
@ -122,9 +138,11 @@ SegmentGrowingImpl::Delete(int64_t reserved_begin, int64_t size, const IdArray*
|
|||
}
|
||||
|
||||
// step 2: fill delete record
|
||||
deleted_record_.timestamps_.set_data_raw(reserved_begin, sort_timestamps.data(), size);
|
||||
deleted_record_.timestamps_.set_data_raw(
|
||||
reserved_begin, sort_timestamps.data(), size);
|
||||
deleted_record_.pks_.set_data_raw(reserved_begin, sort_pks.data(), size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin,
|
||||
reserved_begin + size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -145,8 +163,10 @@ SegmentGrowingImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
|
|||
AssertInfo(info.primary_keys, "Deleted primary keys is null");
|
||||
AssertInfo(info.timestamps, "Deleted timestamps is null");
|
||||
// step 1: get pks and timestamps
|
||||
auto field_id = schema_->get_primary_field_id().value_or(FieldId(INVALID_FIELD_ID));
|
||||
AssertInfo(field_id.get() != INVALID_FIELD_ID, "Primary key has invalid field id");
|
||||
auto field_id =
|
||||
schema_->get_primary_field_id().value_or(FieldId(INVALID_FIELD_ID));
|
||||
AssertInfo(field_id.get() != INVALID_FIELD_ID,
|
||||
"Primary key has invalid field id");
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
int64_t size = info.row_count;
|
||||
std::vector<PkType> pks(size);
|
||||
|
@ -157,7 +177,8 @@ SegmentGrowingImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
|
|||
auto reserved_begin = deleted_record_.reserved.fetch_add(size);
|
||||
deleted_record_.pks_.set_data_raw(reserved_begin, pks.data(), size);
|
||||
deleted_record_.timestamps_.set_data_raw(reserved_begin, timestamps, size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin,
|
||||
reserved_begin + size);
|
||||
}
|
||||
|
||||
SpanBase
|
||||
|
@ -181,70 +202,100 @@ SegmentGrowingImpl::vector_search(SearchInfo& search_info,
|
|||
SearchResult& output) const {
|
||||
auto& sealed_indexing = this->get_sealed_indexing_record();
|
||||
if (sealed_indexing.is_ready(search_info.field_id_)) {
|
||||
query::SearchOnSealedIndex(this->get_schema(), sealed_indexing, search_info, query_data, query_count, bitset,
|
||||
query::SearchOnSealedIndex(this->get_schema(),
|
||||
sealed_indexing,
|
||||
search_info,
|
||||
query_data,
|
||||
query_count,
|
||||
bitset,
|
||||
output);
|
||||
} else {
|
||||
query::SearchOnGrowing(*this, search_info, query_data, query_count, timestamp, bitset, output);
|
||||
query::SearchOnGrowing(*this,
|
||||
search_info,
|
||||
query_data,
|
||||
query_count,
|
||||
timestamp,
|
||||
bitset,
|
||||
output);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
SegmentGrowingImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const {
|
||||
SegmentGrowingImpl::bulk_subscript(FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const {
|
||||
// TODO: support more types
|
||||
auto vec_ptr = insert_record_.get_field_data_base(field_id);
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
if (field_meta.is_vector()) {
|
||||
aligned_vector<char> output(field_meta.get_sizeof() * count);
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
bulk_subscript_impl<FloatVector>(field_meta.get_sizeof(), *vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<FloatVector>(field_meta.get_sizeof(),
|
||||
*vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
output.data());
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
|
||||
bulk_subscript_impl<BinaryVector>(field_meta.get_sizeof(), *vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<BinaryVector>(field_meta.get_sizeof(),
|
||||
*vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
output.data());
|
||||
} else {
|
||||
PanicInfo("logical error");
|
||||
}
|
||||
return CreateVectorDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
|
||||
AssertInfo(!field_meta.is_vector(), "Scalar field meta type is vector type");
|
||||
AssertInfo(!field_meta.is_vector(),
|
||||
"Scalar field meta type is vector type");
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::BOOL: {
|
||||
FixedVector<bool> output(count);
|
||||
bulk_subscript_impl<bool>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<bool>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
FixedVector<bool> output(count);
|
||||
bulk_subscript_impl<int8_t>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int8_t>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
FixedVector<int16_t> output(count);
|
||||
bulk_subscript_impl<int16_t>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int16_t>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
FixedVector<int32_t> output(count);
|
||||
bulk_subscript_impl<int32_t>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int32_t>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
FixedVector<int64_t> output(count);
|
||||
bulk_subscript_impl<int64_t>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int64_t>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
FixedVector<float> output(count);
|
||||
bulk_subscript_impl<float>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<float>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
FixedVector<double> output(count);
|
||||
bulk_subscript_impl<double>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<double>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
FixedVector<std::string> output(count);
|
||||
bulk_subscript_impl<std::string>(*vec_ptr, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<std::string>(
|
||||
*vec_ptr, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
default: {
|
||||
|
@ -269,7 +320,9 @@ SegmentGrowingImpl::bulk_subscript_impl(int64_t element_sizeof,
|
|||
for (int i = 0; i < count; ++i) {
|
||||
auto dst = output_base + i * element_sizeof;
|
||||
auto offset = seg_offsets[i];
|
||||
const uint8_t* src = (offset == INVALID_SEG_OFFSET ? empty.data() : (const uint8_t*)vec.get_element(offset));
|
||||
const uint8_t* src = (offset == INVALID_SEG_OFFSET
|
||||
? empty.data()
|
||||
: (const uint8_t*)vec.get_element(offset));
|
||||
memcpy(dst, src, element_sizeof);
|
||||
}
|
||||
}
|
||||
|
@ -300,10 +353,12 @@ SegmentGrowingImpl::bulk_subscript(SystemFieldType system_type,
|
|||
void* output) const {
|
||||
switch (system_type) {
|
||||
case SystemFieldType::Timestamp:
|
||||
bulk_subscript_impl<Timestamp>(this->insert_record_.timestamps_, seg_offsets, count, output);
|
||||
bulk_subscript_impl<Timestamp>(
|
||||
this->insert_record_.timestamps_, seg_offsets, count, output);
|
||||
break;
|
||||
case SystemFieldType::RowId:
|
||||
bulk_subscript_impl<int64_t>(this->insert_record_.row_ids_, seg_offsets, count, output);
|
||||
bulk_subscript_impl<int64_t>(
|
||||
this->insert_record_.row_ids_, seg_offsets, count, output);
|
||||
break;
|
||||
default:
|
||||
PanicInfo("unknown subscript fields");
|
||||
|
@ -311,7 +366,8 @@ SegmentGrowingImpl::bulk_subscript(SystemFieldType system_type,
|
|||
}
|
||||
|
||||
std::vector<SegOffset>
|
||||
SegmentGrowingImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) const {
|
||||
SegmentGrowingImpl::search_ids(const BitsetType& bitset,
|
||||
Timestamp timestamp) const {
|
||||
std::vector<SegOffset> res_offsets;
|
||||
|
||||
for (int i = 0; i < bitset.size(); i++) {
|
||||
|
@ -326,7 +382,8 @@ SegmentGrowingImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) co
|
|||
}
|
||||
|
||||
std::vector<SegOffset>
|
||||
SegmentGrowingImpl::search_ids(const BitsetView& bitset, Timestamp timestamp) const {
|
||||
SegmentGrowingImpl::search_ids(const BitsetView& bitset,
|
||||
Timestamp timestamp) const {
|
||||
std::vector<SegOffset> res_offsets;
|
||||
|
||||
for (int i = 0; i < bitset.size(); ++i) {
|
||||
|
@ -341,7 +398,8 @@ SegmentGrowingImpl::search_ids(const BitsetView& bitset, Timestamp timestamp) co
|
|||
}
|
||||
|
||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||
SegmentGrowingImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const {
|
||||
SegmentGrowingImpl::search_ids(const IdArray& id_array,
|
||||
Timestamp timestamp) const {
|
||||
AssertInfo(id_array.has_int_id(), "Id array doesn't have int_id element");
|
||||
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
|
||||
AssertInfo(field_id.get() != -1, "Primary key is -1");
|
||||
|
@ -358,11 +416,13 @@ SegmentGrowingImpl::search_ids(const IdArray& id_array, Timestamp timestamp) con
|
|||
for (auto offset : segOffsets) {
|
||||
switch (data_type) {
|
||||
case DataType::INT64: {
|
||||
res_id_arr->mutable_int_id()->add_data(std::get<int64_t>(pk));
|
||||
res_id_arr->mutable_int_id()->add_data(
|
||||
std::get<int64_t>(pk));
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
res_id_arr->mutable_str_id()->add_data(std::get<std::string>(pk));
|
||||
res_id_arr->mutable_str_id()->add_data(
|
||||
std::get<std::string>(pk));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -384,13 +444,17 @@ int64_t
|
|||
SegmentGrowingImpl::get_active_count(Timestamp ts) const {
|
||||
auto row_count = this->get_row_count();
|
||||
auto& ts_vec = this->get_insert_record().timestamps_;
|
||||
auto iter = std::upper_bound(boost::make_counting_iterator((int64_t)0), boost::make_counting_iterator(row_count),
|
||||
ts, [&](Timestamp ts, int64_t index) { return ts < ts_vec[index]; });
|
||||
auto iter = std::upper_bound(
|
||||
boost::make_counting_iterator((int64_t)0),
|
||||
boost::make_counting_iterator(row_count),
|
||||
ts,
|
||||
[&](Timestamp ts, int64_t index) { return ts < ts_vec[index]; });
|
||||
return *iter;
|
||||
}
|
||||
|
||||
void
|
||||
SegmentGrowingImpl::mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const {
|
||||
SegmentGrowingImpl::mask_with_timestamps(BitsetType& bitset_chunk,
|
||||
Timestamp timestamp) const {
|
||||
// DO NOTHING
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,10 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
|
||||
// TODO: add id into delete log, possibly bitmap
|
||||
Status
|
||||
Delete(int64_t reserverd_offset, int64_t size, const IdArray* pks, const Timestamp* timestamps) override;
|
||||
Delete(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
const IdArray* pks,
|
||||
const Timestamp* timestamps) override;
|
||||
|
||||
int64_t
|
||||
GetMemoryUsageInBytes() const override;
|
||||
|
@ -111,7 +114,8 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
// deprecated
|
||||
const index::IndexBase*
|
||||
chunk_index_impl(FieldId field_id, int64_t chunk_id) const final {
|
||||
return indexing_record_.get_field_indexing(field_id).get_chunk_indexing(chunk_id);
|
||||
return indexing_record_.get_field_indexing(field_id).get_chunk_indexing(
|
||||
chunk_id);
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -142,7 +146,10 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
// for scalar vectors
|
||||
template <typename T>
|
||||
void
|
||||
bulk_subscript_impl(const VectorBase& vec_raw, const int64_t* seg_offsets, int64_t count, void* output_raw) const;
|
||||
bulk_subscript_impl(const VectorBase& vec_raw,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count,
|
||||
void* output_raw) const;
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
|
@ -153,16 +160,25 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
void* output_raw) const;
|
||||
|
||||
void
|
||||
bulk_subscript(SystemFieldType system_type, const int64_t* seg_offsets, int64_t count, void* output) const override;
|
||||
bulk_subscript(SystemFieldType system_type,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count,
|
||||
void* output) const override;
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const override;
|
||||
bulk_subscript(FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const override;
|
||||
|
||||
public:
|
||||
friend std::unique_ptr<SegmentGrowing>
|
||||
CreateGrowingSegment(SchemaPtr schema, const SegcoreConfig& segcore_config, int64_t segment_id);
|
||||
CreateGrowingSegment(SchemaPtr schema,
|
||||
const SegcoreConfig& segcore_config,
|
||||
int64_t segment_id);
|
||||
|
||||
explicit SegmentGrowingImpl(SchemaPtr schema, const SegcoreConfig& segcore_config, int64_t segment_id)
|
||||
explicit SegmentGrowingImpl(SchemaPtr schema,
|
||||
const SegcoreConfig& segcore_config,
|
||||
int64_t segment_id)
|
||||
: segcore_config_(segcore_config),
|
||||
schema_(std::move(schema)),
|
||||
insert_record_(*schema_, segcore_config.get_chunk_rows()),
|
||||
|
@ -171,7 +187,8 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
}
|
||||
|
||||
void
|
||||
mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const override;
|
||||
mask_with_timestamps(BitsetType& bitset_chunk,
|
||||
Timestamp timestamp) const override;
|
||||
|
||||
void
|
||||
vector_search(SearchInfo& search_info,
|
||||
|
@ -183,7 +200,9 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
|
||||
public:
|
||||
void
|
||||
mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const override;
|
||||
mask_with_delete(BitsetType& bitset,
|
||||
int64_t ins_barrier,
|
||||
Timestamp timestamp) const override;
|
||||
|
||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||
search_ids(const IdArray& id_array, Timestamp timestamp) const override;
|
||||
|
@ -237,9 +256,10 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
};
|
||||
|
||||
inline SegmentGrowingPtr
|
||||
CreateGrowingSegment(SchemaPtr schema,
|
||||
int64_t segment_id = -1,
|
||||
const SegcoreConfig& conf = SegcoreConfig::default_config()) {
|
||||
CreateGrowingSegment(
|
||||
SchemaPtr schema,
|
||||
int64_t segment_id = -1,
|
||||
const SegcoreConfig& conf = SegcoreConfig::default_config()) {
|
||||
return std::make_unique<SegmentGrowingImpl>(schema, conf, segment_id);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,43 +21,51 @@
|
|||
namespace milvus::segcore {
|
||||
|
||||
void
|
||||
SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const {
|
||||
SegmentInternalInterface::FillPrimaryKeys(const query::Plan* plan,
|
||||
SearchResult& results) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
AssertInfo(plan, "empty plan");
|
||||
auto size = results.distances_.size();
|
||||
AssertInfo(results.seg_offsets_.size() == size, "Size of result distances is not equal to size of ids");
|
||||
AssertInfo(results.seg_offsets_.size() == size,
|
||||
"Size of result distances is not equal to size of ids");
|
||||
Assert(results.primary_keys_.size() == 0);
|
||||
results.primary_keys_.resize(size);
|
||||
|
||||
auto pk_field_id_opt = get_schema().get_primary_field_id();
|
||||
AssertInfo(pk_field_id_opt.has_value(), "Cannot get primary key offset from schema");
|
||||
AssertInfo(pk_field_id_opt.has_value(),
|
||||
"Cannot get primary key offset from schema");
|
||||
auto pk_field_id = pk_field_id_opt.value();
|
||||
AssertInfo(IsPrimaryKeyDataType(get_schema()[pk_field_id].get_data_type()),
|
||||
"Primary key field is not INT64 or VARCHAR type");
|
||||
auto field_data = bulk_subscript(pk_field_id, results.seg_offsets_.data(), size);
|
||||
auto field_data =
|
||||
bulk_subscript(pk_field_id, results.seg_offsets_.data(), size);
|
||||
results.pk_type_ = DataType(field_data->type());
|
||||
|
||||
ParsePksFromFieldData(results.primary_keys_, *field_data.get());
|
||||
}
|
||||
|
||||
void
|
||||
SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult& results) const {
|
||||
SegmentInternalInterface::FillTargetEntry(const query::Plan* plan,
|
||||
SearchResult& results) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
AssertInfo(plan, "empty plan");
|
||||
auto size = results.distances_.size();
|
||||
AssertInfo(results.seg_offsets_.size() == size, "Size of result distances is not equal to size of ids");
|
||||
AssertInfo(results.seg_offsets_.size() == size,
|
||||
"Size of result distances is not equal to size of ids");
|
||||
|
||||
// fill other entries except primary key by result_offset
|
||||
for (auto field_id : plan->target_entries_) {
|
||||
auto field_data = bulk_subscript(field_id, results.seg_offsets_.data(), size);
|
||||
auto field_data =
|
||||
bulk_subscript(field_id, results.seg_offsets_.data(), size);
|
||||
results.output_fields_data_[field_id] = std::move(field_data);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SearchResult>
|
||||
SegmentInternalInterface::Search(const query::Plan* plan,
|
||||
const query::PlaceholderGroup* placeholder_group,
|
||||
Timestamp timestamp) const {
|
||||
SegmentInternalInterface::Search(
|
||||
const query::Plan* plan,
|
||||
const query::PlaceholderGroup* placeholder_group,
|
||||
Timestamp timestamp) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
check_search(plan);
|
||||
query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group);
|
||||
|
@ -68,24 +76,30 @@ SegmentInternalInterface::Search(const query::Plan* plan,
|
|||
}
|
||||
|
||||
std::unique_ptr<proto::segcore::RetrieveResults>
|
||||
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, Timestamp timestamp) const {
|
||||
SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
|
||||
Timestamp timestamp) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
auto results = std::make_unique<proto::segcore::RetrieveResults>();
|
||||
query::ExecPlanNodeVisitor visitor(*this, timestamp);
|
||||
auto retrieve_results = visitor.get_retrieve_result(*plan->plan_node_);
|
||||
retrieve_results.segment_ = (void*)this;
|
||||
results->mutable_offset()->Add(retrieve_results.result_offsets_.begin(), retrieve_results.result_offsets_.end());
|
||||
results->mutable_offset()->Add(retrieve_results.result_offsets_.begin(),
|
||||
retrieve_results.result_offsets_.end());
|
||||
|
||||
auto fields_data = results->mutable_fields_data();
|
||||
auto ids = results->mutable_ids();
|
||||
auto pk_field_id = plan->schema_.get_primary_field_id();
|
||||
for (auto field_id : plan->field_ids_) {
|
||||
if (SystemProperty::Instance().IsSystem(field_id)) {
|
||||
auto system_type = SystemProperty::Instance().GetSystemFieldType(field_id);
|
||||
auto system_type =
|
||||
SystemProperty::Instance().GetSystemFieldType(field_id);
|
||||
|
||||
auto size = retrieve_results.result_offsets_.size();
|
||||
FixedVector<int64_t> output(size);
|
||||
bulk_subscript(system_type, retrieve_results.result_offsets_.data(), size, output.data());
|
||||
bulk_subscript(system_type,
|
||||
retrieve_results.result_offsets_.data(),
|
||||
size,
|
||||
output.data());
|
||||
|
||||
auto data_array = std::make_unique<DataArray>();
|
||||
data_array->set_field_id(field_id.get());
|
||||
|
@ -102,8 +116,9 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, Timestamp ti
|
|||
|
||||
auto& field_meta = plan->schema_[field_id];
|
||||
|
||||
auto col =
|
||||
bulk_subscript(field_id, retrieve_results.result_offsets_.data(), retrieve_results.result_offsets_.size());
|
||||
auto col = bulk_subscript(field_id,
|
||||
retrieve_results.result_offsets_.data(),
|
||||
retrieve_results.result_offsets_.size());
|
||||
auto col_data = col.release();
|
||||
fields_data->AddAllocated(col_data);
|
||||
if (pk_field_id.has_value() && pk_field_id.value() == field_id) {
|
||||
|
@ -111,7 +126,8 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, Timestamp ti
|
|||
case DataType::INT64: {
|
||||
auto int_ids = ids->mutable_int_id();
|
||||
auto src_data = col_data->scalars().long_data();
|
||||
int_ids->mutable_data()->Add(src_data.data().begin(), src_data.data().end());
|
||||
int_ids->mutable_data()->Add(src_data.data().begin(),
|
||||
src_data.data().end());
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
|
|
|
@ -48,7 +48,9 @@ class SegmentInterface {
|
|||
FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0;
|
||||
|
||||
virtual std::unique_ptr<SearchResult>
|
||||
Search(const query::Plan* Plan, const query::PlaceholderGroup* placeholder_group, Timestamp timestamp) const = 0;
|
||||
Search(const query::Plan* Plan,
|
||||
const query::PlaceholderGroup* placeholder_group,
|
||||
Timestamp timestamp) const = 0;
|
||||
|
||||
virtual std::unique_ptr<proto::segcore::RetrieveResults>
|
||||
Retrieve(const query::RetrievePlan* Plan, Timestamp timestamp) const = 0;
|
||||
|
@ -73,7 +75,10 @@ class SegmentInterface {
|
|||
PreDelete(int64_t size) = 0;
|
||||
|
||||
virtual Status
|
||||
Delete(int64_t reserved_offset, int64_t size, const IdArray* pks, const Timestamp* timestamps) = 0;
|
||||
Delete(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
const IdArray* pks,
|
||||
const Timestamp* timestamps) = 0;
|
||||
|
||||
virtual void
|
||||
LoadDeletedRecord(const LoadDeletedRecordInfo& info) = 0;
|
||||
|
@ -112,13 +117,16 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
Timestamp timestamp) const override;
|
||||
|
||||
void
|
||||
FillPrimaryKeys(const query::Plan* plan, SearchResult& results) const override;
|
||||
FillPrimaryKeys(const query::Plan* plan,
|
||||
SearchResult& results) const override;
|
||||
|
||||
void
|
||||
FillTargetEntry(const query::Plan* plan, SearchResult& results) const override;
|
||||
FillTargetEntry(const query::Plan* plan,
|
||||
SearchResult& results) const override;
|
||||
|
||||
std::unique_ptr<proto::segcore::RetrieveResults>
|
||||
Retrieve(const query::RetrievePlan* plan, Timestamp timestamp) const override;
|
||||
Retrieve(const query::RetrievePlan* plan,
|
||||
Timestamp timestamp) const override;
|
||||
|
||||
virtual bool
|
||||
HasIndex(FieldId field_id) const = 0;
|
||||
|
@ -142,7 +150,9 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
SearchResult& output) const = 0;
|
||||
|
||||
virtual void
|
||||
mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const = 0;
|
||||
mask_with_delete(BitsetType& bitset,
|
||||
int64_t ins_barrier,
|
||||
Timestamp timestamp) const = 0;
|
||||
|
||||
// count of chunk that has index available
|
||||
virtual int64_t
|
||||
|
@ -153,7 +163,8 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
num_chunk_data(FieldId field_id) const = 0;
|
||||
|
||||
virtual void
|
||||
mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const = 0;
|
||||
mask_with_timestamps(BitsetType& bitset_chunk,
|
||||
Timestamp timestamp) const = 0;
|
||||
|
||||
// count of chunks
|
||||
virtual int64_t
|
||||
|
@ -186,11 +197,16 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
|
||||
// calculate output[i] = Vec[seg_offsets[i]}, where Vec binds to system_type
|
||||
virtual void
|
||||
bulk_subscript(SystemFieldType system_type, const int64_t* seg_offsets, int64_t count, void* output) const = 0;
|
||||
bulk_subscript(SystemFieldType system_type,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count,
|
||||
void* output) const = 0;
|
||||
|
||||
// calculate output[i] = Vec[seg_offsets[i]}, where Vec binds to field_offset
|
||||
virtual std::unique_ptr<DataArray>
|
||||
bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const = 0;
|
||||
bulk_subscript(FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const = 0;
|
||||
|
||||
virtual void
|
||||
check_search(const query::Plan* plan) const = 0;
|
||||
|
|
|
@ -66,7 +66,8 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
|
|||
auto field_id = FieldId(info.field_id);
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
|
||||
AssertInfo(info.index_params.count("metric_type"), "Can't get metric_type in index_params");
|
||||
AssertInfo(info.index_params.count("metric_type"),
|
||||
"Can't get metric_type in index_params");
|
||||
auto metric_type = info.index_params.at("metric_type");
|
||||
auto row_count = info.index->Count();
|
||||
AssertInfo(row_count > 0, "Index count is 0");
|
||||
|
@ -74,17 +75,24 @@ SegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
|
|||
std::unique_lock lck(mutex_);
|
||||
// Don't allow vector raw data and index exist at the same time
|
||||
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
|
||||
"vector index can't be loaded when raw data exists at field " + std::to_string(field_id.get()));
|
||||
AssertInfo(!get_bit(index_ready_bitset_, field_id),
|
||||
"vector index has been exist at " + std::to_string(field_id.get()));
|
||||
"vector index can't be loaded when raw data exists at field " +
|
||||
std::to_string(field_id.get()));
|
||||
AssertInfo(
|
||||
!get_bit(index_ready_bitset_, field_id),
|
||||
"vector index has been exist at " + std::to_string(field_id.get()));
|
||||
if (row_count_opt_.has_value()) {
|
||||
AssertInfo(row_count_opt_.value() == row_count,
|
||||
"field (" + std::to_string(field_id.get()) + ") data has different row count (" +
|
||||
std::to_string(row_count) + ") than other column's row count (" +
|
||||
"field (" + std::to_string(field_id.get()) +
|
||||
") data has different row count (" +
|
||||
std::to_string(row_count) +
|
||||
") than other column's row count (" +
|
||||
std::to_string(row_count_opt_.value()) + ")");
|
||||
}
|
||||
AssertInfo(!vector_indexings_.is_ready(field_id), "vec index is not ready");
|
||||
vector_indexings_.append_field_indexing(field_id, metric_type, std::move(const_cast<LoadIndexInfo&>(info).index));
|
||||
vector_indexings_.append_field_indexing(
|
||||
field_id,
|
||||
metric_type,
|
||||
std::move(const_cast<LoadIndexInfo&>(info).index));
|
||||
|
||||
set_bit(index_ready_bitset_, field_id, true);
|
||||
update_row_count(row_count);
|
||||
|
@ -103,24 +111,30 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) {
|
|||
std::unique_lock lck(mutex_);
|
||||
// Don't allow scalar raw data and index exist at the same time
|
||||
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
|
||||
"scalar index can't be loaded when raw data exists at field " + std::to_string(field_id.get()));
|
||||
AssertInfo(!get_bit(index_ready_bitset_, field_id),
|
||||
"scalar index has been exist at " + std::to_string(field_id.get()));
|
||||
"scalar index can't be loaded when raw data exists at field " +
|
||||
std::to_string(field_id.get()));
|
||||
AssertInfo(
|
||||
!get_bit(index_ready_bitset_, field_id),
|
||||
"scalar index has been exist at " + std::to_string(field_id.get()));
|
||||
if (row_count_opt_.has_value()) {
|
||||
AssertInfo(row_count_opt_.value() == row_count,
|
||||
"field (" + std::to_string(field_id.get()) + ") data has different row count (" +
|
||||
std::to_string(row_count) + ") than other column's row count (" +
|
||||
"field (" + std::to_string(field_id.get()) +
|
||||
") data has different row count (" +
|
||||
std::to_string(row_count) +
|
||||
") than other column's row count (" +
|
||||
std::to_string(row_count_opt_.value()) + ")");
|
||||
}
|
||||
|
||||
scalar_indexings_[field_id] = std::move(const_cast<LoadIndexInfo&>(info).index);
|
||||
scalar_indexings_[field_id] =
|
||||
std::move(const_cast<LoadIndexInfo&>(info).index);
|
||||
// reverse pk from scalar index and set pks to offset
|
||||
if (schema_->get_primary_field_id() == field_id) {
|
||||
AssertInfo(field_id.get() != -1, "Primary key is -1");
|
||||
AssertInfo(insert_record_.empty_pks(), "already exists");
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::INT64: {
|
||||
auto int64_index = dynamic_cast<index::ScalarIndex<int64_t>*>(scalar_indexings_[field_id].get());
|
||||
auto int64_index = dynamic_cast<index::ScalarIndex<int64_t>*>(
|
||||
scalar_indexings_[field_id].get());
|
||||
for (int i = 0; i < row_count; ++i) {
|
||||
insert_record_.insert_pk(int64_index->Reverse_Lookup(i), i);
|
||||
}
|
||||
|
@ -128,9 +142,12 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) {
|
|||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto string_index = dynamic_cast<index::ScalarIndex<std::string>*>(scalar_indexings_[field_id].get());
|
||||
auto string_index =
|
||||
dynamic_cast<index::ScalarIndex<std::string>*>(
|
||||
scalar_indexings_[field_id].get());
|
||||
for (int i = 0; i < row_count; ++i) {
|
||||
insert_record_.insert_pk(string_index->Reverse_Lookup(i), i);
|
||||
insert_record_.insert_pk(string_index->Reverse_Lookup(i),
|
||||
i);
|
||||
}
|
||||
insert_record_.seal_pks();
|
||||
break;
|
||||
|
@ -155,15 +172,21 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& info) {
|
|||
AssertInfo(info.field_data != nullptr, "Field info blob is null");
|
||||
auto size = info.row_count;
|
||||
if (row_count_opt_.has_value()) {
|
||||
AssertInfo(row_count_opt_.value() == size,
|
||||
fmt::format("field {} has different row count {} to other column's {}", field_id.get(), size,
|
||||
row_count_opt_.value()));
|
||||
AssertInfo(
|
||||
row_count_opt_.value() == size,
|
||||
fmt::format(
|
||||
"field {} has different row count {} to other column's {}",
|
||||
field_id.get(),
|
||||
size,
|
||||
row_count_opt_.value()));
|
||||
}
|
||||
|
||||
if (SystemProperty::Instance().IsSystem(field_id)) {
|
||||
auto system_field_type = SystemProperty::Instance().GetSystemFieldType(field_id);
|
||||
auto system_field_type =
|
||||
SystemProperty::Instance().GetSystemFieldType(field_id);
|
||||
if (system_field_type == SystemFieldType::Timestamp) {
|
||||
auto timestamps = reinterpret_cast<const Timestamp*>(info.field_data->scalars().long_data().data().data());
|
||||
auto timestamps = reinterpret_cast<const Timestamp*>(
|
||||
info.field_data->scalars().long_data().data().data());
|
||||
|
||||
TimestampIndex index;
|
||||
auto min_slice_length = size < 4096 ? 1 : 4096;
|
||||
|
@ -176,15 +199,19 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& info) {
|
|||
AssertInfo(insert_record_.timestamps_.empty(), "already exists");
|
||||
insert_record_.timestamps_.fill_chunk_data(timestamps, size);
|
||||
insert_record_.timestamp_index_ = std::move(index);
|
||||
AssertInfo(insert_record_.timestamps_.num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
AssertInfo(insert_record_.timestamps_.num_chunk() == 1,
|
||||
"num chunk not equal to 1 for sealed segment");
|
||||
} else {
|
||||
AssertInfo(system_field_type == SystemFieldType::RowId, "System field type of id column is not RowId");
|
||||
auto row_ids = reinterpret_cast<const idx_t*>(info.field_data->scalars().long_data().data().data());
|
||||
AssertInfo(system_field_type == SystemFieldType::RowId,
|
||||
"System field type of id column is not RowId");
|
||||
auto row_ids = reinterpret_cast<const idx_t*>(
|
||||
info.field_data->scalars().long_data().data().data());
|
||||
// write data under lock
|
||||
std::unique_lock lck(mutex_);
|
||||
AssertInfo(insert_record_.row_ids_.empty(), "already exists");
|
||||
insert_record_.row_ids_.fill_chunk_data(row_ids, size);
|
||||
AssertInfo(insert_record_.row_ids_.num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
AssertInfo(insert_record_.row_ids_.num_chunk() == 1,
|
||||
"num chunk not equal to 1 for sealed segment");
|
||||
}
|
||||
++system_ready_count_;
|
||||
} else {
|
||||
|
@ -198,7 +225,8 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& info) {
|
|||
std::unique_lock lck(mutex_);
|
||||
|
||||
// Don't allow raw data and index exist at the same time
|
||||
AssertInfo(!get_bit(index_ready_bitset_, field_id), "field data can't be loaded when indexing exists");
|
||||
AssertInfo(!get_bit(index_ready_bitset_, field_id),
|
||||
"field data can't be loaded when indexing exists");
|
||||
|
||||
void* field_data = nullptr;
|
||||
size_t size = 0;
|
||||
|
@ -250,7 +278,8 @@ SegmentSealedImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
|
|||
auto reserved_begin = deleted_record_.reserved.fetch_add(size);
|
||||
deleted_record_.pks_.set_data_raw(reserved_begin, pks.data(), size);
|
||||
deleted_record_.timestamps_.set_data_raw(reserved_begin, timestamps, size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin,
|
||||
reserved_begin + size);
|
||||
}
|
||||
|
||||
// internal API: support scalar index only
|
||||
|
@ -290,19 +319,24 @@ SegmentSealedImpl::chunk_data_impl(FieldId field_id, int64_t chunk_id) const {
|
|||
auto field_data = it->second;
|
||||
return SpanBase(field_data, get_row_count(), element_sizeof);
|
||||
}
|
||||
if (auto it = variable_fields_.find(field_id); it != variable_fields_.end()) {
|
||||
if (auto it = variable_fields_.find(field_id);
|
||||
it != variable_fields_.end()) {
|
||||
auto& field = it->second;
|
||||
return SpanBase(field.views().data(), field.views().size(), sizeof(std::string_view));
|
||||
return SpanBase(field.views().data(),
|
||||
field.views().size(),
|
||||
sizeof(std::string_view));
|
||||
}
|
||||
auto field_data = insert_record_.get_field_data_base(field_id);
|
||||
AssertInfo(field_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
AssertInfo(field_data->num_chunk() == 1,
|
||||
"num chunk not equal to 1 for sealed segment");
|
||||
return field_data->get_span_base(0);
|
||||
}
|
||||
|
||||
const index::IndexBase*
|
||||
SegmentSealedImpl::chunk_index_impl(FieldId field_id, int64_t chunk_id) const {
|
||||
AssertInfo(scalar_indexings_.find(field_id) != scalar_indexings_.end(),
|
||||
"Cannot find scalar_indexing with field_id: " + std::to_string(field_id.get()));
|
||||
"Cannot find scalar_indexing with field_id: " +
|
||||
std::to_string(field_id.get()));
|
||||
auto ptr = scalar_indexings_.at(field_id).get();
|
||||
return ptr;
|
||||
}
|
||||
|
@ -333,17 +367,21 @@ SegmentSealedImpl::get_schema() const {
|
|||
}
|
||||
|
||||
void
|
||||
SegmentSealedImpl::mask_with_delete(BitsetType& bitset, int64_t ins_barrier, Timestamp timestamp) const {
|
||||
SegmentSealedImpl::mask_with_delete(BitsetType& bitset,
|
||||
int64_t ins_barrier,
|
||||
Timestamp timestamp) const {
|
||||
auto del_barrier = get_barrier(get_deleted_record(), timestamp);
|
||||
if (del_barrier == 0) {
|
||||
return;
|
||||
}
|
||||
auto bitmap_holder = get_deleted_bitmap(del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
|
||||
auto bitmap_holder = get_deleted_bitmap(
|
||||
del_barrier, ins_barrier, deleted_record_, insert_record_, timestamp);
|
||||
if (!bitmap_holder || !bitmap_holder->bitmap_ptr) {
|
||||
return;
|
||||
}
|
||||
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
|
||||
AssertInfo(delete_bitset.size() == bitset.size(), "Deleted bitmap size not equal to filtered bitmap size");
|
||||
AssertInfo(delete_bitset.size() == bitset.size(),
|
||||
"Deleted bitmap size not equal to filtered bitmap size");
|
||||
bitset |= delete_bitset;
|
||||
}
|
||||
|
||||
|
@ -358,25 +396,42 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
|
|||
auto field_id = search_info.field_id_;
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
|
||||
AssertInfo(field_meta.is_vector(), "The meta type of vector field is not vector type");
|
||||
AssertInfo(field_meta.is_vector(),
|
||||
"The meta type of vector field is not vector type");
|
||||
if (get_bit(index_ready_bitset_, field_id)) {
|
||||
AssertInfo(vector_indexings_.is_ready(field_id),
|
||||
"vector indexes isn't ready for field " + std::to_string(field_id.get()));
|
||||
query::SearchOnSealedIndex(*schema_, vector_indexings_, search_info, query_data, query_count, bitset, output);
|
||||
"vector indexes isn't ready for field " +
|
||||
std::to_string(field_id.get()));
|
||||
query::SearchOnSealedIndex(*schema_,
|
||||
vector_indexings_,
|
||||
search_info,
|
||||
query_data,
|
||||
query_count,
|
||||
bitset,
|
||||
output);
|
||||
} else {
|
||||
AssertInfo(get_bit(field_data_ready_bitset_, field_id),
|
||||
"Field Data is not loaded: " + std::to_string(field_id.get()));
|
||||
AssertInfo(
|
||||
get_bit(field_data_ready_bitset_, field_id),
|
||||
"Field Data is not loaded: " + std::to_string(field_id.get()));
|
||||
AssertInfo(row_count_opt_.has_value(), "Can't get row count value");
|
||||
auto row_count = row_count_opt_.value();
|
||||
auto vec_data = fixed_fields_.at(field_id);
|
||||
query::SearchOnSealed(*schema_, vec_data, search_info, query_data, query_count, row_count, bitset, output);
|
||||
query::SearchOnSealed(*schema_,
|
||||
vec_data,
|
||||
search_info,
|
||||
query_data,
|
||||
query_count,
|
||||
row_count,
|
||||
bitset,
|
||||
output);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
|
||||
if (SystemProperty::Instance().IsSystem(field_id)) {
|
||||
auto system_field_type = SystemProperty::Instance().GetSystemFieldType(field_id);
|
||||
auto system_field_type =
|
||||
SystemProperty::Instance().GetSystemFieldType(field_id);
|
||||
|
||||
std::unique_lock lck(mutex_);
|
||||
--system_ready_count_;
|
||||
|
@ -398,10 +453,12 @@ SegmentSealedImpl::DropFieldData(const FieldId field_id) {
|
|||
void
|
||||
SegmentSealedImpl::DropIndex(const FieldId field_id) {
|
||||
AssertInfo(!SystemProperty::Instance().IsSystem(field_id),
|
||||
"Field id:" + std::to_string(field_id.get()) + " isn't one of system type when drop index");
|
||||
"Field id:" + std::to_string(field_id.get()) +
|
||||
" isn't one of system type when drop index");
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
AssertInfo(field_meta.is_vector(),
|
||||
"Field meta of offset:" + std::to_string(field_id.get()) + " is not vector type");
|
||||
"Field meta of offset:" + std::to_string(field_id.get()) +
|
||||
" is not vector type");
|
||||
|
||||
std::unique_lock lck(mutex_);
|
||||
vector_indexings_.drop_field_indexing(field_id);
|
||||
|
@ -411,23 +468,29 @@ SegmentSealedImpl::DropIndex(const FieldId field_id) {
|
|||
void
|
||||
SegmentSealedImpl::check_search(const query::Plan* plan) const {
|
||||
AssertInfo(plan, "Search plan is null");
|
||||
AssertInfo(plan->extra_info_opt_.has_value(), "Extra info of search plan doesn't have value");
|
||||
AssertInfo(plan->extra_info_opt_.has_value(),
|
||||
"Extra info of search plan doesn't have value");
|
||||
|
||||
if (!is_system_field_ready()) {
|
||||
PanicInfo("failed to load row ID or timestamp, potential missing bin logs or empty segments. Segment ID = " +
|
||||
std::to_string(this->id_));
|
||||
PanicInfo(
|
||||
"failed to load row ID or timestamp, potential missing bin logs or "
|
||||
"empty segments. Segment ID = " +
|
||||
std::to_string(this->id_));
|
||||
}
|
||||
|
||||
auto& request_fields = plan->extra_info_opt_.value().involved_fields_;
|
||||
auto field_ready_bitset = field_data_ready_bitset_ | index_ready_bitset_;
|
||||
AssertInfo(request_fields.size() == field_ready_bitset.size(),
|
||||
"Request fields size not equal to field ready bitset size when check search");
|
||||
"Request fields size not equal to field ready bitset size when "
|
||||
"check search");
|
||||
auto absent_fields = request_fields - field_ready_bitset;
|
||||
|
||||
if (absent_fields.any()) {
|
||||
auto field_id = FieldId(absent_fields.find_first() + START_USER_FIELDID);
|
||||
auto field_id =
|
||||
FieldId(absent_fields.find_first() + START_USER_FIELDID);
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
PanicInfo("User Field(" + field_meta.get_name().get() + ") is not loaded");
|
||||
PanicInfo("User Field(" + field_meta.get_name().get() +
|
||||
") is not loaded");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -445,18 +508,27 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type,
|
|||
const int64_t* seg_offsets,
|
||||
int64_t count,
|
||||
void* output) const {
|
||||
AssertInfo(is_system_field_ready(), "System field isn't ready when do bulk_insert");
|
||||
AssertInfo(is_system_field_ready(),
|
||||
"System field isn't ready when do bulk_insert");
|
||||
switch (system_type) {
|
||||
case SystemFieldType::Timestamp:
|
||||
AssertInfo(insert_record_.timestamps_.num_chunk() == 1,
|
||||
"num chunk of timestamp not equal to 1 for sealed segment");
|
||||
bulk_subscript_impl<Timestamp>(this->insert_record_.timestamps_.get_chunk_data(0), seg_offsets, count,
|
||||
output);
|
||||
AssertInfo(
|
||||
insert_record_.timestamps_.num_chunk() == 1,
|
||||
"num chunk of timestamp not equal to 1 for sealed segment");
|
||||
bulk_subscript_impl<Timestamp>(
|
||||
this->insert_record_.timestamps_.get_chunk_data(0),
|
||||
seg_offsets,
|
||||
count,
|
||||
output);
|
||||
break;
|
||||
case SystemFieldType::RowId:
|
||||
AssertInfo(insert_record_.row_ids_.num_chunk() == 1,
|
||||
"num chunk of rowID not equal to 1 for sealed segment");
|
||||
bulk_subscript_impl<int64_t>(this->insert_record_.row_ids_.get_chunk_data(0), seg_offsets, count, output);
|
||||
bulk_subscript_impl<int64_t>(
|
||||
this->insert_record_.row_ids_.get_chunk_data(0),
|
||||
seg_offsets,
|
||||
count,
|
||||
output);
|
||||
break;
|
||||
default:
|
||||
PanicInfo("unknown subscript fields");
|
||||
|
@ -465,7 +537,10 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type,
|
|||
|
||||
template <typename T>
|
||||
void
|
||||
SegmentSealedImpl::bulk_subscript_impl(const void* src_raw, const int64_t* seg_offsets, int64_t count, void* dst_raw) {
|
||||
SegmentSealedImpl::bulk_subscript_impl(const void* src_raw,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count,
|
||||
void* dst_raw) {
|
||||
static_assert(IsScalar<T>);
|
||||
auto src = reinterpret_cast<const T*>(src_raw);
|
||||
auto dst = reinterpret_cast<T*>(dst_raw);
|
||||
|
@ -495,14 +570,19 @@ SegmentSealedImpl::bulk_subscript_impl(const VariableField& field,
|
|||
|
||||
// for vector
|
||||
void
|
||||
SegmentSealedImpl::bulk_subscript_impl(
|
||||
int64_t element_sizeof, const void* src_raw, const int64_t* seg_offsets, int64_t count, void* dst_raw) {
|
||||
SegmentSealedImpl::bulk_subscript_impl(int64_t element_sizeof,
|
||||
const void* src_raw,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count,
|
||||
void* dst_raw) {
|
||||
auto src_vec = reinterpret_cast<const char*>(src_raw);
|
||||
auto dst_vec = reinterpret_cast<char*>(dst_raw);
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
auto offset = seg_offsets[i];
|
||||
auto dst = dst_vec + i * element_sizeof;
|
||||
const char* src = (offset == INVALID_SEG_OFFSET ? nullptr : (src_vec + element_sizeof * offset));
|
||||
const char* src = (offset == INVALID_SEG_OFFSET
|
||||
? nullptr
|
||||
: (src_vec + element_sizeof * offset));
|
||||
if (!src) {
|
||||
continue;
|
||||
}
|
||||
|
@ -520,7 +600,9 @@ SegmentSealedImpl::fill_with_empty(FieldId field_id, int64_t count) const {
|
|||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const {
|
||||
SegmentSealedImpl::bulk_subscript(FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const {
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
// if count == 0, return empty data array
|
||||
if (count == 0) {
|
||||
|
@ -530,7 +612,8 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
|
|||
if (HasIndex(field_id)) {
|
||||
// if field has load scalar index, reverse raw data from index
|
||||
if (!datatype_is_vector(field_meta.get_data_type())) {
|
||||
AssertInfo(num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
AssertInfo(num_chunk() == 1,
|
||||
"num chunk not equal to 1 for sealed segment");
|
||||
auto index = chunk_index_impl(field_id, 0);
|
||||
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
|
||||
}
|
||||
|
@ -547,12 +630,18 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
|
|||
case DataType::VARCHAR:
|
||||
case DataType::STRING: {
|
||||
FixedVector<std::string> output(count);
|
||||
bulk_subscript_impl<std::string>(variable_fields_.at(field_id), seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
bulk_subscript_impl<std::string>(variable_fields_.at(field_id),
|
||||
seg_offsets,
|
||||
count,
|
||||
output.data());
|
||||
return CreateScalarDataArrayFrom(
|
||||
output.data(), count, field_meta);
|
||||
}
|
||||
|
||||
default:
|
||||
PanicInfo(fmt::format("unsupported data type: {}", datatype_name(field_meta.get_data_type())));
|
||||
PanicInfo(
|
||||
fmt::format("unsupported data type: {}",
|
||||
datatype_name(field_meta.get_data_type())));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -560,44 +649,55 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
|
|||
switch (field_meta.get_data_type()) {
|
||||
case DataType::BOOL: {
|
||||
FixedVector<bool> output(count);
|
||||
bulk_subscript_impl<bool>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<bool>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
FixedVector<int8_t> output(count);
|
||||
bulk_subscript_impl<int8_t>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int8_t>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
FixedVector<int16_t> output(count);
|
||||
bulk_subscript_impl<int16_t>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int16_t>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
FixedVector<int32_t> output(count);
|
||||
bulk_subscript_impl<int32_t>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int32_t>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
FixedVector<int64_t> output(count);
|
||||
bulk_subscript_impl<int64_t>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<int64_t>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
FixedVector<float> output(count);
|
||||
bulk_subscript_impl<float>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<float>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
FixedVector<double> output(count);
|
||||
bulk_subscript_impl<double>(src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl<double>(
|
||||
src_vec, seg_offsets, count, output.data());
|
||||
return CreateScalarDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
|
||||
case DataType::VECTOR_FLOAT:
|
||||
case DataType::VECTOR_BINARY: {
|
||||
aligned_vector<char> output(field_meta.get_sizeof() * count);
|
||||
bulk_subscript_impl(field_meta.get_sizeof(), src_vec, seg_offsets, count, output.data());
|
||||
bulk_subscript_impl(field_meta.get_sizeof(),
|
||||
src_vec,
|
||||
seg_offsets,
|
||||
count,
|
||||
output.data());
|
||||
return CreateVectorDataArrayFrom(output.data(), count, field_meta);
|
||||
}
|
||||
|
||||
|
@ -624,7 +724,8 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
|
|||
}
|
||||
|
||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||
SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const {
|
||||
SegmentSealedImpl::search_ids(const IdArray& id_array,
|
||||
Timestamp timestamp) const {
|
||||
AssertInfo(id_array.has_int_id(), "Id array doesn't have int_id element");
|
||||
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
|
||||
AssertInfo(field_id.get() != -1, "Primary key is -1");
|
||||
|
@ -641,11 +742,13 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) cons
|
|||
for (auto offset : segOffsets) {
|
||||
switch (data_type) {
|
||||
case DataType::INT64: {
|
||||
res_id_arr->mutable_int_id()->add_data(std::get<int64_t>(pk));
|
||||
res_id_arr->mutable_int_id()->add_data(
|
||||
std::get<int64_t>(pk));
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
res_id_arr->mutable_str_id()->add_data(std::get<std::string>(pk));
|
||||
res_id_arr->mutable_str_id()->add_data(
|
||||
std::get<std::string>(pk));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -659,7 +762,10 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) cons
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentSealedImpl::Delete(int64_t reserved_offset, int64_t size, const IdArray* ids, const Timestamp* timestamps_raw) {
|
||||
SegmentSealedImpl::Delete(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
const IdArray* ids,
|
||||
const Timestamp* timestamps_raw) {
|
||||
auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1));
|
||||
AssertInfo(field_id.get() != -1, "Primary key is -1");
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
|
@ -680,14 +786,17 @@ SegmentSealedImpl::Delete(int64_t reserved_offset, int64_t size, const IdArray*
|
|||
sort_timestamps[i] = t;
|
||||
sort_pks[i] = pk;
|
||||
}
|
||||
deleted_record_.timestamps_.set_data_raw(reserved_offset, sort_timestamps.data(), size);
|
||||
deleted_record_.timestamps_.set_data_raw(
|
||||
reserved_offset, sort_timestamps.data(), size);
|
||||
deleted_record_.pks_.set_data_raw(reserved_offset, sort_pks.data(), size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_offset, reserved_offset + size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_offset,
|
||||
reserved_offset + size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<SegOffset>
|
||||
SegmentSealedImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) const {
|
||||
SegmentSealedImpl::search_ids(const BitsetType& bitset,
|
||||
Timestamp timestamp) const {
|
||||
std::vector<SegOffset> dst_offset;
|
||||
for (int i = 0; i < bitset.size(); i++) {
|
||||
if (bitset[i]) {
|
||||
|
@ -701,7 +810,8 @@ SegmentSealedImpl::search_ids(const BitsetType& bitset, Timestamp timestamp) con
|
|||
}
|
||||
|
||||
std::vector<SegOffset>
|
||||
SegmentSealedImpl::search_ids(const BitsetView& bitset, Timestamp timestamp) const {
|
||||
SegmentSealedImpl::search_ids(const BitsetView& bitset,
|
||||
Timestamp timestamp) const {
|
||||
std::vector<SegOffset> dst_offset;
|
||||
for (int i = 0; i < bitset.size(); i++) {
|
||||
if (!bitset.test(i)) {
|
||||
|
@ -723,7 +833,8 @@ SegmentSealedImpl::debug() const {
|
|||
}
|
||||
|
||||
void
|
||||
SegmentSealedImpl::LoadSegmentMeta(const proto::segcore::LoadSegmentMeta& segment_meta) {
|
||||
SegmentSealedImpl::LoadSegmentMeta(
|
||||
const proto::segcore::LoadSegmentMeta& segment_meta) {
|
||||
std::unique_lock lck(mutex_);
|
||||
std::vector<int64_t> slice_lengths;
|
||||
for (auto& info : segment_meta.metas()) {
|
||||
|
@ -740,11 +851,14 @@ SegmentSealedImpl::get_active_count(Timestamp ts) const {
|
|||
}
|
||||
|
||||
void
|
||||
SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk, Timestamp timestamp) const {
|
||||
SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk,
|
||||
Timestamp timestamp) const {
|
||||
// TODO change the
|
||||
AssertInfo(insert_record_.timestamps_.num_chunk() == 1, "num chunk not equal to 1 for sealed segment");
|
||||
AssertInfo(insert_record_.timestamps_.num_chunk() == 1,
|
||||
"num chunk not equal to 1 for sealed segment");
|
||||
const auto& timestamps_data = insert_record_.timestamps_.get_chunk(0);
|
||||
AssertInfo(timestamps_data.size() == get_row_count(), "Timestamp size not equal to row count");
|
||||
AssertInfo(timestamps_data.size() == get_row_count(),
|
||||
"Timestamp size not equal to row count");
|
||||
auto range = insert_record_.timestamp_index_.get_active_range(timestamp);
|
||||
|
||||
// range == (size_, size_) and size_ is this->timestamps_.size().
|
||||
|
@ -760,7 +874,8 @@ SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk, Timestamp time
|
|||
bitset_chunk.set();
|
||||
return;
|
||||
}
|
||||
auto mask = TimestampIndex::GenerateBitset(timestamp, range, timestamps_data.data(), timestamps_data.size());
|
||||
auto mask = TimestampIndex::GenerateBitset(
|
||||
timestamp, range, timestamps_data.data(), timestamps_data.size());
|
||||
bitset_chunk |= mask;
|
||||
}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue