enhance: Refactor runtime and expr framework (#28166)

#28165

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
pull/29293/head
zhagnlu 2023-12-18 12:04:42 +08:00 committed by GitHub
parent 438f39e268
commit a602171d06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
133 changed files with 13062 additions and 917 deletions

View File

@ -30,7 +30,7 @@ ifdef USE_ASAN
use_asan =${USE_ASAN}
endif
use_dynamic_simd = OFF
use_dynamic_simd = ON
ifdef USE_DYNAMIC_SIMD
use_dynamic_simd = ${USE_DYNAMIC_SIMD}
endif

View File

@ -5,6 +5,11 @@
# `INFO``, ``WARNING``, ``ERROR``, and ``FATAL`` are 0, 1, 2, and 3
--minloglevel=0
--log_dir=/var/lib/milvus/logs/
# using vlog to implement debug and trace log
# if set vmodule to 5, open debug level
# if set vmodule to 6, open trace level
# default 4, not open debug and trace
--v=4
# MB
--max_log_size=200
--stop_logging_if_full_disk=true
--stop_logging_if_full_disk=true

View File

@ -288,6 +288,7 @@ queryNode:
# This parameter is only useful when enable-disk = true.
# And this value should be a number greater than 1 and less than 32.
chunkRows: 1024 # The number of vectors in a chunk.
exprEvalBatchSize: 8192 # The batch size for executor get next
interimIndex: # build a vector temperate index for growing segment or binlog to accelerate search
enableIndex: true
nlist: 128 # segment index nlist

View File

@ -36,7 +36,7 @@ class MilvusConan(ConanFile):
"xz_utils/5.4.0",
"prometheus-cpp/1.1.0",
"re2/20230301",
"folly/2023.10.30.04@milvus/dev",
"folly/2023.10.30.05@milvus/dev",
"google-cloud-cpp/2.5.0@milvus/dev",
"opentelemetry-cpp/1.8.1.1@milvus/dev",
"librdkafka/1.9.1",
@ -44,6 +44,9 @@ class MilvusConan(ConanFile):
)
generators = ("cmake", "cmake_find_package")
default_options = {
"libevent:shared": True,
"double-conversion:shared": True,
"folly:shared": True,
"librdkafka:shared": True,
"librdkafka:zstd": True,
"librdkafka:ssl": True,

View File

@ -32,6 +32,7 @@ add_subdirectory( index )
add_subdirectory( query )
add_subdirectory( segcore )
add_subdirectory( indexbuilder )
add_subdirectory( exec )
if(USE_DYNAMIC_SIMD)
add_subdirectory( simd )
endif()

View File

@ -22,6 +22,7 @@ set(COMMON_SRC
Tracer.cpp
IndexMeta.cpp
EasyAssert.cpp
FieldData.cpp
)
add_library(milvus_common SHARED ${COMMON_SRC})

View File

@ -27,6 +27,7 @@ int64_t MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT =
int64_t LOW_PRIORITY_THREAD_CORE_COEFFICIENT =
DEFAULT_LOW_PRIORITY_THREAD_CORE_COEFFICIENT;
int CPU_NUM = DEFAULT_CPU_NUM;
int64_t EXEC_EVAL_EXPR_BATCH_SIZE = DEFAULT_EXEC_EVAL_EXPR_BATCH_SIZE;
void
SetIndexSliceSize(const int64_t size) {
@ -56,6 +57,13 @@ SetLowPriorityThreadCoreCoefficient(const int64_t coefficient) {
<< LOW_PRIORITY_THREAD_CORE_COEFFICIENT;
}
void
SetDefaultExecEvalExprBatchSize(int64_t val) {
EXEC_EVAL_EXPR_BATCH_SIZE = val;
LOG_SEGCORE_INFO_ << "set default expr eval batch size: "
<< EXEC_EVAL_EXPR_BATCH_SIZE;
}
void
SetCpuNum(const int num) {
CPU_NUM = num;

View File

@ -26,6 +26,7 @@ extern int64_t HIGH_PRIORITY_THREAD_CORE_COEFFICIENT;
extern int64_t MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT;
extern int64_t LOW_PRIORITY_THREAD_CORE_COEFFICIENT;
extern int CPU_NUM;
extern int64_t EXEC_EVAL_EXPR_BATCH_SIZE;
void
SetIndexSliceSize(const int64_t size);
@ -42,4 +43,7 @@ SetLowPriorityThreadCoreCoefficient(const int64_t coefficient);
void
SetCpuNum(const int core);
void
SetDefaultExecEvalExprBatchSize(int64_t val);
} // namespace milvus

View File

@ -39,6 +39,10 @@ const char INDEX_BUILD_ID_KEY[] = "indexBuildID";
const char INDEX_ROOT_PATH[] = "index_files";
const char RAWDATA_ROOT_PATH[] = "raw_datas";
const char DEFAULT_PLANNODE_ID[] = "0";
const char DEAFULT_QUERY_ID[] = "0";
const char DEFAULT_TASK_ID[] = "0";
const int64_t DEFAULT_FIELD_MAX_MEMORY_LIMIT = 64 << 20; // bytes
const int64_t DEFAULT_HIGH_PRIORITY_THREAD_CORE_COEFFICIENT = 10;
const int64_t DEFAULT_MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT = 5;
@ -48,6 +52,8 @@ const int64_t DEFAULT_INDEX_FILE_SLICE_SIZE = 4 << 20; // bytes
const int DEFAULT_CPU_NUM = 1;
const int64_t DEFAULT_EXEC_EVAL_EXPR_BATCH_SIZE = 8192;
constexpr const char* RADIUS = knowhere::meta::RADIUS;
constexpr const char* RANGE_FILTER = knowhere::meta::RANGE_FILTER;

View File

@ -0,0 +1,218 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
namespace milvus {
class NotImplementedException : public std::exception {
public:
explicit NotImplementedException(const std::string& msg)
: std::exception(), exception_message_(msg) {
}
const char*
what() const noexcept {
return exception_message_.c_str();
}
virtual ~NotImplementedException() {
}
private:
std::string exception_message_;
};
class NotSupportedDataTypeException : public std::exception {
public:
explicit NotSupportedDataTypeException(const std::string& msg)
: std::exception(), exception_message_(msg) {
}
const char*
what() const noexcept {
return exception_message_.c_str();
}
virtual ~NotSupportedDataTypeException() {
}
private:
std::string exception_message_;
};
class UnistdException : public std::runtime_error {
public:
explicit UnistdException(const std::string& msg) : std::runtime_error(msg) {
}
virtual ~UnistdException() {
}
};
// Exceptions for storage module
class LocalChunkManagerException : public std::runtime_error {
public:
explicit LocalChunkManagerException(const std::string& msg)
: std::runtime_error(msg) {
}
virtual ~LocalChunkManagerException() {
}
};
class InvalidPathException : public LocalChunkManagerException {
public:
explicit InvalidPathException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~InvalidPathException() {
}
};
class OpenFileException : public LocalChunkManagerException {
public:
explicit OpenFileException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~OpenFileException() {
}
};
class CreateFileException : public LocalChunkManagerException {
public:
explicit CreateFileException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~CreateFileException() {
}
};
class ReadFileException : public LocalChunkManagerException {
public:
explicit ReadFileException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~ReadFileException() {
}
};
class WriteFileException : public LocalChunkManagerException {
public:
explicit WriteFileException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~WriteFileException() {
}
};
class PathAlreadyExistException : public LocalChunkManagerException {
public:
explicit PathAlreadyExistException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~PathAlreadyExistException() {
}
};
class DirNotExistException : public LocalChunkManagerException {
public:
explicit DirNotExistException(const std::string& msg)
: LocalChunkManagerException(msg) {
}
virtual ~DirNotExistException() {
}
};
class MinioException : public std::runtime_error {
public:
explicit MinioException(const std::string& msg) : std::runtime_error(msg) {
}
virtual ~MinioException() {
}
};
class InvalidBucketNameException : public MinioException {
public:
explicit InvalidBucketNameException(const std::string& msg)
: MinioException(msg) {
}
virtual ~InvalidBucketNameException() {
}
};
class ObjectNotExistException : public MinioException {
public:
explicit ObjectNotExistException(const std::string& msg)
: MinioException(msg) {
}
virtual ~ObjectNotExistException() {
}
};
class S3ErrorException : public MinioException {
public:
explicit S3ErrorException(const std::string& msg) : MinioException(msg) {
}
virtual ~S3ErrorException() {
}
};
class DiskANNFileManagerException : public std::runtime_error {
public:
explicit DiskANNFileManagerException(const std::string& msg)
: std::runtime_error(msg) {
}
virtual ~DiskANNFileManagerException() {
}
};
class ArrowException : public std::runtime_error {
public:
explicit ArrowException(const std::string& msg) : std::runtime_error(msg) {
}
virtual ~ArrowException() {
}
};
// Exceptions for executor module
class ExecDriverException : public std::exception {
public:
explicit ExecDriverException(const std::string& msg)
: std::exception(), exception_message_(msg) {
}
const char*
what() const noexcept {
return exception_message_.c_str();
}
virtual ~ExecDriverException() {
}
private:
std::string exception_message_;
};
class ExecOperatorException : public std::exception {
public:
explicit ExecOperatorException(const std::string& msg)
: std::exception(), exception_message_(msg) {
}
const char*
what() const noexcept {
return exception_message_.c_str();
}
virtual ~ExecOperatorException() {
}
private:
std::string exception_message_;
};
} // namespace milvus

View File

@ -14,15 +14,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "storage/FieldData.h"
#include "common/FieldData.h"
#include "arrow/array/array_binary.h"
#include "common/Array.h"
#include "common/EasyAssert.h"
#include "common/Exception.h"
#include "common/FieldDataInterface.h"
#include "common/Json.h"
#include "simdjson/padded_string.h"
#include "common/Array.h"
#include "FieldDataInterface.h"
namespace milvus::storage {
namespace milvus {
template <typename Type, bool is_scalar>
void
@ -183,4 +185,33 @@ template class FieldDataImpl<int8_t, false>;
template class FieldDataImpl<float, false>;
template class FieldDataImpl<float16, false>;
} // namespace milvus::storage
FieldDataPtr
InitScalarFieldData(const DataType& type, int64_t cap_rows) {
switch (type) {
case DataType::BOOL:
return std::make_shared<FieldData<bool>>(type, cap_rows);
case DataType::INT8:
return std::make_shared<FieldData<int8_t>>(type, cap_rows);
case DataType::INT16:
return std::make_shared<FieldData<int16_t>>(type, cap_rows);
case DataType::INT32:
return std::make_shared<FieldData<int32_t>>(type, cap_rows);
case DataType::INT64:
return std::make_shared<FieldData<int64_t>>(type, cap_rows);
case DataType::FLOAT:
return std::make_shared<FieldData<float>>(type, cap_rows);
case DataType::DOUBLE:
return std::make_shared<FieldData<double>>(type, cap_rows);
case DataType::STRING:
case DataType::VARCHAR:
return std::make_shared<FieldData<std::string>>(type, cap_rows);
case DataType::JSON:
return std::make_shared<FieldData<Json>>(type, cap_rows);
default:
throw NotSupportedDataTypeException(
"InitScalarFieldData not support data type " +
datatype_name(type));
}
}
} // namespace milvus

View File

@ -21,10 +21,10 @@
#include <oneapi/tbb/concurrent_queue.h>
#include "storage/FieldDataInterface.h"
#include "common/FieldDataInterface.h"
#include "common/Channel.h"
namespace milvus::storage {
namespace milvus {
template <typename Type>
class FieldData : public FieldDataImpl<Type, true> {
@ -34,6 +34,11 @@ class FieldData : public FieldDataImpl<Type, true> {
: FieldDataImpl<Type, true>::FieldDataImpl(
1, data_type, buffered_num_rows) {
}
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
explicit FieldData(DataType data_type, FixedVector<Type>&& inner_data)
: FieldDataImpl<Type, true>::FieldDataImpl(
1, data_type, std::move(inner_data)) {
}
};
template <>
@ -106,7 +111,10 @@ class FieldData<Float16Vector> : public FieldDataImpl<float16, false> {
};
using FieldDataPtr = std::shared_ptr<FieldDataBase>;
using FieldDataChannel = Channel<storage::FieldDataPtr>;
using FieldDataChannel = Channel<FieldDataPtr>;
using FieldDataChannelPtr = std::shared_ptr<FieldDataChannel>;
} // namespace milvus::storage
FieldDataPtr
InitScalarFieldData(const DataType& type, int64_t cap_rows);
} // namespace milvus

View File

@ -33,7 +33,7 @@
#include "common/EasyAssert.h"
#include "common/Array.h"
namespace milvus::storage {
namespace milvus {
using DataType = milvus::DataType;
@ -49,8 +49,8 @@ class FieldDataBase {
virtual void
FillFieldData(const std::shared_ptr<arrow::Array> array) = 0;
virtual const void*
Data() const = 0;
virtual void*
Data() = 0;
virtual const void*
RawValue(ssize_t offset) const = 0;
@ -109,6 +109,12 @@ class FieldDataImpl : public FieldDataBase {
field_data_.resize(num_rows_ * dim_);
}
explicit FieldDataImpl(size_t dim, DataType type, Chunk&& field_data)
: FieldDataBase(type), dim_(is_scalar ? 1 : dim) {
field_data_ = std::move(field_data);
num_rows_ = field_data.size() / dim;
}
void
FillFieldData(const void* source, ssize_t element_count) override;
@ -126,8 +132,8 @@ class FieldDataImpl : public FieldDataBase {
return "FieldDataImpl";
}
const void*
Data() const override {
void*
Data() override {
return field_data_.data();
}
@ -332,4 +338,4 @@ class FieldDataArrayImpl : public FieldDataImpl<Array, true> {
}
};
} // namespace milvus::storage
} // namespace milvus

View File

@ -0,0 +1,75 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <folly/Unit.h>
#include <folly/futures/Future.h>
#include "log/Log.h"
namespace milvus {
template <class T>
class MilvusPromise : public folly::Promise<T> {
public:
MilvusPromise() : folly::Promise<T>() {
}
explicit MilvusPromise(const std::string& context)
: folly::Promise<T>(), context_(context) {
}
MilvusPromise(folly::futures::detail::EmptyConstruct,
const std::string& context) noexcept
: folly::Promise<T>(folly::Promise<T>::makeEmpty()), context_(context) {
}
~MilvusPromise() {
if (!this->isFulfilled()) {
LOG_SEGCORE_WARNING_
<< "PROMISE: Unfulfilled promise is being deleted. Context: "
<< context_;
}
}
explicit MilvusPromise(MilvusPromise<T>&& other)
: folly::Promise<T>(std::move(other)),
context_(std::move(other.context_)) {
}
MilvusPromise&
operator=(MilvusPromise<T>&& other) noexcept {
folly::Promise<T>::operator=(std::move(other));
context_ = std::move(other.context_);
return *this;
}
static MilvusPromise
MakeEmpty(const std::string& context = "") noexcept {
return MilvusPromise<T>(folly::futures::detail::EmptyConstruct{},
context);
}
private:
/// Optional parameter to understand where this promise was created.
std::string context_;
};
using ContinuePromise = MilvusPromise<folly::Unit>;
using ContinueFuture = folly::SemiFuture<folly::Unit>;
} // namespace milvus

View File

@ -95,6 +95,10 @@ enum class DataType {
ARRAY = 22,
JSON = 23,
// Some special Data type, start from after 50
// just for internal use now, may sync proto in future
ROW = 50,
VECTOR_BINARY = 100,
VECTOR_FLOAT = 101,
VECTOR_FLOAT16 = 102,
@ -182,8 +186,138 @@ using MayConstRef = std::conditional_t<std::is_same_v<T, std::string> ||
const T&,
T>;
static_assert(std::is_same_v<const std::string&, MayConstRef<std::string>>);
template <DataType T>
struct TypeTraits {};
template <>
struct TypeTraits<DataType::NONE> {
static constexpr const char* Name = "NONE";
};
template <>
struct TypeTraits<DataType::BOOL> {
using NativeType = bool;
static constexpr DataType TypeKind = DataType::BOOL;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "BOOL";
};
template <>
struct TypeTraits<DataType::INT8> {
using NativeType = int8_t;
static constexpr DataType TypeKind = DataType::INT8;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "INT8";
};
template <>
struct TypeTraits<DataType::INT16> {
using NativeType = int16_t;
static constexpr DataType TypeKind = DataType::INT16;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "INT16";
};
template <>
struct TypeTraits<DataType::INT32> {
using NativeType = int32_t;
static constexpr DataType TypeKind = DataType::INT32;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "INT32";
};
template <>
struct TypeTraits<DataType::INT64> {
using NativeType = int32_t;
static constexpr DataType TypeKind = DataType::INT64;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "INT64";
};
template <>
struct TypeTraits<DataType::FLOAT> {
using NativeType = float;
static constexpr DataType TypeKind = DataType::FLOAT;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "FLOAT";
};
template <>
struct TypeTraits<DataType::DOUBLE> {
using NativeType = double;
static constexpr DataType TypeKind = DataType::DOUBLE;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = true;
static constexpr const char* Name = "DOUBLE";
};
template <>
struct TypeTraits<DataType::VARCHAR> {
using NativeType = std::string;
static constexpr DataType TypeKind = DataType::VARCHAR;
static constexpr bool IsPrimitiveType = true;
static constexpr bool IsFixedWidth = false;
static constexpr const char* Name = "VARCHAR";
};
template <>
struct TypeTraits<DataType::STRING> : public TypeTraits<DataType::VARCHAR> {
static constexpr DataType TypeKind = DataType::STRING;
static constexpr const char* Name = "STRING";
};
template <>
struct TypeTraits<DataType::ARRAY> {
using NativeType = void;
static constexpr DataType TypeKind = DataType::ARRAY;
static constexpr bool IsPrimitiveType = false;
static constexpr bool IsFixedWidth = false;
static constexpr const char* Name = "ARRAY";
};
template <>
struct TypeTraits<DataType::JSON> {
using NativeType = void;
static constexpr DataType TypeKind = DataType::JSON;
static constexpr bool IsPrimitiveType = false;
static constexpr bool IsFixedWidth = false;
static constexpr const char* Name = "JSON";
};
template <>
struct TypeTraits<DataType::ROW> {
using NativeType = void;
static constexpr DataType TypeKind = DataType::ROW;
static constexpr bool IsPrimitiveType = false;
static constexpr bool IsFixedWidth = false;
static constexpr const char* Name = "ROW";
};
template <>
struct TypeTraits<DataType::VECTOR_BINARY> {
using NativeType = uint8_t;
static constexpr DataType TypeKind = DataType::VECTOR_BINARY;
static constexpr bool IsPrimitiveType = false;
static constexpr bool IsFixedWidth = false;
static constexpr const char* Name = "VECTOR_BINARY";
};
template <>
struct TypeTraits<DataType::VECTOR_FLOAT> {
using NativeType = float;
static constexpr DataType TypeKind = DataType::VECTOR_FLOAT;
static constexpr bool IsPrimitiveType = false;
static constexpr bool IsFixedWidth = false;
static constexpr const char* Name = "VECTOR_FLOAT";
};
} // namespace milvus
//
template <>
struct fmt::formatter<milvus::DataType> : formatter<string_view> {
auto
@ -226,6 +360,9 @@ struct fmt::formatter<milvus::DataType> : formatter<string_view> {
case milvus::DataType::JSON:
name = "JSON";
break;
case milvus::DataType::ROW:
name = "ROW";
break;
case milvus::DataType::VECTOR_BINARY:
name = "VECTOR_BINARY";
break;

View File

@ -192,4 +192,17 @@ is_in_disk_list(const IndexType& index_type) {
return is_in_list<IndexType>(index_type, DISK_INDEX_LIST);
}
template <typename T>
std::string
Join(const std::vector<T>& items, const std::string& delimiter) {
std::stringstream ss;
for (size_t i = 0; i < items.size(); ++i) {
if (i > 0) {
ss << delimiter;
}
ss << items[i];
}
return ss.str();
}
} // namespace milvus

View File

@ -0,0 +1,141 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "common/FieldData.h"
namespace milvus {
/**
* @brief base class for different type vector
* @todo implement full null value support
*/
class BaseVector {
public:
BaseVector(DataType data_type,
size_t length,
std::optional<size_t> null_count = std::nullopt)
: type_kind_(data_type), length_(length), null_count_(null_count) {
}
virtual ~BaseVector() = default;
int64_t
size() {
return length_;
}
DataType
type() {
return type_kind_;
}
protected:
DataType type_kind_;
size_t length_;
std::optional<size_t> null_count_;
};
using VectorPtr = std::shared_ptr<BaseVector>;
/**
* @brief Single vector for scalar types
* @todo using memory pool && buffer replace FieldData
*/
class ColumnVector final : public BaseVector {
public:
ColumnVector(DataType data_type,
size_t length,
std::optional<size_t> null_count = std::nullopt)
: BaseVector(data_type, length, null_count) {
values_ = InitScalarFieldData(data_type, length);
}
ColumnVector(FixedVector<bool>&& data)
: BaseVector(DataType::BOOL, data.size()) {
values_ =
std::make_shared<FieldData<bool>>(DataType::BOOL, std::move(data));
}
virtual ~ColumnVector() override {
values_.reset();
}
void*
GetRawData() {
return values_->Data();
}
template <typename As>
const As*
RawAsValues() const {
return reinterpret_cast<const As*>(values_->Data());
}
private:
FieldDataPtr values_;
};
using ColumnVectorPtr = std::shared_ptr<ColumnVector>;
/**
* @brief Multi vectors for scalar types
* mainly using it to pass internal result in segcore scalar engine system
*/
class RowVector : public BaseVector {
public:
RowVector(std::vector<DataType>& data_types,
size_t length,
std::optional<size_t> null_count = std::nullopt)
: BaseVector(DataType::ROW, length, null_count) {
for (auto& type : data_types) {
children_values_.emplace_back(
std::make_shared<ColumnVector>(type, length));
}
}
RowVector(const std::vector<VectorPtr>& children)
: BaseVector(DataType::ROW, 0) {
for (auto& child : children) {
children_values_.push_back(child);
if (child->size() > length_) {
length_ = child->size();
}
}
}
const std::vector<VectorPtr>&
childrens() {
return children_values_;
}
VectorPtr
child(int index) {
assert(index < children_values_.size());
return children_values_[index];
}
private:
std::vector<VectorPtr> children_values_;
};
using RowVectorPtr = std::shared_ptr<RowVector>;
} // namespace milvus

View File

@ -25,7 +25,7 @@
#include "common/Tracer.h"
#include "log/Log.h"
std::once_flag flag1, flag2, flag3, flag4, flag5;
std::once_flag flag1, flag2, flag3, flag4, flag5, flag6;
std::once_flag traceFlag;
void
@ -70,6 +70,14 @@ InitCpuNum(const int value) {
flag3, [](int value) { milvus::SetCpuNum(value); }, value);
}
void
InitDefaultExprEvalBatchSize(int64_t val) {
std::call_once(
flag6,
[](int val) { milvus::SetDefaultExecEvalExprBatchSize(val); },
val);
}
void
InitTrace(CTraceConfig* config) {
auto traceConfig = milvus::tracer::TraceConfig{config->exporter,

View File

@ -36,6 +36,9 @@ InitMiddlePriorityThreadCoreCoefficient(const int64_t);
void
InitLowPriorityThreadCoreCoefficient(const int64_t);
void
InitDefaultExprEvalBatchSize(int64_t val);
void
InitCpuNum(const int);

View File

@ -0,0 +1,36 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License
set(MILVUS_EXEC_SRCS
expression/Expr.cpp
expression/UnaryExpr.cpp
expression/ConjunctExpr.cpp
expression/LogicalUnaryExpr.cpp
expression/LogicalBinaryExpr.cpp
expression/TermExpr.cpp
expression/BinaryArithOpEvalRangeExpr.cpp
expression/BinaryRangeExpr.cpp
expression/AlwaysTrueExpr.cpp
expression/CompareExpr.cpp
expression/JsonContainsExpr.cpp
expression/ExistsExpr.cpp
operator/FilterBits.cpp
operator/Operator.cpp
Driver.cpp
Task.cpp
)
add_library(milvus_exec STATIC ${MILVUS_EXEC_SRCS})
if(USE_DYNAMIC_SIMD)
target_link_libraries(milvus_exec milvus_common milvus_simd milvus-storage ${CONAN_LIBS})
else()
target_link_libraries(milvus_exec milvus_common milvus-storage ${CONAN_LIBS})
endif()

View File

@ -0,0 +1,355 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Driver.h"
#include <cassert>
#include <memory>
#include "exec/operator/CallbackSink.h"
#include "exec/operator/FilterBits.h"
#include "exec/operator/Operator.h"
#include "exec/Task.h"
#include "common/EasyAssert.h"
namespace milvus {
namespace exec {
std::atomic_uint64_t BlockingState::num_blocked_drivers_{0};
std::shared_ptr<QueryConfig>
DriverContext::GetQueryConfig() {
return task_->query_context()->query_config();
}
std::shared_ptr<Driver>
DriverFactory::CreateDriver(std::unique_ptr<DriverContext> ctx,
std::function<int(int pipelineid)> num_drivers) {
auto driver = std::shared_ptr<Driver>(new Driver());
ctx->driver_ = driver.get();
std::vector<std::unique_ptr<Operator>> operators;
operators.reserve(plannodes_.size());
for (size_t i = 0; i < plannodes_.size(); ++i) {
auto id = operators.size();
auto plannode = plannodes_[i];
if (auto filternode =
std::dynamic_pointer_cast<const plan::FilterBitsNode>(
plannode)) {
operators.push_back(
std::make_unique<FilterBits>(id, ctx.get(), filternode));
}
// TODO: add more operators
}
if (consumer_supplier_) {
operators.push_back(consumer_supplier_(operators.size(), ctx.get()));
}
driver->Init(std::move(ctx), std::move(operators));
return driver;
}
void
Driver::Enqueue(std::shared_ptr<Driver> driver) {
if (driver->closed_) {
return;
}
driver->get_task()->query_context()->executor()->add(
[driver]() { Driver::Run(driver); });
}
void
Driver::Run(std::shared_ptr<Driver> self) {
std::shared_ptr<BlockingState> blocking_state;
RowVectorPtr result;
auto reason = self->RunInternal(self, blocking_state, result);
AssertInfo(result == nullptr,
"The last operator (sink) must not produce any results.");
if (reason == StopReason::kBlock) {
return;
}
switch (reason) {
case StopReason::kBlock:
BlockingState::SetResume(blocking_state);
return;
case StopReason::kYield:
Enqueue(self);
case StopReason::kPause:
case StopReason::kTerminate:
case StopReason::kAlreadyTerminated:
case StopReason::kAtEnd:
return;
default:
AssertInfo(false, "Unhandled stop reason");
}
}
void
Driver::Init(std::unique_ptr<DriverContext> ctx,
std::vector<std::unique_ptr<Operator>> operators) {
assert(ctx != nullptr);
ctx_ = std::move(ctx);
AssertInfo(operators.size() != 0, "operators in driver must not empty");
operators_ = std::move(operators);
current_operator_index_ = operators_.size() - 1;
}
void
Driver::Close() {
if (closed_) {
return;
}
for (auto& op : operators_) {
op->Close();
}
closed_ = true;
Task::RemoveDriver(ctx_->task_, this);
}
RowVectorPtr
Driver::Next(std::shared_ptr<BlockingState>& blocking_state) {
auto self = shared_from_this();
RowVectorPtr result;
auto stop = RunInternal(self, blocking_state, result);
Assert(stop == StopReason::kBlock || stop == StopReason::kAtEnd ||
stop == StopReason::kAlreadyTerminated);
return result;
}
#define CALL_OPERATOR(call_func, operator, method_name) \
try { \
call_func; \
} catch (SegcoreError & e) { \
auto err_msg = fmt::format( \
"Operator::{} failed for [Operator:{}, plan node id: " \
"{}] : {}", \
method_name, \
operator->get_operator_type(), \
operator->get_plannode_id(), \
e.what()); \
LOG_SEGCORE_ERROR_ << err_msg; \
throw ExecOperatorException(err_msg); \
} catch (std::exception & e) { \
throw ExecOperatorException( \
fmt::format("Operator::{} failed for [Operator:{}, plan node id: " \
"{}] : {}", \
method_name, \
operator->get_operator_type(), \
operator->get_plannode_id(), \
e.what())); \
}
StopReason
Driver::RunInternal(std::shared_ptr<Driver>& self,
std::shared_ptr<BlockingState>& blocking_state,
RowVectorPtr& result) {
try {
int num_operators = operators_.size();
ContinueFuture future;
for (;;) {
for (int32_t i = num_operators - 1; i >= 0; --i) {
auto op = operators_[i].get();
current_operator_index_ = i;
CALL_OPERATOR(
blocking_reason_ = op->IsBlocked(&future), op, "IsBlocked");
if (blocking_reason_ != BlockingReason::kNotBlocked) {
blocking_state = std::make_shared<BlockingState>(
self, std::move(future), op, blocking_reason_);
return StopReason::kBlock;
}
Operator* next_op = nullptr;
if (i < operators_.size() - 1) {
next_op = operators_[i + 1].get();
CALL_OPERATOR(
blocking_reason_ = next_op->IsBlocked(&future),
next_op,
"IsBlocked");
if (blocking_reason_ != BlockingReason::kNotBlocked) {
blocking_state = std::make_shared<BlockingState>(
self, std::move(future), next_op, blocking_reason_);
return StopReason::kBlock;
}
bool needs_input;
CALL_OPERATOR(needs_input = next_op->NeedInput(),
next_op,
"NeedInput");
if (needs_input) {
RowVectorPtr result;
{
CALL_OPERATOR(
result = op->GetOutput(), op, "GetOutput");
if (result) {
AssertInfo(
result->size() > 0,
fmt::format(
"GetOutput must return nullptr or "
"a non-empty vector: {}",
op->get_operator_type()));
}
}
if (result) {
CALL_OPERATOR(
next_op->AddInput(result), next_op, "AddInput");
i += 2;
continue;
} else {
CALL_OPERATOR(
blocking_reason_ = op->IsBlocked(&future),
op,
"IsBlocked");
if (blocking_reason_ !=
BlockingReason::kNotBlocked) {
blocking_state =
std::make_shared<BlockingState>(
self,
std::move(future),
next_op,
blocking_reason_);
return StopReason::kBlock;
}
if (op->IsFinished()) {
CALL_OPERATOR(next_op->NoMoreInput(),
next_op,
"NoMoreInput");
break;
}
}
}
} else {
{
CALL_OPERATOR(
result = op->GetOutput(), op, "GetOutput");
if (result) {
AssertInfo(
result->size() > 0,
fmt::format("GetOutput must return nullptr or "
"a non-empty vector: {}",
op->get_operator_type()));
blocking_reason_ = BlockingReason::kWaitForConsumer;
return StopReason::kBlock;
}
}
if (op->IsFinished()) {
Close();
return StopReason::kAtEnd;
}
continue;
}
}
}
} catch (std::exception& e) {
get_task()->SetError(std::current_exception());
return StopReason::kAlreadyTerminated;
}
}
static bool
MustStartNewPipeline(std::shared_ptr<const plan::PlanNode> plannode,
int source_id) {
//TODO: support LocalMerge and other shuffle
return source_id != 0;
}
OperatorSupplier
MakeConsumerSupplier(ConsumerSupplier supplier) {
if (supplier) {
return [supplier](int32_t operator_id, DriverContext* ctx) {
return std::make_unique<CallbackSink>(operator_id, ctx, supplier());
};
}
return nullptr;
}
uint32_t
MaxDrivers(const DriverFactory* factory, const QueryConfig& config) {
return 1;
}
static void
SplitPlan(const std::shared_ptr<const plan::PlanNode>& plannode,
std::vector<std::shared_ptr<const plan::PlanNode>>* current_plannodes,
const std::shared_ptr<const plan::PlanNode>& consumer_node,
OperatorSupplier operator_supplier,
std::vector<std::unique_ptr<DriverFactory>>* driver_factories) {
if (!current_plannodes) {
driver_factories->push_back(std::make_unique<DriverFactory>());
current_plannodes = &driver_factories->back()->plannodes_;
driver_factories->back()->consumer_supplier_ = operator_supplier;
driver_factories->back()->consumer_node_ = consumer_node;
}
auto sources = plannode->sources();
if (sources.empty()) {
driver_factories->back()->is_input_driver_ = true;
} else {
for (int i = 0; i < sources.size(); ++i) {
SplitPlan(
sources[i],
MustStartNewPipeline(plannode, i) ? nullptr : current_plannodes,
plannode,
nullptr,
driver_factories);
}
}
current_plannodes->push_back(plannode);
}
void
LocalPlanner::Plan(
const plan::PlanFragment& fragment,
ConsumerSupplier consumer_supplier,
std::vector<std::unique_ptr<DriverFactory>>* driver_factories,
const QueryConfig& config,
uint32_t max_drivers) {
SplitPlan(fragment.plan_node_,
nullptr,
nullptr,
MakeConsumerSupplier(consumer_supplier),
driver_factories);
(*driver_factories)[0]->is_output_driver_ = true;
for (auto& factory : *driver_factories) {
factory->max_drivers_ = MaxDrivers(factory.get(), config);
factory->num_drivers_ = std::min(factory->max_drivers_, max_drivers);
if (factory->is_group_execution_) {
factory->num_total_drivers_ =
factory->num_drivers_ * fragment.num_splitgroups_;
} else {
factory->num_total_drivers_ = factory->num_drivers_;
}
}
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,254 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "common/Types.h"
#include "common/Promise.h"
#include "exec/QueryContext.h"
#include "plan/PlanNode.h"
namespace milvus {
namespace exec {
enum class StopReason {
// Keep running.
kNone,
// Go off thread and do not schedule more activity.
kPause,
// Stop and free all. This is returned once and the thread that gets
// this value is responsible for freeing the state associated with
// the thread. Other threads will get kAlreadyTerminated after the
// first thread has received kTerminate.
kTerminate,
kAlreadyTerminated,
// Go off thread and then enqueue to the back of the runnable queue.
kYield,
// Must wait for external events.
kBlock,
// No more data to produce.
kAtEnd,
kAlreadyOnThread
};
enum class BlockingReason {
kNotBlocked,
kWaitForConsumer,
kWaitForSplit,
kWaitForExchange,
kWaitForJoinBuild,
/// For a build operator, it is blocked waiting for the probe operators to
/// finish probing before build the next hash table from one of the previously
/// spilled partition data.
/// For a probe operator, it is blocked waiting for all its peer probe
/// operators to finish probing before notifying the build operators to build
/// the next hash table from the previously spilled data.
kWaitForJoinProbe,
kWaitForMemory,
kWaitForConnector,
/// Build operator is blocked waiting for all its peers to stop to run group
/// spill on all of them.
kWaitForSpill,
};
class Driver;
class Operator;
class Task;
class BlockingState {
public:
BlockingState(std::shared_ptr<Driver> driver,
ContinueFuture&& future,
Operator* op,
BlockingReason reason)
: driver_(std::move(driver_)),
future_(std::move(future)),
operator_(op),
reason_(reason) {
num_blocked_drivers_++;
}
~BlockingState() {
num_blocked_drivers_--;
}
static void
SetResume(std::shared_ptr<BlockingState> state) {
}
Operator*
op() {
return operator_;
}
BlockingReason
reason() {
return reason_;
}
// Moves out the blocking future stored inside. Can be called only once. Used
// in single-threaded execution.
ContinueFuture
future() {
return std::move(future_);
}
// Returns total number of drivers process wide that are currently in blocked
// state.
static uint64_t
get_num_blocked_drivers() {
return num_blocked_drivers_;
}
private:
std::shared_ptr<Driver> driver_;
ContinueFuture future_;
Operator* operator_;
BlockingReason reason_;
static std::atomic_uint64_t num_blocked_drivers_;
};
struct DriverContext {
int driverid_;
int pipelineid_;
uint32_t split_groupid_;
uint32_t partitionid_;
std::shared_ptr<Task> task_;
Driver* driver_;
explicit DriverContext(std::shared_ptr<Task> task,
int driverid,
int pipilineid,
uint32_t split_group_id,
uint32_t partition_id)
: driverid_(driverid),
pipelineid_(pipilineid),
split_groupid_(split_group_id),
partitionid_(partition_id),
task_(task) {
}
std::shared_ptr<QueryConfig>
GetQueryConfig();
};
using OperatorSupplier = std::function<std::unique_ptr<Operator>(
int32_t operatorid, DriverContext* ctx)>;
struct DriverFactory {
std::vector<std::shared_ptr<const plan::PlanNode>> plannodes_;
OperatorSupplier consumer_supplier_;
// The (local) node that will consume results supplied by this pipeline.
// Can be null. We use that to determine the max drivers.
std::shared_ptr<const plan::PlanNode> consumer_node_;
uint32_t max_drivers_;
uint32_t num_drivers_;
uint32_t num_total_drivers_;
bool is_group_execution_;
bool is_input_driver_;
bool is_output_driver_;
std::shared_ptr<Driver>
CreateDriver(std::unique_ptr<DriverContext> ctx,
// TODO: support exchange function
// std::shared_ptr<ExchangeClient> exchange_client,
std::function<int(int pipilineid)> num_driver);
// TODO: support ditribution compute
bool
SupportSingleThreadExecution() const {
return true;
}
};
class Driver : public std::enable_shared_from_this<Driver> {
public:
static void
Enqueue(std::shared_ptr<Driver> instance);
RowVectorPtr
Next(std::shared_ptr<BlockingState>& blocking_state);
DriverContext*
get_driver_context() const {
return ctx_.get();
}
const std::shared_ptr<Task>&
get_task() const {
return ctx_->task_;
}
BlockingReason
GetBlockingReason() const {
return blocking_reason_;
}
void
Init(std::unique_ptr<DriverContext> driver_ctx,
std::vector<std::unique_ptr<Operator>> operators);
private:
Driver() = default;
void
EnqueueInternal() {
}
static void
Run(std::shared_ptr<Driver> self);
StopReason
RunInternal(std::shared_ptr<Driver>& self,
std::shared_ptr<BlockingState>& blocking_state,
RowVectorPtr& result);
void
Close();
std::unique_ptr<DriverContext> ctx_;
std::atomic_bool closed_{false};
std::vector<std::unique_ptr<Operator>> operators_;
size_t current_operator_index_{0};
BlockingReason blocking_reason_{BlockingReason::kNotBlocked};
friend struct DriverFactory;
};
using Consumer = std::function<BlockingReason(RowVectorPtr, ContinueFuture*)>;
using ConsumerSupplier = std::function<Consumer()>;
class LocalPlanner {
public:
static void
Plan(const plan::PlanFragment& fragment,
ConsumerSupplier consumer_supplier,
std::vector<std::unique_ptr<DriverFactory>>* driver_factories,
const QueryConfig& config,
uint32_t max_drivers);
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,257 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include <folly/Executor.h>
#include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/Optional.h>
#include "common/Common.h"
#include "common/Types.h"
#include "common/Exception.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
enum class ContextScope { GLOBAL = 0, SESSION = 1, QUERY = 2, Executor = 3 };
class BaseConfig {
public:
virtual folly::Optional<std::string>
Get(const std::string& key) const = 0;
template <typename T>
folly::Optional<T>
Get(const std::string& key) const {
auto val = Get(key);
if (val.hasValue()) {
return folly::to<T>(val.value());
} else {
return folly::none;
}
}
template <typename T>
T
Get(const std::string& key, const T& default_value) const {
auto val = Get(key);
if (val.hasValue()) {
return folly::to<T>(val.value());
} else {
return default_value;
}
}
virtual bool
IsValueExists(const std::string& key) const = 0;
virtual const std::unordered_map<std::string, std::string>&
values() const {
throw NotImplementedException("method values() is not supported");
}
virtual ~BaseConfig() = default;
};
class MemConfig : public BaseConfig {
public:
explicit MemConfig(
const std::unordered_map<std::string, std::string>& values)
: values_(values) {
}
explicit MemConfig() : values_{} {
}
explicit MemConfig(std::unordered_map<std::string, std::string>&& values)
: values_(std::move(values)) {
}
folly::Optional<std::string>
Get(const std::string& key) const override {
folly::Optional<std::string> val;
auto it = values_.find(key);
if (it != values_.end()) {
val = it->second;
}
return val;
}
bool
IsValueExists(const std::string& key) const override {
return values_.find(key) != values_.end();
}
const std::unordered_map<std::string, std::string>&
values() const override {
return values_;
}
private:
std::unordered_map<std::string, std::string> values_;
};
class QueryConfig : public MemConfig {
public:
// Whether to use the simplified expression evaluation path. False by default.
static constexpr const char* kExprEvalSimplified =
"expression.eval_simplified";
static constexpr const char* kExprEvalBatchSize =
"expression.eval_batch_size";
QueryConfig(const std::unordered_map<std::string, std::string>& values)
: MemConfig(values) {
}
QueryConfig() = default;
bool
get_expr_eval_simplified() const {
return BaseConfig::Get<bool>(kExprEvalSimplified, false);
}
int64_t
get_expr_batch_size() const {
return BaseConfig::Get<int64_t>(kExprEvalBatchSize,
EXEC_EVAL_EXPR_BATCH_SIZE);
}
};
class Context {
public:
explicit Context(ContextScope scope,
const std::shared_ptr<const Context> parent = nullptr)
: scope_(scope), parent_(parent) {
}
ContextScope
scope() const {
return scope_;
}
std::shared_ptr<const Context>
parent() const {
return parent_;
}
// // TODO: support dynamic update
// void
// set_config(const std::shared_ptr<const Config>& config) {
// std::atomic_exchange(&config_, config);
// }
// std::shared_ptr<const config>
// get_config() {
// return config_;
// }
private:
ContextScope scope_;
std::shared_ptr<const Context> parent_;
//std::shared_ptr<const Config> config_;
};
class QueryContext : public Context {
public:
QueryContext(const std::string& query_id,
const milvus::segcore::SegmentInternalInterface* segment,
milvus::Timestamp timestamp,
std::shared_ptr<QueryConfig> query_config =
std::make_shared<QueryConfig>(),
folly::Executor* executor = nullptr,
std::unordered_map<std::string, std::shared_ptr<Config>>
connector_configs = {})
: Context(ContextScope::QUERY),
query_id_(query_id),
segment_(segment),
query_timestamp_(timestamp),
query_config_(query_config),
executor_(executor) {
}
folly::Executor*
executor() const {
return executor_;
}
const std::unordered_map<std::string, std::shared_ptr<Config>>&
connector_configs() const {
return connector_configs_;
}
std::shared_ptr<QueryConfig>
query_config() const {
return query_config_;
}
std::string
query_id() const {
return query_id_;
}
const milvus::segcore::SegmentInternalInterface*
get_segment() {
return segment_;
}
milvus::Timestamp
get_query_timestamp() {
return query_timestamp_;
}
private:
folly::Executor* executor_;
//folly::Executor::KeepAlive<> executor_keepalive_;
std::unordered_map<std::string, std::shared_ptr<Config>> connector_configs_;
std::shared_ptr<QueryConfig> query_config_;
std::string query_id_;
// current segment that query execute in
const milvus::segcore::SegmentInternalInterface* segment_;
// timestamp this query generate
milvus::Timestamp query_timestamp_;
};
// Represent the state of one thread of query execution.
// TODO: add more class member such as memory pool
class ExecContext : public Context {
public:
ExecContext(QueryContext* query_context)
: Context(ContextScope::Executor), query_context_(query_context) {
}
QueryContext*
get_query_context() const {
return query_context_;
}
std::shared_ptr<QueryConfig>
get_query_config() const {
return query_context_->query_config();
}
private:
QueryContext* query_context_;
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,230 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Task.h"
#include <boost/lexical_cast.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace milvus {
namespace exec {
// Special group id to reflect the ungrouped execution.
constexpr uint32_t kUngroupedGroupId{std::numeric_limits<uint32_t>::max()};
std::string
MakeUuid() {
return boost::lexical_cast<std::string>(boost::uuids::random_generator()());
}
std::shared_ptr<Task>
Task::Create(const std::string& task_id,
plan::PlanFragment plan_fragment,
int destination,
std::shared_ptr<QueryContext> query_context,
Consumer consumer,
std::function<void(std::exception_ptr)> on_error) {
return Task::Create(task_id,
std::move(plan_fragment),
destination,
std::move(query_context),
(consumer ? [c = std::move(consumer)]() { return c; }
: ConsumerSupplier{}),
std::move(on_error));
}
std::shared_ptr<Task>
Task::Create(const std::string& task_id,
const plan::PlanFragment& plan_fragment,
int destination,
std::shared_ptr<QueryContext> query_ctx,
ConsumerSupplier supplier,
std::function<void(std::exception_ptr)> on_error) {
return std::shared_ptr<Task>(new Task(task_id,
std::move(plan_fragment),
destination,
std::move(query_ctx),
std::move(supplier),
std::move(on_error)));
}
void
Task::SetError(const std::exception_ptr& exception) {
{
std::lock_guard<std::mutex> l(mutex_);
if (!IsRunningLocked()) {
return;
}
if (exception_ != nullptr) {
return;
}
exception_ = exception;
}
Terminate(TaskState::kFailed);
if (on_error_) {
on_error_(exception_);
}
}
void
Task::SetError(const std::string& message) {
try {
throw std::runtime_error(message);
} catch (const std::runtime_error& e) {
SetError(std::current_exception());
}
}
void
Task::CreateDriversLocked(std::shared_ptr<Task>& self,
uint32_t split_group_id,
std::vector<std::shared_ptr<Driver>>& out) {
const bool is_group_execution_drivers =
(split_group_id != kUngroupedGroupId);
const auto num_pipelines = driver_factories_.size();
for (auto pipeline = 0; pipeline < num_pipelines; ++pipeline) {
auto& factory = driver_factories_[pipeline];
if (factory->is_group_execution_ != is_group_execution_drivers) {
continue;
}
const uint32_t driverid_offset =
factory->num_drivers_ *
(is_group_execution_drivers ? split_group_id : 0);
for (uint32_t partition_id = 0; partition_id < factory->num_drivers_;
++partition_id) {
out.emplace_back(factory->CreateDriver(
std::make_unique<DriverContext>(self,
driverid_offset + partition_id,
pipeline,
split_group_id,
partition_id),
[self](size_t i) {
return i < self->driver_factories_.size()
? self->driver_factories_[i]->num_total_drivers_
: 0;
}));
}
}
}
RowVectorPtr
Task::Next(ContinueFuture* future) {
// NOTE: Task::Next is single-threaded execution
AssertInfo(plan_fragment_.execution_strategy_ ==
plan::ExecutionStrategy::kUngrouped,
"Single-threaded execution supports only ungrouped execution");
AssertInfo(state_ == TaskState::kRunning,
"Task has already finished processing.");
if (driver_factories_.empty()) {
AssertInfo(
consumer_supplier_ == nullptr,
"Single-threaded execution doesn't support delivering results to a "
"callback");
LocalPlanner::Plan(plan_fragment_,
nullptr,
&driver_factories_,
*query_context_->query_config(),
1);
for (const auto& factory : driver_factories_) {
assert(factory->SupportSingleThreadExecution());
num_ungrouped_drivers_ += factory->num_drivers_;
num_total_drivers_ += factory->num_total_drivers_;
}
auto self = shared_from_this();
std::vector<std::shared_ptr<Driver>> drivers;
drivers.reserve(num_ungrouped_drivers_);
CreateDriversLocked(self, kUngroupedGroupId, drivers);
drivers_ = std::move(drivers);
}
const auto num_drivers = drivers_.size();
std::vector<ContinueFuture> futures;
futures.resize(num_drivers);
for (;;) {
int runnable_drivers = 0;
int blocked_drivers = 0;
for (auto i = 0; i < num_drivers; ++i) {
if (drivers_[i] == nullptr) {
continue;
}
if (!futures[i].isReady()) {
++blocked_drivers;
continue;
}
++runnable_drivers;
std::shared_ptr<BlockingState> blocking_state;
auto result = drivers_[i]->Next(blocking_state);
if (result) {
return result;
}
if (blocking_state) {
futures[i] = blocking_state->future();
}
if (error()) {
std::rethrow_exception(error());
}
}
if (runnable_drivers == 0) {
if (blocked_drivers > 0) {
if (!future) {
throw ExecDriverException(
"Cannot make progress as all remaining drivers are "
"blocked and user are not expected to wait.");
} else {
std::vector<ContinueFuture> not_ready_futures;
for (auto& continue_future : futures) {
if (!continue_future.isReady()) {
not_ready_futures.emplace_back(
std::move(continue_future));
}
}
*future =
folly::collectAll(std::move(not_ready_futures)).unit();
}
}
return nullptr;
}
}
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,205 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "common/Types.h"
#include "exec/Driver.h"
#include "exec/QueryContext.h"
#include "plan/PlanNode.h"
namespace milvus {
namespace exec {
enum class TaskState { kRunning, kFinished, kCanceled, kAborted, kFailed };
std::string
MakeUuid();
class Task : public std::enable_shared_from_this<Task> {
public:
static std::shared_ptr<Task>
Create(const std::string& task_id,
plan::PlanFragment plan_fragment,
int destination,
std::shared_ptr<QueryContext> query_context,
Consumer consumer = nullptr,
std::function<void(std::exception_ptr)> on_error = nullptr);
static std::shared_ptr<Task>
Create(const std::string& task_id,
const plan::PlanFragment& plan_fragment,
int destination,
std::shared_ptr<QueryContext> query_ctx,
ConsumerSupplier supplier,
std::function<void(std::exception_ptr)> on_error = nullptr);
Task(const std::string& task_id,
plan::PlanFragment plan_fragment,
int destination,
std::shared_ptr<QueryContext> query_ctx,
ConsumerSupplier consumer_supplier,
std::function<void(std::exception_ptr)> on_error)
: uuid_{MakeUuid()},
taskid_(task_id),
plan_fragment_(std::move(plan_fragment)),
destination_(destination),
query_context_(std::move(query_ctx)),
consumer_supplier_(std::move(consumer_supplier)),
on_error_(on_error) {
}
~Task() {
}
const std::string&
uuid() const {
return uuid_;
}
const std::string&
taskid() const {
return taskid_;
}
const int
destination() const {
return destination_;
}
const std::shared_ptr<QueryContext>&
query_context() const {
return query_context_;
}
static void
Start(std::shared_ptr<Task> self,
uint32_t max_drivers,
uint32_t concurrent_split_groups = 1);
static void
RemoveDriver(std::shared_ptr<Task> self, Driver* instance) {
std::lock_guard<std::mutex> lock(self->mutex_);
for (auto& driver_ptr : self->drivers_) {
if (driver_ptr.get() != instance) {
continue;
}
driver_ptr = nullptr;
self->DriverClosedLocked();
}
}
bool
SupportsSingleThreadedExecution() const {
if (consumer_supplier_) {
return false;
}
}
RowVectorPtr
Next(ContinueFuture* future = nullptr);
void
CreateDriversLocked(std::shared_ptr<Task>& self,
uint32_t split_groupid,
std::vector<std::shared_ptr<Driver>>& out);
void
SetError(const std::exception_ptr& exception);
void
SetError(const std::string& message);
bool
IsRunning() const {
std::lock_guard<std::mutex> l(mutex_);
return (state_ == TaskState::kRunning);
}
bool
IsFinished() const {
std::lock_guard<std::mutex> l(mutex_);
return (state_ == TaskState::kFinished);
}
bool
IsRunningLocked() const {
return (state_ == TaskState::kRunning);
}
bool
IsFinishedLocked() const {
return (state_ == TaskState::kFinished);
}
void
Terminate(TaskState state) {
}
std::exception_ptr
error() const {
std::lock_guard<std::mutex> l(mutex_);
return exception_;
}
void
DriverClosedLocked() {
if (IsRunningLocked()) {
--num_running_drivers_;
}
num_finished_drivers_++;
}
private:
std::string uuid_;
std::string taskid_;
plan::PlanFragment plan_fragment_;
int destination_;
std::shared_ptr<QueryContext> query_context_;
std::exception_ptr exception_ = nullptr;
std::function<void(std::exception_ptr)> on_error_;
std::vector<std::unique_ptr<DriverFactory>> driver_factories_;
std::vector<std::shared_ptr<Driver>> drivers_;
ConsumerSupplier consumer_supplier_;
mutable std::mutex mutex_;
TaskState state_ = TaskState::kRunning;
uint32_t num_running_drivers_{0};
uint32_t num_total_drivers_{0};
uint32_t num_ungrouped_drivers_{0};
uint32_t num_finished_drivers_{0};
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,45 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "AlwaysTrueExpr.h"
namespace milvus {
namespace exec {
void
PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) {
int64_t real_batch_size = current_pos_ + batch_size_ >= num_rows_
? num_rows_ - current_pos_
: batch_size_;
if (real_batch_size == 0) {
result = nullptr;
return;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res_bool = (bool*)res_vec->GetRawData();
for (size_t i = 0; i < real_batch_size; ++i) {
res_bool[i] = true;
}
result = res_vec;
current_pos_ += real_batch_size;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,56 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
class PhyAlwaysTrueExpr : public Expr {
public:
PhyAlwaysTrueExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::AlwaysTrueExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: Expr(DataType::BOOL, std::move(input), name),
expr_(expr),
batch_size_(batch_size) {
num_rows_ = segment->get_active_count(query_timestamp);
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
std::shared_ptr<const milvus::expr::AlwaysTrueExpr> expr_;
int64_t num_rows_;
int64_t current_pos_{0};
int64_t batch_size_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,748 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "BinaryArithOpEvalRangeExpr.h"
namespace milvus {
namespace exec {
void
PhyBinaryArithOpEvalRangeExpr::Eval(EvalCtx& context, VectorPtr& result) {
switch (expr_->column_.data_type_) {
case DataType::BOOL: {
result = ExecRangeVisitorImpl<bool>();
break;
}
case DataType::INT8: {
result = ExecRangeVisitorImpl<int8_t>();
break;
}
case DataType::INT16: {
result = ExecRangeVisitorImpl<int16_t>();
break;
}
case DataType::INT32: {
result = ExecRangeVisitorImpl<int32_t>();
break;
}
case DataType::INT64: {
result = ExecRangeVisitorImpl<int64_t>();
break;
}
case DataType::FLOAT: {
result = ExecRangeVisitorImpl<float>();
break;
}
case DataType::DOUBLE: {
result = ExecRangeVisitorImpl<double>();
break;
}
case DataType::JSON: {
auto value_type = expr_->value_.val_case();
switch (value_type) {
case proto::plan::GenericValue::ValCase::kBoolVal: {
result = ExecRangeVisitorImplForJson<bool>();
break;
}
case proto::plan::GenericValue::ValCase::kInt64Val: {
result = ExecRangeVisitorImplForJson<int64_t>();
break;
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForJson<double>();
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported value type {} in expression",
value_type));
}
}
break;
}
case DataType::ARRAY: {
auto value_type = expr_->value_.val_case();
switch (value_type) {
case proto::plan::GenericValue::ValCase::kInt64Val: {
result = ExecRangeVisitorImplForArray<int64_t>();
break;
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForArray<double>();
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported value type {} in expression",
value_type));
}
}
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}",
expr_->column_.data_type_));
}
}
template <typename ValueType>
VectorPtr
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
auto op_type = expr_->op_type_;
auto arith_type = expr_->arith_op_type_;
auto value = GetValueFromProto<ValueType>(expr_->value_);
auto right_operand = GetValueFromProto<ValueType>(expr_->right_operand_);
#define BinaryArithRangeJSONCompare(cmp) \
do { \
for (size_t i = 0; i < size; ++i) { \
auto x = data[i].template at<GetType>(pointer); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = data[i].template at<double>(pointer); \
res[i] = !x.error() && (cmp); \
continue; \
} \
res[i] = false; \
continue; \
} \
res[i] = (cmp); \
} \
} while (false)
#define BinaryArithRangeJSONCompareNotEqual(cmp) \
do { \
for (size_t i = 0; i < size; ++i) { \
auto x = data[i].template at<GetType>(pointer); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = data[i].template at<double>(pointer); \
res[i] = x.error() || (cmp); \
continue; \
} \
res[i] = true; \
continue; \
} \
res[i] = (cmp); \
} \
} while (false)
auto execute_sub_batch = [op_type, arith_type](const milvus::Json* data,
const int size,
bool* res,
ValueType val,
ValueType right_operand,
const std::string& pointer) {
switch (op_type) {
case proto::plan::OpType::Equal: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
BinaryArithRangeJSONCompare(x.value() + right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Sub: {
BinaryArithRangeJSONCompare(x.value() - right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Mul: {
BinaryArithRangeJSONCompare(x.value() * right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Div: {
BinaryArithRangeJSONCompare(x.value() / right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Mod: {
BinaryArithRangeJSONCompare(
static_cast<ValueType>(
fmod(x.value(), right_operand)) == val);
break;
}
case proto::plan::ArithOpType::ArrayLength: {
for (size_t i = 0; i < size; ++i) {
int array_length = 0;
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (!array.error()) {
array_length = array.count_elements();
}
res[i] = array_length == val;
}
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
case proto::plan::OpType::NotEqual: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
BinaryArithRangeJSONCompareNotEqual(
x.value() + right_operand != val);
break;
}
case proto::plan::ArithOpType::Sub: {
BinaryArithRangeJSONCompareNotEqual(
x.value() - right_operand != val);
break;
}
case proto::plan::ArithOpType::Mul: {
BinaryArithRangeJSONCompareNotEqual(
x.value() * right_operand != val);
break;
}
case proto::plan::ArithOpType::Div: {
BinaryArithRangeJSONCompareNotEqual(
x.value() / right_operand != val);
break;
}
case proto::plan::ArithOpType::Mod: {
BinaryArithRangeJSONCompareNotEqual(
static_cast<ValueType>(
fmod(x.value(), right_operand)) != val);
break;
}
case proto::plan::ArithOpType::ArrayLength: {
for (size_t i = 0; i < size; ++i) {
int array_length = 0;
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (!array.error()) {
array_length = array.count_elements();
}
res[i] = array_length != val;
}
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
default:
PanicInfo(OpTypeInvalid,
fmt::format("unsupported operator type for binary "
"arithmetic eval expr: {}",
op_type));
}
};
int64_t processed_size = ProcessDataChunks<milvus::Json>(execute_sub_batch,
std::nullptr_t{},
res,
value,
right_operand,
pointer);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ValueType>
VectorPtr
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
int index = -1;
if (expr_->column_.nested_path_.size() > 0) {
index = std::stoi(expr_->column_.nested_path_[0]);
}
auto op_type = expr_->op_type_;
auto arith_type = expr_->arith_op_type_;
auto value = GetValueFromProto<ValueType>(expr_->value_);
auto right_operand =
arith_type != proto::plan::ArithOpType::ArrayLength
? GetValueFromProto<ValueType>(expr_->right_operand_)
: ValueType();
#define BinaryArithRangeArrayCompare(cmp) \
do { \
for (size_t i = 0; i < size; ++i) { \
if (index >= data[i].length()) { \
res[i] = false; \
continue; \
} \
auto value = data[i].get_data<GetType>(index); \
res[i] = (cmp); \
} \
} while (false)
auto execute_sub_batch = [op_type, arith_type](const ArrayView* data,
const int size,
bool* res,
ValueType val,
ValueType right_operand,
int index) {
switch (op_type) {
case proto::plan::OpType::Equal: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
BinaryArithRangeArrayCompare(value + right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Sub: {
BinaryArithRangeArrayCompare(value - right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Mul: {
BinaryArithRangeArrayCompare(value * right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Div: {
BinaryArithRangeArrayCompare(value / right_operand ==
val);
break;
}
case proto::plan::ArithOpType::Mod: {
BinaryArithRangeArrayCompare(
static_cast<ValueType>(
fmod(value, right_operand)) == val);
break;
}
case proto::plan::ArithOpType::ArrayLength: {
for (size_t i = 0; i < size; ++i) {
res[i] = data[i].length() == val;
}
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
case proto::plan::OpType::NotEqual: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
BinaryArithRangeArrayCompare(value + right_operand !=
val);
break;
}
case proto::plan::ArithOpType::Sub: {
BinaryArithRangeArrayCompare(value - right_operand !=
val);
break;
}
case proto::plan::ArithOpType::Mul: {
BinaryArithRangeArrayCompare(value * right_operand !=
val);
break;
}
case proto::plan::ArithOpType::Div: {
BinaryArithRangeArrayCompare(value / right_operand !=
val);
break;
}
case proto::plan::ArithOpType::Mod: {
BinaryArithRangeArrayCompare(
static_cast<ValueType>(
fmod(value, right_operand)) != val);
break;
}
case proto::plan::ArithOpType::ArrayLength: {
for (size_t i = 0; i < size; ++i) {
res[i] = data[i].length() != val;
}
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
default:
PanicInfo(OpTypeInvalid,
fmt::format("unsupported operator type for binary "
"arithmetic eval expr: {}",
op_type));
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, value, right_operand, index);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename T>
VectorPtr
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImpl() {
if (is_index_mode_) {
return ExecRangeVisitorImplForIndex<T>();
} else {
return ExecRangeVisitorImplForData<T>();
}
}
template <typename T>
VectorPtr
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() {
using Index = index::ScalarIndex<T>;
typedef std::conditional_t<std::is_integral_v<T> &&
!std::is_same_v<bool, T>,
int64_t,
T>
HighPrecisionType;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto value = GetValueFromProto<HighPrecisionType>(expr_->value_);
auto right_operand =
GetValueFromProto<HighPrecisionType>(expr_->right_operand_);
auto op_type = expr_->op_type_;
auto arith_type = expr_->arith_op_type_;
auto sub_batch_size = size_per_chunk_;
auto execute_sub_batch = [op_type, arith_type, sub_batch_size](
Index* index_ptr,
HighPrecisionType value,
HighPrecisionType right_operand) {
FixedVector<bool> res;
switch (op_type) {
case proto::plan::OpType::Equal: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
ArithOpIndexFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Add>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Sub: {
ArithOpIndexFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Sub>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Mul: {
ArithOpIndexFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Mul>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Div: {
ArithOpIndexFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Div>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Mod: {
ArithOpIndexFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Mod>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
case proto::plan::OpType::NotEqual: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
ArithOpIndexFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Add>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Sub: {
ArithOpIndexFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Sub>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Mul: {
ArithOpIndexFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Mul>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Div: {
ArithOpIndexFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Div>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
case proto::plan::ArithOpType::Mod: {
ArithOpIndexFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Mod>
func;
res = std::move(func(
index_ptr, sub_batch_size, value, right_operand));
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
default:
PanicInfo(OpTypeInvalid,
fmt::format("unsupported operator type for binary "
"arithmetic eval expr: {}",
op_type));
}
return res;
};
auto res = ProcessIndexChunks<T>(execute_sub_batch, value, right_operand);
AssertInfo(res.size() == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
res.size(),
real_batch_size));
return std::make_shared<ColumnVector>(std::move(res));
}
template <typename T>
VectorPtr
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() {
typedef std::conditional_t<std::is_integral_v<T> &&
!std::is_same_v<bool, T>,
int64_t,
T>
HighPrecisionType;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto value = GetValueFromProto<HighPrecisionType>(expr_->value_);
auto right_operand =
GetValueFromProto<HighPrecisionType>(expr_->right_operand_);
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto op_type = expr_->op_type_;
auto arith_type = expr_->arith_op_type_;
auto execute_sub_batch = [op_type, arith_type](
const T* data,
const int size,
bool* res,
HighPrecisionType value,
HighPrecisionType right_operand) {
switch (op_type) {
case proto::plan::OpType::Equal: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
ArithOpElementFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Add>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Sub: {
ArithOpElementFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Sub>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Mul: {
ArithOpElementFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Mul>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Div: {
ArithOpElementFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Div>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Mod: {
ArithOpElementFunc<T,
proto::plan::OpType::Equal,
proto::plan::ArithOpType::Mod>
func;
func(data, size, value, right_operand, res);
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
case proto::plan::OpType::NotEqual: {
switch (arith_type) {
case proto::plan::ArithOpType::Add: {
ArithOpElementFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Add>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Sub: {
ArithOpElementFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Sub>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Mul: {
ArithOpElementFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Mul>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Div: {
ArithOpElementFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Div>
func;
func(data, size, value, right_operand, res);
break;
}
case proto::plan::ArithOpType::Mod: {
ArithOpElementFunc<T,
proto::plan::OpType::NotEqual,
proto::plan::ArithOpType::Mod>
func;
func(data, size, value, right_operand, res);
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported arith type for binary "
"arithmetic eval expr: {}",
arith_type));
}
break;
}
default:
PanicInfo(OpTypeInvalid,
fmt::format("unsupported operator type for binary "
"arithmetic eval expr: {}",
op_type));
}
};
int64_t processed_size = ProcessDataChunks<T>(
execute_sub_batch, std::nullptr_t{}, res, value, right_operand);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,213 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
template <typename T,
proto::plan::OpType cmp_op,
proto::plan::ArithOpType arith_op>
struct ArithOpElementFunc {
typedef std::conditional_t<std::is_integral_v<T> &&
!std::is_same_v<bool, T>,
int64_t,
T>
HighPrecisonType;
void
operator()(const T* src,
size_t size,
HighPrecisonType val,
HighPrecisonType right_operand,
bool* res) {
for (int i = 0; i < size; ++i) {
if constexpr (cmp_op == proto::plan::OpType::Equal) {
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
res[i] = (src[i] + right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Sub) {
res[i] = (src[i] - right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mul) {
res[i] = (src[i] * right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Div) {
res[i] = (src[i] / right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mod) {
res[i] = (fmod(src[i], right_operand)) == val;
} else {
PanicInfo(
OpTypeInvalid,
fmt::format(
"unsupported arith type:{} for ArithOpElementFunc",
arith_op));
}
} else if constexpr (cmp_op == proto::plan::OpType::NotEqual) {
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
res[i] = (src[i] + right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Sub) {
res[i] = (src[i] - right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mul) {
res[i] = (src[i] * right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Div) {
res[i] = (src[i] / right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mod) {
res[i] = (fmod(src[i], right_operand)) != val;
} else {
PanicInfo(
OpTypeInvalid,
fmt::format(
"unsupported arith type:{} for ArithOpElementFunc",
arith_op));
}
}
}
}
};
template <typename T,
proto::plan::OpType cmp_op,
proto::plan::ArithOpType arith_op>
struct ArithOpIndexFunc {
typedef std::conditional_t<std::is_integral_v<T> &&
!std::is_same_v<bool, T>,
int64_t,
T>
HighPrecisonType;
using Index = index::ScalarIndex<T>;
FixedVector<bool>
operator()(Index* index,
size_t size,
HighPrecisonType val,
HighPrecisonType right_operand) {
FixedVector<bool> res_vec(size);
bool* res = res_vec.data();
for (size_t i = 0; i < size; ++i) {
if constexpr (cmp_op == proto::plan::OpType::Equal) {
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
res[i] = (index->Reverse_Lookup(i) + right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Sub) {
res[i] = (index->Reverse_Lookup(i) - right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mul) {
res[i] = (index->Reverse_Lookup(i) * right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Div) {
res[i] = (index->Reverse_Lookup(i) / right_operand) == val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mod) {
res[i] =
(fmod(index->Reverse_Lookup(i), right_operand)) == val;
} else {
PanicInfo(
OpTypeInvalid,
fmt::format(
"unsupported arith type:{} for ArithOpElementFunc",
arith_op));
}
} else if constexpr (cmp_op == proto::plan::OpType::NotEqual) {
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
res[i] = (index->Reverse_Lookup(i) + right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Sub) {
res[i] = (index->Reverse_Lookup(i) - right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mul) {
res[i] = (index->Reverse_Lookup(i) * right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Div) {
res[i] = (index->Reverse_Lookup(i) / right_operand) != val;
} else if constexpr (arith_op ==
proto::plan::ArithOpType::Mod) {
res[i] =
(fmod(index->Reverse_Lookup(i), right_operand)) != val;
} else {
PanicInfo(
OpTypeInvalid,
fmt::format(
"unsupported arith type:{} for ArithOpElementFunc",
arith_op));
}
}
}
return res_vec;
}
};
class PhyBinaryArithOpEvalRangeExpr : public SegmentExpr {
public:
PhyBinaryArithOpEvalRangeExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::BinaryArithOpEvalRangeExpr>&
expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: SegmentExpr(std::move(input),
name,
segment,
expr->column_.field_id_,
query_timestamp,
batch_size),
expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
template <typename T>
VectorPtr
ExecRangeVisitorImpl();
template <typename T>
VectorPtr
ExecRangeVisitorImplForIndex();
template <typename T>
VectorPtr
ExecRangeVisitorImplForData();
template <typename ValueType>
VectorPtr
ExecRangeVisitorImplForJson();
template <typename ValueType>
VectorPtr
ExecRangeVisitorImplForArray();
private:
std::shared_ptr<const milvus::expr::BinaryArithOpEvalRangeExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,392 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "BinaryRangeExpr.h"
#include "query/Utils.h"
namespace milvus {
namespace exec {
void
PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
switch (expr_->column_.data_type_) {
case DataType::BOOL: {
result = ExecRangeVisitorImpl<bool>();
break;
}
case DataType::INT8: {
result = ExecRangeVisitorImpl<int8_t>();
break;
}
case DataType::INT16: {
result = ExecRangeVisitorImpl<int16_t>();
break;
}
case DataType::INT32: {
result = ExecRangeVisitorImpl<int32_t>();
break;
}
case DataType::INT64: {
result = ExecRangeVisitorImpl<int64_t>();
break;
}
case DataType::FLOAT: {
result = ExecRangeVisitorImpl<float>();
break;
}
case DataType::DOUBLE: {
result = ExecRangeVisitorImpl<double>();
break;
}
case DataType::VARCHAR: {
if (segment_->type() == SegmentType::Growing) {
result = ExecRangeVisitorImpl<std::string>();
} else {
result = ExecRangeVisitorImpl<std::string_view>();
}
break;
}
case DataType::JSON: {
auto value_type = expr_->lower_val_.val_case();
switch (value_type) {
case proto::plan::GenericValue::ValCase::kInt64Val: {
result = ExecRangeVisitorImplForJson<int64_t>();
break;
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForJson<double>();
break;
}
case proto::plan::GenericValue::ValCase::kStringVal: {
result = ExecRangeVisitorImplForJson<std::string>();
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported value type {} in expression",
value_type));
}
}
break;
}
case DataType::ARRAY: {
auto value_type = expr_->lower_val_.val_case();
switch (value_type) {
case proto::plan::GenericValue::ValCase::kInt64Val: {
result = ExecRangeVisitorImplForArray<int64_t>();
break;
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForArray<double>();
break;
}
case proto::plan::GenericValue::ValCase::kStringVal: {
result = ExecRangeVisitorImplForArray<std::string>();
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported value type {} in expression",
value_type));
}
}
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}",
expr_->column_.data_type_));
}
}
template <typename T>
VectorPtr
PhyBinaryRangeFilterExpr::ExecRangeVisitorImpl() {
if (is_index_mode_) {
return ExecRangeVisitorImplForIndex<T>();
} else {
return ExecRangeVisitorImplForData<T>();
}
}
template <typename T, typename IndexInnerType, typename HighPrecisionType>
ColumnVectorPtr
PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1,
HighPrecisionType& val2,
bool& lower_inclusive,
bool& upper_inclusive) {
lower_inclusive = expr_->lower_inclusive_;
upper_inclusive = expr_->upper_inclusive_;
val1 = GetValueFromProto<HighPrecisionType>(expr_->lower_val_);
val2 = GetValueFromProto<HighPrecisionType>(expr_->upper_val_);
auto get_next_overflow_batch = [this]() -> ColumnVectorPtr {
int64_t batch_size = overflow_check_pos_ + batch_size_ >= num_rows_
? num_rows_ - overflow_check_pos_
: batch_size_;
overflow_check_pos_ += batch_size;
if (cached_overflow_res_ != nullptr &&
cached_overflow_res_->size() == batch_size) {
return cached_overflow_res_;
}
auto res = std::make_shared<ColumnVector>(DataType::BOOL, batch_size);
return res;
};
if constexpr (std::is_integral_v<T> && !std::is_same_v<bool, T>) {
if (milvus::query::gt_ub<T>(val1)) {
return get_next_overflow_batch();
} else if (milvus::query::lt_lb<T>(val1)) {
val1 = std::numeric_limits<T>::min();
lower_inclusive = true;
}
if (milvus::query::gt_ub<T>(val2)) {
val2 = std::numeric_limits<T>::max();
upper_inclusive = true;
} else if (milvus::query::lt_lb<T>(val2)) {
return get_next_overflow_batch();
}
}
return nullptr;
}
template <typename T>
VectorPtr
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForIndex() {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
typedef std::conditional_t<std::is_integral_v<IndexInnerType> &&
!std::is_same_v<bool, T>,
int64_t,
IndexInnerType>
HighPrecisionType;
HighPrecisionType val1;
HighPrecisionType val2;
bool lower_inclusive = false;
bool upper_inclusive = false;
if (auto res =
PreCheckOverflow<T>(val1, val2, lower_inclusive, upper_inclusive)) {
return res;
}
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto execute_sub_batch =
[lower_inclusive, upper_inclusive](
Index* index_ptr, HighPrecisionType val1, HighPrecisionType val2) {
BinaryRangeIndexFunc<T> func;
return std::move(
func(index_ptr, val1, val2, lower_inclusive, upper_inclusive));
};
auto res = ProcessIndexChunks<T>(execute_sub_batch, val1, val2);
AssertInfo(res.size() == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
res.size(),
real_batch_size));
return std::make_shared<ColumnVector>(std::move(res));
}
template <typename T>
VectorPtr
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
typedef std::conditional_t<std::is_integral_v<IndexInnerType> &&
!std::is_same_v<bool, T>,
int64_t,
IndexInnerType>
HighPrecisionType;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
HighPrecisionType val1;
HighPrecisionType val2;
bool lower_inclusive = false;
bool upper_inclusive = false;
if (auto res =
PreCheckOverflow<T>(val1, val2, lower_inclusive, upper_inclusive)) {
return res;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto execute_sub_batch = [lower_inclusive, upper_inclusive](
const T* data,
const int size,
bool* res,
HighPrecisionType val1,
HighPrecisionType val2) {
if (lower_inclusive && upper_inclusive) {
BinaryRangeElementFunc<T, true, true> func;
func(val1, val2, data, size, res);
} else if (lower_inclusive && !upper_inclusive) {
BinaryRangeElementFunc<T, true, false> func;
func(val1, val2, data, size, res);
} else if (!lower_inclusive && upper_inclusive) {
BinaryRangeElementFunc<T, false, true> func;
func(val1, val2, data, size, res);
} else {
BinaryRangeElementFunc<T, false, false> func;
func(val1, val2, data, size, res);
}
};
auto skip_index_func =
[val1, val2, lower_inclusive, upper_inclusive](
const SkipIndex& skip_index, FieldId field_id, int64_t chunk_id) {
if (lower_inclusive && upper_inclusive) {
return skip_index.CanSkipBinaryRange<T>(
field_id, chunk_id, val1, val2, true, true);
} else if (lower_inclusive && !upper_inclusive) {
return skip_index.CanSkipBinaryRange<T>(
field_id, chunk_id, val1, val2, true, false);
} else if (!lower_inclusive && upper_inclusive) {
return skip_index.CanSkipBinaryRange<T>(
field_id, chunk_id, val1, val2, false, true);
} else {
return skip_index.CanSkipBinaryRange<T>(
field_id, chunk_id, val1, val2, false, false);
}
};
int64_t processed_size = ProcessDataChunks<T>(
execute_sub_batch, skip_index_func, res, val1, val2);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ValueType>
VectorPtr
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
bool lower_inclusive = expr_->lower_inclusive_;
bool upper_inclusive = expr_->upper_inclusive_;
ValueType val1 = GetValueFromProto<ValueType>(expr_->lower_val_);
ValueType val2 = GetValueFromProto<ValueType>(expr_->upper_val_);
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
auto execute_sub_batch = [lower_inclusive, upper_inclusive, pointer](
const milvus::Json* data,
const int size,
bool* res,
ValueType val1,
ValueType val2) {
if (lower_inclusive && upper_inclusive) {
BinaryRangeElementFuncForJson<ValueType, true, true> func;
func(val1, val2, pointer, data, size, res);
} else if (lower_inclusive && !upper_inclusive) {
BinaryRangeElementFuncForJson<ValueType, true, false> func;
func(val1, val2, pointer, data, size, res);
} else if (!lower_inclusive && upper_inclusive) {
BinaryRangeElementFuncForJson<ValueType, false, true> func;
func(val1, val2, pointer, data, size, res);
} else {
BinaryRangeElementFuncForJson<ValueType, false, false> func;
func(val1, val2, pointer, data, size, res);
}
};
int64_t processed_size = ProcessDataChunks<milvus::Json>(
execute_sub_batch, std::nullptr_t{}, res, val1, val2);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ValueType>
VectorPtr
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
bool lower_inclusive = expr_->lower_inclusive_;
bool upper_inclusive = expr_->upper_inclusive_;
ValueType val1 = GetValueFromProto<ValueType>(expr_->lower_val_);
ValueType val2 = GetValueFromProto<ValueType>(expr_->upper_val_);
int index = -1;
if (expr_->column_.nested_path_.size() > 0) {
index = std::stoi(expr_->column_.nested_path_[0]);
}
auto execute_sub_batch = [lower_inclusive, upper_inclusive](
const milvus::ArrayView* data,
const int size,
bool* res,
ValueType val1,
ValueType val2,
int index) {
if (lower_inclusive && upper_inclusive) {
BinaryRangeElementFuncForArray<ValueType, true, true> func;
func(val1, val2, index, data, size, res);
} else if (lower_inclusive && !upper_inclusive) {
BinaryRangeElementFuncForArray<ValueType, true, false> func;
func(val1, val2, index, data, size, res);
} else if (!lower_inclusive && upper_inclusive) {
BinaryRangeElementFuncForArray<ValueType, false, true> func;
func(val1, val2, index, data, size, res);
} else {
BinaryRangeElementFuncForArray<ValueType, false, false> func;
func(val1, val2, index, data, size, res);
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, val1, val2, index);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,228 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
template <typename T, bool lower_inclusive, bool upper_inclusive>
struct BinaryRangeElementFunc {
typedef std::conditional_t<std::is_integral_v<T> &&
!std::is_same_v<bool, T>,
int64_t,
T>
HighPrecisionType;
void
operator()(T val1, T val2, const T* src, size_t n, bool* res) {
for (size_t i = 0; i < n; ++i) {
if constexpr (lower_inclusive && upper_inclusive) {
res[i] = val1 <= src[i] && src[i] <= val2;
} else if constexpr (lower_inclusive && !upper_inclusive) {
res[i] = val1 <= src[i] && src[i] < val2;
} else if constexpr (!lower_inclusive && upper_inclusive) {
res[i] = val1 < src[i] && src[i] <= val2;
} else {
res[i] = val1 < src[i] && src[i] < val2;
}
}
}
};
#define BinaryRangeJSONCompare(cmp) \
do { \
auto x = src[i].template at<GetType>(pointer); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = src[i].template at<double>(pointer); \
if (!x.error()) { \
auto value = x.value(); \
res[i] = (cmp); \
break; \
} \
} \
res[i] = false; \
break; \
} \
auto value = x.value(); \
res[i] = (cmp); \
} while (false)
template <typename ValueType, bool lower_inclusive, bool upper_inclusive>
struct BinaryRangeElementFuncForJson {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
void
operator()(ValueType val1,
ValueType val2,
const std::string& pointer,
const milvus::Json* src,
size_t n,
bool* res) {
for (size_t i = 0; i < n; ++i) {
if constexpr (lower_inclusive && upper_inclusive) {
BinaryRangeJSONCompare(val1 <= value && value <= val2);
} else if constexpr (lower_inclusive && !upper_inclusive) {
BinaryRangeJSONCompare(val1 <= value && value < val2);
} else if constexpr (!lower_inclusive && upper_inclusive) {
BinaryRangeJSONCompare(val1 < value && value <= val2);
} else {
BinaryRangeJSONCompare(val1 < value && value < val2);
}
}
}
};
template <typename ValueType, bool lower_inclusive, bool upper_inclusive>
struct BinaryRangeElementFuncForArray {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
void
operator()(ValueType val1,
ValueType val2,
int index,
const milvus::ArrayView* src,
size_t n,
bool* res) {
for (size_t i = 0; i < n; ++i) {
if constexpr (lower_inclusive && upper_inclusive) {
if (index >= src[i].length()) {
res[i] = false;
continue;
}
auto value = src[i].get_data<GetType>(index);
res[i] = val1 <= value && value <= val2;
} else if constexpr (lower_inclusive && !upper_inclusive) {
if (index >= src[i].length()) {
res[i] = false;
continue;
}
auto value = src[i].get_data<GetType>(index);
res[i] = val1 <= value && value < val2;
} else if constexpr (!lower_inclusive && upper_inclusive) {
if (index >= src[i].length()) {
res[i] = false;
continue;
}
auto value = src[i].get_data<GetType>(index);
res[i] = val1 < value && value <= val2;
} else {
if (index >= src[i].length()) {
res[i] = false;
continue;
}
auto value = src[i].get_data<GetType>(index);
res[i] = val1 < value && value < val2;
}
}
}
};
template <typename T>
struct BinaryRangeIndexFunc {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
typedef std::conditional_t<std::is_integral_v<IndexInnerType> &&
!std::is_same_v<bool, T>,
int64_t,
IndexInnerType>
HighPrecisionType;
FixedVector<bool>
operator()(Index* index,
IndexInnerType val1,
IndexInnerType val2,
bool lower_inclusive,
bool upper_inclusive) {
return index->Range(val1, lower_inclusive, val2, upper_inclusive);
}
};
class PhyBinaryRangeFilterExpr : public SegmentExpr {
public:
PhyBinaryRangeFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::BinaryRangeFilterExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: SegmentExpr(std::move(input),
name,
segment,
expr->column_.field_id_,
query_timestamp,
batch_size),
expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
// Check overflow and cache result for performace
template <
typename T,
typename IndexInnerType = std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>,
typename HighPrecisionType = std::conditional_t<
std::is_integral_v<IndexInnerType> && !std::is_same_v<bool, T>,
int64_t,
IndexInnerType>>
ColumnVectorPtr
PreCheckOverflow(HighPrecisionType& val1,
HighPrecisionType& val2,
bool& lower_inclusive,
bool& upper_inclusive);
template <typename T>
VectorPtr
ExecRangeVisitorImpl();
template <typename T>
VectorPtr
ExecRangeVisitorImplForIndex();
template <typename T>
VectorPtr
ExecRangeVisitorImplForData();
template <typename ValueType>
VectorPtr
ExecRangeVisitorImplForJson();
template <typename ValueType>
VectorPtr
ExecRangeVisitorImplForArray();
private:
std::shared_ptr<const milvus::expr::BinaryRangeFilterExpr> expr_;
ColumnVectorPtr cached_overflow_res_{nullptr};
int64_t overflow_check_pos_{0};
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,319 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "CompareExpr.h"
#include "query/Relational.h"
namespace milvus {
namespace exec {
bool
PhyCompareFilterExpr::IsStringExpr() {
return expr_->left_data_type_ == DataType::VARCHAR ||
expr_->right_data_type_ == DataType::VARCHAR;
}
int64_t
PhyCompareFilterExpr::GetNextBatchSize() {
auto current_rows =
segment_->type() == SegmentType::Growing
? current_chunk_id_ * size_per_chunk_ + current_chunk_pos_
: current_chunk_pos_;
return current_rows + batch_size_ >= num_rows_ ? num_rows_ - current_rows
: batch_size_;
}
template <typename T>
ChunkDataAccessor
PhyCompareFilterExpr::GetChunkData(FieldId field_id,
int chunk_id,
int data_barrier) {
if (chunk_id >= data_barrier) {
auto& indexing = segment_->chunk_scalar_index<T>(field_id, chunk_id);
if (indexing.HasRawData()) {
return [&indexing](int i) -> const number {
return indexing.Reverse_Lookup(i);
};
}
}
auto chunk_data = segment_->chunk_data<T>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
}
template <>
ChunkDataAccessor
PhyCompareFilterExpr::GetChunkData<std::string>(FieldId field_id,
int chunk_id,
int data_barrier) {
if (chunk_id >= data_barrier) {
auto& indexing =
segment_->chunk_scalar_index<std::string>(field_id, chunk_id);
if (indexing.HasRawData()) {
return [&indexing](int i) -> const std::string {
return indexing.Reverse_Lookup(i);
};
}
}
if (segment_->type() == SegmentType::Growing) {
auto chunk_data =
segment_->chunk_data<std::string>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number { return chunk_data[i]; };
} else {
auto chunk_data =
segment_->chunk_data<std::string_view>(field_id, chunk_id).data();
return [chunk_data](int i) -> const number {
return std::string(chunk_data[i]);
};
}
}
ChunkDataAccessor
PhyCompareFilterExpr::GetChunkData(DataType data_type,
FieldId field_id,
int chunk_id,
int data_barrier) {
switch (data_type) {
case DataType::BOOL:
return GetChunkData<bool>(field_id, chunk_id, data_barrier);
case DataType::INT8:
return GetChunkData<int8_t>(field_id, chunk_id, data_barrier);
case DataType::INT16:
return GetChunkData<int16_t>(field_id, chunk_id, data_barrier);
case DataType::INT32:
return GetChunkData<int32_t>(field_id, chunk_id, data_barrier);
case DataType::INT64:
return GetChunkData<int64_t>(field_id, chunk_id, data_barrier);
case DataType::FLOAT:
return GetChunkData<float>(field_id, chunk_id, data_barrier);
case DataType::DOUBLE:
return GetChunkData<double>(field_id, chunk_id, data_barrier);
case DataType::VARCHAR: {
return GetChunkData<std::string>(field_id, chunk_id, data_barrier);
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}", data_type));
}
}
template <typename OpType>
VectorPtr
PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_);
auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_);
int64_t processed_rows = 0;
for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_;
++chunk_id) {
auto chunk_size = chunk_id == num_chunk_ - 1
? num_rows_ - chunk_id * size_per_chunk_
: size_per_chunk_;
auto left = GetChunkData(expr_->left_data_type_,
expr_->left_field_id_,
chunk_id,
left_data_barrier);
auto right = GetChunkData(expr_->right_data_type_,
expr_->right_field_id_,
chunk_id,
right_data_barrier);
for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0;
i < chunk_size;
++i) {
res[processed_rows++] = boost::apply_visitor(
milvus::query::Relational<decltype(op)>{}, left(i), right(i));
if (processed_rows >= batch_size_) {
current_chunk_id_ = chunk_id;
current_chunk_pos_ = i + 1;
return res_vec;
}
}
}
return res_vec;
}
void
PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
// For segment both fields has no index, can use SIMD to speed up.
// Avoiding too much call stack that blocks SIMD.
if (!is_left_indexed_ && !is_right_indexed_ && !IsStringExpr()) {
result = ExecCompareExprDispatcherForBothDataSegment();
return;
}
result = ExecCompareExprDispatcherForHybridSegment();
}
VectorPtr
PhyCompareFilterExpr::ExecCompareExprDispatcherForHybridSegment() {
switch (expr_->op_type_) {
case OpType::Equal: {
return ExecCompareExprDispatcher(std::equal_to<>{});
}
case OpType::NotEqual: {
return ExecCompareExprDispatcher(std::not_equal_to<>{});
}
case OpType::GreaterEqual: {
return ExecCompareExprDispatcher(std::greater_equal<>{});
}
case OpType::GreaterThan: {
return ExecCompareExprDispatcher(std::greater<>{});
}
case OpType::LessEqual: {
return ExecCompareExprDispatcher(std::less_equal<>{});
}
case OpType::LessThan: {
return ExecCompareExprDispatcher(std::less<>{});
}
case OpType::PrefixMatch: {
return ExecCompareExprDispatcher(
milvus::query::MatchOp<OpType::PrefixMatch>{});
}
// case OpType::PostfixMatch: {
// }
default: {
PanicInfo(OpTypeInvalid,
fmt::format("unsupported optype: {}", expr_->op_type_));
}
}
}
VectorPtr
PhyCompareFilterExpr::ExecCompareExprDispatcherForBothDataSegment() {
switch (expr_->left_data_type_) {
case DataType::BOOL:
return ExecCompareLeftType<bool>();
case DataType::INT8:
return ExecCompareLeftType<int8_t>();
case DataType::INT16:
return ExecCompareLeftType<int16_t>();
case DataType::INT32:
return ExecCompareLeftType<int32_t>();
case DataType::INT64:
return ExecCompareLeftType<int64_t>();
case DataType::FLOAT:
return ExecCompareLeftType<float>();
case DataType::DOUBLE:
return ExecCompareLeftType<double>();
default:
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported left datatype:{} of compare expr",
expr_->left_data_type_));
}
}
template <typename T>
VectorPtr
PhyCompareFilterExpr::ExecCompareLeftType() {
switch (expr_->right_data_type_) {
case DataType::BOOL:
return ExecCompareRightType<T, bool>();
case DataType::INT8:
return ExecCompareRightType<T, int8_t>();
case DataType::INT16:
return ExecCompareRightType<T, int16_t>();
case DataType::INT32:
return ExecCompareRightType<T, int32_t>();
case DataType::INT64:
return ExecCompareRightType<T, int64_t>();
case DataType::FLOAT:
return ExecCompareRightType<T, float>();
case DataType::DOUBLE:
return ExecCompareRightType<T, double>();
default:
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported right datatype:{} of compare expr",
expr_->right_data_type_));
}
}
template <typename T, typename U>
VectorPtr
PhyCompareFilterExpr::ExecCompareRightType() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto expr_type = expr_->op_type_;
auto execute_sub_batch = [expr_type](const T* left,
const U* right,
const int size,
bool* res) {
switch (expr_type) {
case proto::plan::GreaterThan: {
CompareElementFunc<T, U, proto::plan::GreaterThan> func;
func(left, right, size, res);
break;
}
case proto::plan::GreaterEqual: {
CompareElementFunc<T, U, proto::plan::GreaterEqual> func;
func(left, right, size, res);
break;
}
case proto::plan::LessThan: {
CompareElementFunc<T, U, proto::plan::LessThan> func;
func(left, right, size, res);
break;
}
case proto::plan::LessEqual: {
CompareElementFunc<T, U, proto::plan::LessEqual> func;
func(left, right, size, res);
break;
}
case proto::plan::Equal: {
CompareElementFunc<T, U, proto::plan::Equal> func;
func(left, right, size, res);
break;
}
case proto::plan::NotEqual: {
CompareElementFunc<T, U, proto::plan::NotEqual> func;
func(left, right, size, res);
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format(
"unsupported operator type for compare column expr: {}",
expr_type));
}
};
int64_t processed_size =
ProcessBothDataChunks<T, U>(execute_sub_batch, res);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,186 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include <boost/variant.hpp>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
using number = boost::variant<bool,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
std::string>;
using ChunkDataAccessor = std::function<const number(int)>;
template <typename T, typename U, proto::plan::OpType op>
struct CompareElementFunc {
void
operator()(const T* left, const U* right, size_t size, bool* res) {
for (int i = 0; i < size; ++i) {
if constexpr (op == proto::plan::OpType::Equal) {
res[i] = left[i] == right[i];
} else if constexpr (op == proto::plan::OpType::NotEqual) {
res[i] = left[i] != right[i];
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
res[i] = left[i] > right[i];
} else if constexpr (op == proto::plan::OpType::LessThan) {
res[i] = left[i] < right[i];
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
res[i] = left[i] >= right[i];
} else if constexpr (op == proto::plan::OpType::LessEqual) {
res[i] = left[i] <= right[i];
} else {
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported op_type:{} for CompareElementFunc",
op));
}
}
}
};
class PhyCompareFilterExpr : public Expr {
public:
PhyCompareFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::CompareExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: Expr(DataType::BOOL, std::move(input), name),
left_field_(expr->left_field_id_),
right_field_(expr->right_field_id_),
segment_(segment),
query_timestamp_(query_timestamp),
batch_size_(batch_size),
expr_(expr) {
is_left_indexed_ = segment_->HasIndex(left_field_);
is_right_indexed_ = segment_->HasIndex(right_field_);
num_rows_ = segment_->get_active_count(query_timestamp_);
num_chunk_ = is_left_indexed_
? segment_->num_chunk_index(expr_->left_field_id_)
: segment_->num_chunk_data(expr_->left_field_id_);
size_per_chunk_ = segment_->size_per_chunk();
AssertInfo(
batch_size_ > 0,
fmt::format("expr batch size should greater than zero, but now: {}",
batch_size_));
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
int64_t
GetNextBatchSize();
bool
IsStringExpr();
template <typename T>
ChunkDataAccessor
GetChunkData(FieldId field_id, int chunk_id, int data_barrier);
template <typename T, typename U, typename FUNC, typename... ValTypes>
int64_t
ProcessBothDataChunks(FUNC func, bool* res, ValTypes... values) {
int64_t processed_size = 0;
for (size_t i = current_chunk_id_; i < num_chunk_; i++) {
auto left_chunk = segment_->chunk_data<T>(left_field_, i);
auto right_chunk = segment_->chunk_data<U>(right_field_, i);
auto data_pos = (i == current_chunk_id_) ? current_chunk_pos_ : 0;
auto size = (i == (num_chunk_ - 1))
? (segment_->type() == SegmentType::Growing
? num_rows_ % size_per_chunk_ - data_pos
: num_rows_ - data_pos)
: size_per_chunk_ - data_pos;
if (processed_size + size >= batch_size_) {
size = batch_size_ - processed_size;
}
const T* left_data = left_chunk.data() + data_pos;
const U* right_data = right_chunk.data() + data_pos;
func(left_data, right_data, size, res + processed_size, values...);
processed_size += size;
if (processed_size >= batch_size_) {
current_chunk_id_ = i;
current_chunk_pos_ = data_pos + size;
break;
}
}
return processed_size;
}
ChunkDataAccessor
GetChunkData(DataType data_type,
FieldId field_id,
int chunk_id,
int data_barrier);
template <typename OpType>
VectorPtr
ExecCompareExprDispatcher(OpType op);
VectorPtr
ExecCompareExprDispatcherForHybridSegment();
VectorPtr
ExecCompareExprDispatcherForBothDataSegment();
template <typename T>
VectorPtr
ExecCompareLeftType();
template <typename T, typename U>
VectorPtr
ExecCompareRightType();
private:
const FieldId left_field_;
const FieldId right_field_;
bool is_left_indexed_;
bool is_right_indexed_;
int64_t num_rows_{0};
int64_t num_chunk_{0};
int64_t current_chunk_id_{0};
int64_t current_chunk_pos_{0};
int64_t size_per_chunk_{0};
const segcore::SegmentInternalInterface* segment_;
Timestamp query_timestamp_;
int64_t batch_size_;
std::shared_ptr<const milvus::expr::CompareExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,131 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ConjunctExpr.h"
#include "simd/hook.h"
namespace milvus {
namespace exec {
DataType
PhyConjunctFilterExpr::ResolveType(const std::vector<DataType>& inputs) {
AssertInfo(
inputs.size() > 0,
fmt::format(
"Conjunct expressions expect at least one argument, received: {}",
inputs.size()));
for (const auto& type : inputs) {
AssertInfo(
type == DataType::BOOL,
fmt::format("Conjunct expressions expect BOOLEAN, received: {}",
type));
}
return DataType::BOOL;
}
static bool
AllTrue(ColumnVectorPtr& vec) {
bool* data = static_cast<bool*>(vec->GetRawData());
#if defined(USE_DYNAMIC_SIMD)
return milvus::simd::all_true(data, vec->size());
#else
for (int i = 0; i < vec->size(); ++i) {
if (!data[i]) {
return false;
}
}
return true;
#endif
}
static void
AllSet(ColumnVectorPtr& vec) {
bool* data = static_cast<bool*>(vec->GetRawData());
for (int i = 0; i < vec->size(); ++i) {
data[i] = true;
}
}
static void
AllReset(ColumnVectorPtr& vec) {
bool* data = static_cast<bool*>(vec->GetRawData());
for (int i = 0; i < vec->size(); ++i) {
data[i] = false;
}
}
static bool
AllFalse(ColumnVectorPtr& vec) {
bool* data = static_cast<bool*>(vec->GetRawData());
#if defined(USE_DYNAMIC_SIMD)
return milvus::simd::all_false(data, vec->size());
#else
for (int i = 0; i < vec->size(); ++i) {
if (data[i]) {
return false;
}
}
return true;
#endif
}
int64_t
PhyConjunctFilterExpr::UpdateResult(ColumnVectorPtr& input_result,
EvalCtx& ctx,
ColumnVectorPtr& result) {
if (is_and_) {
ConjunctElementFunc<true> func;
return func(input_result, result);
} else {
ConjunctElementFunc<false> func;
return func(input_result, result);
}
}
bool
PhyConjunctFilterExpr::CanSkipNextExprs(ColumnVectorPtr& vec) {
if ((is_and_ && AllFalse(vec)) || (!is_and_ && AllTrue(vec))) {
return true;
}
return false;
}
void
PhyConjunctFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
for (int i = 0; i < inputs_.size(); ++i) {
VectorPtr input_result;
inputs_[i]->Eval(context, input_result);
if (i == 0) {
result = input_result;
auto all_flat_result = GetColumnVector(result);
if (CanSkipNextExprs(all_flat_result)) {
return;
}
continue;
}
auto input_flat_result = GetColumnVector(input_result);
auto all_flat_result = GetColumnVector(result);
auto active_rows =
UpdateResult(input_flat_result, context, all_flat_result);
if (active_rows == 0) {
return;
}
}
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,89 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
template <bool is_and>
struct ConjunctElementFunc {
int64_t
operator()(ColumnVectorPtr& input_result, ColumnVectorPtr& result) {
bool* input_data = static_cast<bool*>(input_result->GetRawData());
bool* res_data = static_cast<bool*>(result->GetRawData());
int64_t activate_rows = 0;
for (int i = 0; i < result->size(); ++i) {
if constexpr (is_and) {
res_data[i] &= input_data[i];
if (res_data[i]) {
activate_rows++;
}
} else {
res_data[i] |= input_data[i];
if (!res_data[i]) {
activate_rows++;
}
}
}
return activate_rows;
}
};
class PhyConjunctFilterExpr : public Expr {
public:
PhyConjunctFilterExpr(std::vector<ExprPtr>&& inputs, bool is_and)
: Expr(DataType::BOOL, std::move(inputs), is_and ? "and" : "or"),
is_and_(is_and) {
std::vector<DataType> input_types;
input_types.reserve(inputs_.size());
std::transform(inputs_.begin(),
inputs_.end(),
std::back_inserter(input_types),
[](const ExprPtr& expr) { return expr->type(); });
ResolveType(input_types);
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
int64_t
UpdateResult(ColumnVectorPtr& input_result,
EvalCtx& ctx,
ColumnVectorPtr& result);
static DataType
ResolveType(const std::vector<DataType>& inputs);
bool
CanSkipNextExprs(ColumnVectorPtr& vec);
// true if conjunction (and), false if disjunction (or).
bool is_and_;
std::vector<int32_t> input_order_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,62 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <memory>
#include <string>
#include <vector>
#include "common/Vector.h"
#include "exec/QueryContext.h"
namespace milvus {
namespace exec {
class ExprSet;
class EvalCtx {
public:
EvalCtx(ExecContext* exec_ctx, ExprSet* expr_set, RowVector* row)
: exec_ctx_(exec_ctx), expr_set_(expr_set_), row_(row) {
assert(exec_ctx_ != nullptr);
assert(expr_set_ != nullptr);
// assert(row_ != nullptr);
}
explicit EvalCtx(ExecContext* exec_ctx)
: exec_ctx_(exec_ctx), expr_set_(nullptr), row_(nullptr) {
}
ExecContext*
get_exec_context() {
return exec_ctx_;
}
std::shared_ptr<QueryConfig>
get_query_config() {
return exec_ctx_->get_query_config();
}
private:
ExecContext* exec_ctx_;
ExprSet* expr_set_;
RowVector* row_;
bool input_no_nulls_;
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,72 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ExistsExpr.h"
#include "common/Json.h"
namespace milvus {
namespace exec {
void
PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
switch (expr_->column_.data_type_) {
case DataType::JSON: {
if (is_index_mode_) {
PanicInfo(ExprInvalid,
"exists expr for json index mode not supportted");
}
result = EvalJsonExistsForDataSegment();
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}",
expr_->column_.data_type_));
}
}
VectorPtr
PhyExistsFilterExpr::EvalJsonExistsForDataSegment() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
auto execute_sub_batch = [](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer) {
for (int i = 0; i < size; ++i) {
res[i] = data[i].exist(pointer);
}
};
int64_t processed_size = ProcessDataChunks<Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,66 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
template <typename T>
struct ExistsElementFunc {
void
operator()(const T* src, size_t size, T val, bool* res) {
}
};
class PhyExistsFilterExpr : public SegmentExpr {
public:
PhyExistsFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::ExistsExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: SegmentExpr(std::move(input),
name,
segment,
expr->column_.field_id_,
query_timestamp,
batch_size),
expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
VectorPtr
EvalJsonExistsForDataSegment();
private:
std::shared_ptr<const milvus::expr::ExistsExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,255 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Expr.h"
#include "exec/expression/AlwaysTrueExpr.h"
#include "exec/expression/BinaryArithOpEvalRangeExpr.h"
#include "exec/expression/BinaryRangeExpr.h"
#include "exec/expression/CompareExpr.h"
#include "exec/expression/ConjunctExpr.h"
#include "exec/expression/ExistsExpr.h"
#include "exec/expression/JsonContainsExpr.h"
#include "exec/expression/LogicalBinaryExpr.h"
#include "exec/expression/LogicalUnaryExpr.h"
#include "exec/expression/TermExpr.h"
#include "exec/expression/UnaryExpr.h"
namespace milvus {
namespace exec {
void
ExprSet::Eval(int32_t begin,
int32_t end,
bool initialize,
EvalCtx& context,
std::vector<VectorPtr>& results) {
results.resize(exprs_.size());
for (size_t i = begin; i < end; ++i) {
exprs_[i]->Eval(context, results[i]);
}
}
std::vector<ExprPtr>
CompileExpressions(const std::vector<expr::TypedExprPtr>& sources,
ExecContext* context,
const std::unordered_set<std::string>& flatten_candidate,
bool enable_constant_folding) {
std::vector<std::shared_ptr<Expr>> exprs;
exprs.reserve(sources.size());
for (auto& source : sources) {
exprs.emplace_back(CompileExpression(source,
context->get_query_context(),
flatten_candidate,
enable_constant_folding));
}
return exprs;
}
static std::optional<std::string>
ShouldFlatten(const expr::TypedExprPtr& expr,
const std::unordered_set<std::string>& flat_candidates = {}) {
if (auto call =
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And ||
call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) {
return call->name();
}
}
return std::nullopt;
}
static bool
IsCall(const expr::TypedExprPtr& expr, const std::string& name) {
if (auto call =
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
return call->name() == name;
}
return false;
}
static bool
AllInputTypeEqual(const expr::TypedExprPtr& expr) {
const auto& inputs = expr->inputs();
for (int i = 1; i < inputs.size(); i++) {
if (inputs[0]->type() != inputs[i]->type()) {
return false;
}
}
return true;
}
static void
FlattenInput(const expr::TypedExprPtr& input,
const std::string& flatten_call,
std::vector<expr::TypedExprPtr>& flat) {
if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) {
for (auto& child : input->inputs()) {
FlattenInput(child, flatten_call, flat);
}
} else {
flat.emplace_back(input);
}
}
std::vector<ExprPtr>
CompileInputs(const expr::TypedExprPtr& expr,
QueryContext* context,
const std::unordered_set<std::string>& flatten_cadidates) {
std::vector<ExprPtr> compiled_inputs;
auto flatten = ShouldFlatten(expr);
for (auto& input : expr->inputs()) {
if (dynamic_cast<const expr::InputTypeExpr*>(input.get())) {
AssertInfo(
dynamic_cast<const expr::FieldAccessTypeExpr*>(expr.get()),
"An InputReference can only occur under a FieldReference");
} else {
if (flatten.has_value()) {
std::vector<expr::TypedExprPtr> flat_exprs;
FlattenInput(input, flatten.value(), flat_exprs);
for (auto& input : flat_exprs) {
compiled_inputs.push_back(CompileExpression(
input, context, flatten_cadidates, false));
}
} else {
compiled_inputs.push_back(CompileExpression(
input, context, flatten_cadidates, false));
}
}
}
return compiled_inputs;
}
ExprPtr
CompileExpression(const expr::TypedExprPtr& expr,
QueryContext* context,
const std::unordered_set<std::string>& flatten_candidates,
bool enable_constant_folding) {
ExprPtr result;
auto result_type = expr->type();
auto compiled_inputs = CompileInputs(expr, context, flatten_candidates);
auto GetTypes = [](const std::vector<ExprPtr>& exprs) {
std::vector<DataType> types;
for (auto& expr : exprs) {
types.push_back(expr->type());
}
return types;
};
auto input_types = GetTypes(compiled_inputs);
if (auto call = dynamic_cast<const expr::CallTypeExpr*>(expr.get())) {
// TODO: support function register and search mode
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::UnaryRangeFilterExpr>(expr)) {
result = std::make_shared<PhyUnaryRangeFilterExpr>(
compiled_inputs,
casted_expr,
"PhyUnaryRangeFilterExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::LogicalUnaryExpr>(expr)) {
result = std::make_shared<PhyLogicalUnaryExpr>(
compiled_inputs, casted_expr, "PhyLogicalUnaryExpr");
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::TermFilterExpr>(expr)) {
result = std::make_shared<PhyTermFilterExpr>(
compiled_inputs,
casted_expr,
"PhyTermFilterExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::LogicalBinaryExpr>(expr)) {
if (casted_expr->op_type_ ==
milvus::expr::LogicalBinaryExpr::OpType::And ||
casted_expr->op_type_ ==
milvus::expr::LogicalBinaryExpr::OpType::Or) {
result = std::make_shared<PhyConjunctFilterExpr>(
std::move(compiled_inputs),
casted_expr->op_type_ ==
milvus::expr::LogicalBinaryExpr::OpType::And);
} else {
result = std::make_shared<PhyLogicalBinaryExpr>(
compiled_inputs, casted_expr, "PhyLogicalBinaryExpr");
}
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::BinaryRangeFilterExpr>(expr)) {
result = std::make_shared<PhyBinaryRangeFilterExpr>(
compiled_inputs,
casted_expr,
"PhyBinaryRangeFilterExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::AlwaysTrueExpr>(expr)) {
result = std::make_shared<PhyAlwaysTrueExpr>(
compiled_inputs,
casted_expr,
"PhyAlwaysTrueExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::BinaryArithOpEvalRangeExpr>(expr)) {
result = std::make_shared<PhyBinaryArithOpEvalRangeExpr>(
compiled_inputs,
casted_expr,
"PhyBinaryArithOpEvalRangeExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr =
std::dynamic_pointer_cast<const milvus::expr::CompareExpr>(
expr)) {
result = std::make_shared<PhyCompareFilterExpr>(
compiled_inputs,
casted_expr,
"PhyCompareFilterExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr =
std::dynamic_pointer_cast<const milvus::expr::ExistsExpr>(
expr)) {
result = std::make_shared<PhyExistsFilterExpr>(
compiled_inputs,
casted_expr,
"PhyExistsFilterExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::JsonContainsExpr>(expr)) {
result = std::make_shared<PhyJsonContainsFilterExpr>(
compiled_inputs,
casted_expr,
"PhyJsonContainsFilterExpr",
context->get_segment(),
context->get_query_timestamp(),
context->query_config()->get_expr_batch_size());
}
return result;
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,324 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "common/Types.h"
#include "exec/expression/EvalCtx.h"
#include "exec/expression/VectorFunction.h"
#include "exec/expression/Utils.h"
#include "exec/QueryContext.h"
#include "expr/ITypeExpr.h"
#include "query/PlanProto.h"
namespace milvus {
namespace exec {
class Expr {
public:
Expr(DataType type,
const std::vector<std::shared_ptr<Expr>>&& inputs,
const std::string& name)
: type_(type),
inputs_(std::move(inputs)),
name_(name),
vector_func_(nullptr) {
}
Expr(DataType type,
const std::vector<std::shared_ptr<Expr>>&& inputs,
std::shared_ptr<VectorFunction> vec_func,
const std::string& name)
: type_(type),
inputs_(std::move(inputs)),
name_(name),
vector_func_(vec_func) {
}
virtual ~Expr() = default;
const DataType&
type() const {
return type_;
}
std::string
get_name() {
return name_;
}
virtual void
Eval(EvalCtx& context, VectorPtr& result) {
}
protected:
DataType type_;
const std::vector<std::shared_ptr<Expr>> inputs_;
std::string name_;
std::shared_ptr<VectorFunction> vector_func_;
};
using ExprPtr = std::shared_ptr<milvus::exec::Expr>;
using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int);
class SegmentExpr : public Expr {
public:
SegmentExpr(const std::vector<ExprPtr>&& input,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
const FieldId& field_id,
Timestamp query_timestamp,
int64_t batch_size)
: Expr(DataType::BOOL, std::move(input), name),
segment_(segment),
field_id_(field_id),
query_timestamp_(query_timestamp),
batch_size_(batch_size) {
num_rows_ = segment_->get_active_count(query_timestamp_);
size_per_chunk_ = segment_->size_per_chunk();
AssertInfo(
batch_size_ > 0,
fmt::format("expr batch size should greater than zero, but now: {}",
batch_size_));
InitSegmentExpr();
}
void
InitSegmentExpr() {
auto& schema = segment_->get_schema();
auto& field_meta = schema[field_id_];
if (schema.get_primary_field_id().has_value() &&
schema.get_primary_field_id().value() == field_id_ &&
IsPrimaryKeyDataType(field_meta.get_data_type())) {
is_pk_field_ = true;
pk_type_ = field_meta.get_data_type();
}
is_index_mode_ = segment_->HasIndex(field_id_);
if (is_index_mode_) {
num_index_chunk_ = segment_->num_chunk_index(field_id_);
} else {
num_data_chunk_ = segment_->num_chunk_data(field_id_);
}
}
int64_t
GetNextBatchSize() {
auto current_chunk =
is_index_mode_ ? current_index_chunk_ : current_data_chunk_;
auto current_chunk_pos =
is_index_mode_ ? current_index_chunk_pos_ : current_data_chunk_pos_;
auto current_rows = current_chunk * size_per_chunk_ + current_chunk_pos;
return current_rows + batch_size_ >= num_rows_
? num_rows_ - current_rows
: batch_size_;
}
template <typename T, typename FUNC, typename... ValTypes>
int64_t
ProcessDataChunks(
FUNC func,
std::function<bool(const milvus::SkipIndex&, FieldId, int)> skip_func,
bool* res,
ValTypes... values) {
int64_t processed_size = 0;
for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) {
auto data_pos =
(i == current_data_chunk_) ? current_data_chunk_pos_ : 0;
auto size = (i == (num_data_chunk_ - 1))
? (segment_->type() == SegmentType::Growing
? num_rows_ % size_per_chunk_ - data_pos
: num_rows_ - data_pos)
: size_per_chunk_ - data_pos;
size = std::min(size, batch_size_ - processed_size);
auto& skip_index = segment_->GetSkipIndex();
if (!skip_func || !skip_func(skip_index, field_id_, i)) {
auto chunk = segment_->chunk_data<T>(field_id_, i);
const T* data = chunk.data() + data_pos;
func(data, size, res + processed_size, values...);
}
processed_size += size;
if (processed_size >= batch_size_) {
current_data_chunk_ = i;
current_data_chunk_pos_ = data_pos + size;
break;
}
}
return processed_size;
}
int
ProcessIndexOneChunk(FixedVector<bool>& result,
size_t chunk_id,
const FixedVector<bool>& chunk_res,
int processed_rows) {
auto data_pos =
chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0;
auto size = std::min(
std::min(size_per_chunk_ - data_pos, batch_size_ - processed_rows),
int64_t(chunk_res.size()));
result.insert(result.end(),
chunk_res.begin() + data_pos,
chunk_res.begin() + data_pos + size);
return size;
}
template <typename T, typename FUNC, typename... ValTypes>
FixedVector<bool>
ProcessIndexChunks(FUNC func, ValTypes... values) {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
FixedVector<bool> result;
int processed_rows = 0;
for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) {
// This cache result help getting result for every batch loop.
// It avoids indexing execute for evevy batch because indexing
// executing costs quite much time.
if (cached_index_chunk_id_ != i) {
const Index& index =
segment_->chunk_scalar_index<IndexInnerType>(field_id_, i);
auto* index_ptr = const_cast<Index*>(&index);
cached_index_chunk_res_ = std::move(func(index_ptr, values...));
cached_index_chunk_id_ = i;
}
auto size = ProcessIndexOneChunk(
result, i, cached_index_chunk_res_, processed_rows);
if (processed_rows + size >= batch_size_) {
current_index_chunk_ = i;
current_index_chunk_pos_ = i == current_index_chunk_
? current_index_chunk_pos_ + size
: size;
break;
}
processed_rows += size;
}
return result;
}
protected:
const segcore::SegmentInternalInterface* segment_;
const FieldId field_id_;
bool is_pk_field_{false};
DataType pk_type_;
Timestamp query_timestamp_;
int64_t batch_size_;
// State indicate position that expr computing at
// because expr maybe called for every batch.
bool is_index_mode_{false};
bool is_data_mode_{false};
int64_t num_rows_{0};
int64_t num_data_chunk_{0};
int64_t num_index_chunk_{0};
int64_t current_data_chunk_{0};
int64_t current_data_chunk_pos_{0};
int64_t current_index_chunk_{0};
int64_t current_index_chunk_pos_{0};
int64_t size_per_chunk_{0};
// Cache for index scan to avoid search index every batch
int64_t cached_index_chunk_id_{-1};
FixedVector<bool> cached_index_chunk_res_{};
};
std::vector<ExprPtr>
CompileExpressions(const std::vector<expr::TypedExprPtr>& logical_exprs,
ExecContext* context,
const std::unordered_set<std::string>& flatten_cadidates =
std::unordered_set<std::string>(),
bool enable_constant_folding = false);
std::vector<ExprPtr>
CompileInputs(const expr::TypedExprPtr& expr,
QueryContext* config,
const std::unordered_set<std::string>& flatten_cadidates);
ExprPtr
CompileExpression(const expr::TypedExprPtr& expr,
QueryContext* context,
const std::unordered_set<std::string>& flatten_cadidates,
bool enable_constant_folding);
class ExprSet {
public:
explicit ExprSet(const std::vector<expr::TypedExprPtr>& logical_exprs,
ExecContext* exec_ctx) {
exprs_ = CompileExpressions(logical_exprs, exec_ctx);
}
virtual ~ExprSet() = default;
void
Eval(EvalCtx& ctx, std::vector<VectorPtr>& results) {
Eval(0, exprs_.size(), true, ctx, results);
}
virtual void
Eval(int32_t begin,
int32_t end,
bool initialize,
EvalCtx& ctx,
std::vector<VectorPtr>& result);
void
Clear() {
exprs_.clear();
}
ExecContext*
get_exec_context() const {
return exec_ctx_;
}
size_t
size() const {
return exprs_.size();
}
const std::vector<std::shared_ptr<Expr>>&
exprs() const {
return exprs_;
}
const std::shared_ptr<Expr>&
expr(int32_t index) const {
return exprs_[index];
}
private:
std::vector<std::shared_ptr<Expr>> exprs_;
ExecContext* exec_ctx_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,740 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "JsonContainsExpr.h"
namespace milvus {
namespace exec {
void
PhyJsonContainsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
switch (expr_->column_.data_type_) {
case DataType::ARRAY:
case DataType::JSON: {
if (is_index_mode_) {
PanicInfo(
ExprInvalid,
"exists expr for json or array index mode not supportted");
}
result = EvalJsonContainsForDataSegment();
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}",
expr_->column_.data_type_));
}
}
VectorPtr
PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() {
auto data_type = expr_->column_.data_type_;
switch (expr_->op_) {
case proto::plan::JSONContainsExpr_JSONOp_Contains:
case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: {
if (datatype_is_array(data_type)) {
auto val_type = expr_->vals_[0].val_case();
switch (val_type) {
case proto::plan::GenericValue::kBoolVal: {
return ExecArrayContains<bool>();
}
case proto::plan::GenericValue::kInt64Val: {
return ExecArrayContains<int64_t>();
}
case proto::plan::GenericValue::kFloatVal: {
return ExecArrayContains<double>();
}
case proto::plan::GenericValue::kStringVal: {
return ExecArrayContains<std::string>();
}
default:
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported data type {}", val_type));
}
} else {
if (expr_->same_type_) {
auto val_type = expr_->vals_[0].val_case();
switch (val_type) {
case proto::plan::GenericValue::kBoolVal: {
return ExecJsonContains<bool>();
}
case proto::plan::GenericValue::kInt64Val: {
return ExecJsonContains<int64_t>();
}
case proto::plan::GenericValue::kFloatVal: {
return ExecJsonContains<double>();
}
case proto::plan::GenericValue::kStringVal: {
return ExecJsonContains<std::string>();
}
case proto::plan::GenericValue::kArrayVal: {
return ExecJsonContainsArray();
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type:{}",
val_type));
}
} else {
return ExecJsonContainsWithDiffType();
}
}
break;
}
case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: {
if (datatype_is_array(data_type)) {
auto val_type = expr_->vals_[0].val_case();
switch (val_type) {
case proto::plan::GenericValue::kBoolVal: {
return ExecArrayContainsAll<bool>();
}
case proto::plan::GenericValue::kInt64Val: {
return ExecArrayContainsAll<int64_t>();
}
case proto::plan::GenericValue::kFloatVal: {
return ExecArrayContainsAll<double>();
}
case proto::plan::GenericValue::kStringVal: {
return ExecArrayContainsAll<std::string>();
}
default:
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported data type {}", val_type));
}
} else {
if (expr_->same_type_) {
auto val_type = expr_->vals_[0].val_case();
switch (val_type) {
case proto::plan::GenericValue::kBoolVal: {
return ExecJsonContainsAll<bool>();
}
case proto::plan::GenericValue::kInt64Val: {
return ExecJsonContainsAll<int64_t>();
}
case proto::plan::GenericValue::kFloatVal: {
return ExecJsonContainsAll<double>();
}
case proto::plan::GenericValue::kStringVal: {
return ExecJsonContainsAll<std::string>();
}
case proto::plan::GenericValue::kArrayVal: {
return ExecJsonContainsAllArray();
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type:{}",
val_type));
}
} else {
return ExecJsonContainsAllWithDiffType();
}
}
break;
}
default:
PanicInfo(ExprInvalid,
fmt::format("unsupported json contains type {}",
proto::plan::JSONContainsExpr_JSONOp_Name(
expr_->op_)));
}
}
template <typename ExprValueType>
VectorPtr
PhyJsonContainsFilterExpr::ExecArrayContains() {
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
AssertInfo(expr_->column_.nested_path_.size() == 0,
"[ExecArrayContains]nested path must be null");
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
std::unordered_set<GetType> elements;
for (auto const& element : expr_->vals_) {
elements.insert(GetValueFromProto<GetType>(element));
}
auto execute_sub_batch = [](const milvus::ArrayView* data,
const int size,
bool* res,
const std::unordered_set<GetType>& elements) {
auto executor = [&](size_t i) {
const auto& array = data[i];
for (int j = 0; j < array.length(); ++j) {
if (elements.count(array.template get_data<GetType>(j)) > 0) {
return true;
}
}
return false;
};
for (int i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ExprValueType>
VectorPtr
PhyJsonContainsFilterExpr::ExecJsonContains() {
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
std::unordered_set<GetType> elements;
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
for (auto const& element : expr_->vals_) {
elements.insert(GetValueFromProto<GetType>(element));
}
auto execute_sub_batch = [](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer,
const std::unordered_set<GetType>& elements) {
auto executor = [&](size_t i) {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
for (auto&& it : array) {
auto val = it.template get<GetType>();
if (val.error()) {
continue;
}
if (elements.count(val.value()) > 0) {
return true;
}
}
return false;
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
VectorPtr
PhyJsonContainsFilterExpr::ExecJsonContainsArray() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
std::vector<proto::plan::Array> elements;
for (auto const& element : expr_->vals_) {
elements.emplace_back(GetValueFromProto<proto::plan::Array>(element));
}
auto execute_sub_batch =
[](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer,
const std::vector<proto::plan::Array>& elements) {
auto executor = [&](size_t i) -> bool {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
for (auto&& it : array) {
auto val = it.get_array();
if (val.error()) {
continue;
}
std::vector<
simdjson::simdjson_result<simdjson::ondemand::value>>
json_array;
json_array.reserve(val.count_elements());
for (auto&& e : val) {
json_array.emplace_back(e);
}
for (auto const& element : elements) {
if (CompareTwoJsonArray(json_array, element)) {
return true;
}
}
}
return false;
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<milvus::Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ExprValueType>
VectorPtr
PhyJsonContainsFilterExpr::ExecArrayContainsAll() {
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
AssertInfo(expr_->column_.nested_path_.size() == 0,
"[ExecArrayContainsAll]nested path must be null");
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
std::unordered_set<GetType> elements;
for (auto const& element : expr_->vals_) {
elements.insert(GetValueFromProto<GetType>(element));
}
auto execute_sub_batch = [](const milvus::ArrayView* data,
const int size,
bool* res,
const std::unordered_set<GetType>& elements) {
auto executor = [&](size_t i) {
std::unordered_set<GetType> tmp_elements(elements);
// Note: array can only be iterated once
for (int j = 0; j < data[i].length(); ++j) {
tmp_elements.erase(data[i].template get_data<GetType>(j));
if (tmp_elements.size() == 0) {
return true;
}
}
return tmp_elements.size() == 0;
};
for (int i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ExprValueType>
VectorPtr
PhyJsonContainsFilterExpr::ExecJsonContainsAll() {
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
std::unordered_set<GetType> elements;
for (auto const& element : expr_->vals_) {
elements.insert(GetValueFromProto<GetType>(element));
}
auto execute_sub_batch = [](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer,
const std::unordered_set<GetType>& elements) {
auto executor = [&](const size_t i) -> bool {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
std::unordered_set<GetType> tmp_elements(elements);
// Note: array can only be iterated once
for (auto&& it : array) {
auto val = it.template get<GetType>();
if (val.error()) {
continue;
}
tmp_elements.erase(val.value());
if (tmp_elements.size() == 0) {
return true;
}
}
return tmp_elements.size() == 0;
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
VectorPtr
PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
auto elements = expr_->vals_;
std::unordered_set<int> elements_index;
int i = 0;
for (auto& element : elements) {
elements_index.insert(i);
i++;
}
auto execute_sub_batch =
[](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer,
const std::vector<proto::plan::GenericValue>& elements,
const std::unordered_set<int> elements_index) {
auto executor = [&](size_t i) -> bool {
const auto& json = data[i];
auto doc = json.doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
std::unordered_set<int> tmp_elements_index(elements_index);
for (auto&& it : array) {
int i = -1;
for (auto& element : elements) {
i++;
switch (element.val_case()) {
case proto::plan::GenericValue::kBoolVal: {
auto val = it.template get<bool>();
if (val.error()) {
continue;
}
if (val.value() == element.bool_val()) {
tmp_elements_index.erase(i);
}
break;
}
case proto::plan::GenericValue::kInt64Val: {
auto val = it.template get<int64_t>();
if (val.error()) {
continue;
}
if (val.value() == element.int64_val()) {
tmp_elements_index.erase(i);
}
break;
}
case proto::plan::GenericValue::kFloatVal: {
auto val = it.template get<double>();
if (val.error()) {
continue;
}
if (val.value() == element.float_val()) {
tmp_elements_index.erase(i);
}
break;
}
case proto::plan::GenericValue::kStringVal: {
auto val = it.template get<std::string_view>();
if (val.error()) {
continue;
}
if (val.value() == element.string_val()) {
tmp_elements_index.erase(i);
}
break;
}
case proto::plan::GenericValue::kArrayVal: {
auto val = it.get_array();
if (val.error()) {
continue;
}
if (CompareTwoJsonArray(val,
element.array_val())) {
tmp_elements_index.erase(i);
}
break;
}
default:
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported data type {}",
element.val_case()));
}
if (tmp_elements_index.size() == 0) {
return true;
}
}
if (tmp_elements_index.size() == 0) {
return true;
}
}
return tmp_elements_index.size() == 0;
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<Json>(execute_sub_batch,
std::nullptr_t{},
res,
pointer,
elements,
elements_index);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
VectorPtr
PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
std::vector<proto::plan::Array> elements;
for (auto const& element : expr_->vals_) {
elements.emplace_back(GetValueFromProto<proto::plan::Array>(element));
}
auto execute_sub_batch =
[](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer,
const std::vector<proto::plan::Array>& elements) {
auto executor = [&](const size_t i) {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
std::unordered_set<int> exist_elements_index;
for (auto&& it : array) {
auto val = it.get_array();
if (val.error()) {
continue;
}
std::vector<
simdjson::simdjson_result<simdjson::ondemand::value>>
json_array;
json_array.reserve(val.count_elements());
for (auto&& e : val) {
json_array.emplace_back(e);
}
for (int index = 0; index < elements.size(); ++index) {
if (CompareTwoJsonArray(json_array, elements[index])) {
exist_elements_index.insert(index);
}
}
if (exist_elements_index.size() == elements.size()) {
return true;
}
}
return exist_elements_index.size() == elements.size();
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
VectorPtr
PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
auto elements = expr_->vals_;
std::unordered_set<int> elements_index;
int i = 0;
for (auto& element : elements) {
elements_index.insert(i);
i++;
}
auto execute_sub_batch =
[](const milvus::Json* data,
const int size,
bool* res,
const std::string& pointer,
const std::vector<proto::plan::GenericValue>& elements) {
auto executor = [&](const size_t i) {
auto& json = data[i];
auto doc = json.doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
// Note: array can only be iterated once
for (auto&& it : array) {
for (auto const& element : elements) {
switch (element.val_case()) {
case proto::plan::GenericValue::kBoolVal: {
auto val = it.template get<bool>();
if (val.error()) {
continue;
}
if (val.value() == element.bool_val()) {
return true;
}
break;
}
case proto::plan::GenericValue::kInt64Val: {
auto val = it.template get<int64_t>();
if (val.error()) {
continue;
}
if (val.value() == element.int64_val()) {
return true;
}
break;
}
case proto::plan::GenericValue::kFloatVal: {
auto val = it.template get<double>();
if (val.error()) {
continue;
}
if (val.value() == element.float_val()) {
return true;
}
break;
}
case proto::plan::GenericValue::kStringVal: {
auto val = it.template get<std::string_view>();
if (val.error()) {
continue;
}
if (val.value() == element.string_val()) {
return true;
}
break;
}
case proto::plan::GenericValue::kArrayVal: {
auto val = it.get_array();
if (val.error()) {
continue;
}
if (CompareTwoJsonArray(val,
element.array_val())) {
return true;
}
break;
}
default:
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported data type {}",
element.val_case()));
}
}
}
return false;
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,87 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
class PhyJsonContainsFilterExpr : public SegmentExpr {
public:
PhyJsonContainsFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::JsonContainsExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: SegmentExpr(std::move(input),
name,
segment,
expr->column_.field_id_,
query_timestamp,
batch_size),
expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
VectorPtr
EvalJsonContainsForDataSegment();
template <typename ExprValueType>
VectorPtr
ExecJsonContains();
template <typename ExprValueType>
VectorPtr
ExecArrayContains();
template <typename ExprValueType>
VectorPtr
ExecJsonContainsAll();
template <typename ExprValueType>
VectorPtr
ExecArrayContainsAll();
VectorPtr
ExecJsonContainsArray();
VectorPtr
ExecJsonContainsAllArray();
VectorPtr
ExecJsonContainsAllWithDiffType();
VectorPtr
ExecJsonContainsWithDiffType();
private:
std::shared_ptr<const milvus::expr::JsonContainsExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,51 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "LogicalBinaryExpr.h"
namespace milvus {
namespace exec {
void
PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) {
AssertInfo(inputs_.size() == 2,
fmt::format("logical binary expr must has two input, but now {}",
inputs_.size()));
VectorPtr left;
inputs_[0]->Eval(context, left);
VectorPtr right;
inputs_[1]->Eval(context, right);
auto lflat = GetColumnVector(left);
auto rflat = GetColumnVector(right);
auto size = left->size();
bool* ldata = static_cast<bool*>(lflat->GetRawData());
bool* rdata = static_cast<bool*>(rflat->GetRawData());
if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::And) {
LogicalElementFunc<LogicalOpType::And> func;
func(ldata, rdata, size);
} else if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::Or) {
LogicalElementFunc<LogicalOpType::Or> func;
func(ldata, rdata, size);
} else {
PanicInfo(OpTypeInvalid,
fmt::format("unsupported logical operator: {}",
expr_->GetOpTypeString()));
}
result = std::move(left);
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,78 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
#include "simd/hook.h"
namespace milvus {
namespace exec {
enum class LogicalOpType { Invalid = 0, And = 1, Or = 2, Xor = 3, Minus = 4 };
template <LogicalOpType op>
struct LogicalElementFunc {
void
operator()(bool* left, bool* right, int n) {
#if defined(USE_DYNAMIC_SIMD)
if constexpr (op == LogicalOpType::And) {
milvus::simd::and_bool(left, right, n);
} else if constexpr (op == LogicalOpType::Or) {
milvus::simd::or_bool(left, right, n);
} else {
PanicInfo(OpTypeInvalid,
fmt::format("unsupported logical operator: {}", op));
}
#else
for (size_t i = 0; i < n; ++i) {
if constexpr (op == LogicalOpType::And) {
left[i] &= right[i];
} else if constexpr (op == LogicalOpType::Or) {
left[i] |= right[i];
} else {
PanicInfo(OpTypeInvalid,
fmt::format("unsupported logical operator: {}", op));
}
}
#endif
}
};
class PhyLogicalBinaryExpr : public Expr {
public:
PhyLogicalBinaryExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::LogicalBinaryExpr>& expr,
const std::string& name)
: Expr(DataType::BOOL, std::move(input), name), expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
std::shared_ptr<const milvus::expr::LogicalBinaryExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,44 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "LogicalUnaryExpr.h"
#include "simd/hook.h"
namespace milvus {
namespace exec {
void
PhyLogicalUnaryExpr::Eval(EvalCtx& context, VectorPtr& result) {
AssertInfo(inputs_.size() == 1,
fmt::format("logical unary expr must has one input, but now {}",
inputs_.size()));
inputs_[0]->Eval(context, result);
if (expr_->op_type_ == milvus::expr::LogicalUnaryExpr::OpType::LogicalNot) {
auto flat_vec = GetColumnVector(result);
bool* data = static_cast<bool*>(flat_vec->GetRawData());
#if defined(USE_DYNAMIC_SIMD)
milvus::simd::invert_bool(data, flat_vec->size());
#else
for (int i = 0; i < flat_vec->size(); ++i) {
data[i] = !data[i];
}
#endif
}
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,47 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
class PhyLogicalUnaryExpr : public Expr {
public:
PhyLogicalUnaryExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::LogicalUnaryExpr>& expr,
const std::string& name)
: Expr(DataType::BOOL, std::move(input), name), expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
std::shared_ptr<const milvus::expr::LogicalUnaryExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,540 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "TermExpr.h"
#include "query/Utils.h"
namespace milvus {
namespace exec {
void
PhyTermFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
if (is_pk_field_) {
result = ExecPkTermImpl();
return;
}
switch (expr_->column_.data_type_) {
case DataType::BOOL: {
result = ExecVisitorImpl<bool>();
break;
}
case DataType::INT8: {
result = ExecVisitorImpl<int8_t>();
break;
}
case DataType::INT16: {
result = ExecVisitorImpl<int16_t>();
break;
}
case DataType::INT32: {
result = ExecVisitorImpl<int32_t>();
break;
}
case DataType::INT64: {
result = ExecVisitorImpl<int64_t>();
break;
}
case DataType::FLOAT: {
result = ExecVisitorImpl<float>();
break;
}
case DataType::DOUBLE: {
result = ExecVisitorImpl<double>();
break;
}
case DataType::VARCHAR: {
if (segment_->type() == SegmentType::Growing) {
result = ExecVisitorImpl<std::string>();
} else {
result = ExecVisitorImpl<std::string_view>();
}
break;
}
case DataType::JSON: {
if (expr_->vals_.size() == 0) {
result = ExecVisitorImplTemplateJson<bool>();
break;
}
auto type = expr_->vals_[0].val_case();
switch (type) {
case proto::plan::GenericValue::ValCase::kBoolVal:
result = ExecVisitorImplTemplateJson<bool>();
break;
case proto::plan::GenericValue::ValCase::kInt64Val:
result = ExecVisitorImplTemplateJson<int64_t>();
break;
case proto::plan::GenericValue::ValCase::kFloatVal:
result = ExecVisitorImplTemplateJson<double>();
break;
case proto::plan::GenericValue::ValCase::kStringVal:
result = ExecVisitorImplTemplateJson<std::string>();
break;
default:
PanicInfo(DataTypeInvalid,
fmt::format("unknown data type: {}", type));
}
break;
}
case DataType::ARRAY: {
if (expr_->vals_.size() == 0) {
result = ExecVisitorImplTemplateArray<bool>();
break;
}
auto type = expr_->vals_[0].val_case();
switch (type) {
case proto::plan::GenericValue::ValCase::kBoolVal:
result = ExecVisitorImplTemplateArray<bool>();
break;
case proto::plan::GenericValue::ValCase::kInt64Val:
result = ExecVisitorImplTemplateArray<int64_t>();
break;
case proto::plan::GenericValue::ValCase::kFloatVal:
result = ExecVisitorImplTemplateArray<double>();
break;
case proto::plan::GenericValue::ValCase::kStringVal:
result = ExecVisitorImplTemplateArray<std::string>();
break;
default:
PanicInfo(DataTypeInvalid,
fmt::format("unknown data type: {}", type));
}
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}",
expr_->column_.data_type_));
}
}
void
PhyTermFilterExpr::InitPkCacheOffset() {
auto id_array = std::make_unique<IdArray>();
switch (pk_type_) {
case DataType::INT64: {
auto dst_ids = id_array->mutable_int_id();
for (const auto& id : expr_->vals_) {
dst_ids->add_data(GetValueFromProto<int64_t>(id));
}
break;
}
case DataType::VARCHAR: {
auto dst_ids = id_array->mutable_str_id();
for (const auto& id : expr_->vals_) {
dst_ids->add_data(GetValueFromProto<std::string>(id));
}
break;
}
default: {
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type {}", pk_type_));
}
}
auto [uids, seg_offsets] =
segment_->search_ids(*id_array, query_timestamp_);
cached_bits_.resize(num_rows_);
cached_offsets_ =
std::make_shared<ColumnVector>(DataType::INT64, seg_offsets.size());
int64_t* cached_offsets_ptr = (int64_t*)cached_offsets_->GetRawData();
int i = 0;
for (const auto& offset : seg_offsets) {
auto _offset = (int64_t)offset.get();
cached_bits_[_offset] = true;
cached_offsets_ptr[i++] = _offset;
}
cached_offsets_inited_ = true;
}
VectorPtr
PhyTermFilterExpr::ExecPkTermImpl() {
if (!cached_offsets_inited_) {
InitPkCacheOffset();
}
auto real_batch_size = current_data_chunk_pos_ + batch_size_ >= num_rows_
? num_rows_ - current_data_chunk_pos_
: batch_size_;
current_data_chunk_pos_ += real_batch_size;
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
for (size_t i = 0; i < real_batch_size; ++i) {
res[i] = cached_bits_[i];
}
std::vector<VectorPtr> vecs{res_vec, cached_offsets_};
return std::make_shared<RowVector>(vecs);
}
template <typename ValueType>
VectorPtr
PhyTermFilterExpr::ExecVisitorImplTemplateJson() {
if (expr_->is_in_field_) {
return ExecTermJsonVariableInField<ValueType>();
} else {
return ExecTermJsonFieldInVariable<ValueType>();
}
}
template <typename ValueType>
VectorPtr
PhyTermFilterExpr::ExecVisitorImplTemplateArray() {
if (expr_->is_in_field_) {
return ExecTermArrayVariableInField<ValueType>();
} else {
return ExecTermArrayFieldInVariable<ValueType>();
}
}
template <typename ValueType>
VectorPtr
PhyTermFilterExpr::ExecTermArrayVariableInField() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
AssertInfo(expr_->vals_.size() == 1,
"element length in json array must be one");
ValueType target_val = GetValueFromProto<ValueType>(expr_->vals_[0]);
auto execute_sub_batch = [](const ArrayView* data,
const int size,
bool* res,
const ValueType& target_val) {
auto executor = [&](size_t i) {
for (int i = 0; i < data[i].length(); i++) {
auto val = data[i].template get_data<GetType>(i);
if (val == target_val) {
return true;
}
}
return false;
};
for (int i = 0; i < size; ++i) {
executor(i);
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, target_val);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ValueType>
VectorPtr
PhyTermFilterExpr::ExecTermArrayFieldInVariable() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
int index = -1;
if (expr_->column_.nested_path_.size() > 0) {
index = std::stoi(expr_->column_.nested_path_[0]);
}
std::unordered_set<ValueType> term_set;
for (const auto& element : expr_->vals_) {
term_set.insert(GetValueFromProto<ValueType>(element));
}
if (term_set.empty()) {
for (size_t i = 0; i < real_batch_size; ++i) {
res[i] = false;
}
return res_vec;
}
auto execute_sub_batch = [](const ArrayView* data,
const int size,
bool* res,
int index,
const std::unordered_set<ValueType>& term_set) {
if (term_set.empty()) {
for (int i = 0; i < size; ++i) {
res[i] = false;
}
}
for (int i = 0; i < size; ++i) {
if (index >= data[i].length()) {
res[i] = false;
continue;
}
auto value = data[i].get_data<GetType>(index);
res[i] = term_set.find(ValueType(value)) != term_set.end();
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, index, term_set);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ValueType>
VectorPtr
PhyTermFilterExpr::ExecTermJsonVariableInField() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
AssertInfo(expr_->vals_.size() == 1,
"element length in json array must be one");
ValueType val = GetValueFromProto<ValueType>(expr_->vals_[0]);
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
auto execute_sub_batch = [](const Json* data,
const int size,
bool* res,
const std::string pointer,
const ValueType& target_val) {
auto executor = [&](size_t i) {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error())
return false;
for (auto it = array.begin(); it != array.end(); ++it) {
auto val = (*it).template get<GetType>();
if (val.error()) {
return false;
}
if (val.value() == target_val) {
return true;
}
}
return false;
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<milvus::Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, val);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ValueType>
VectorPtr
PhyTermFilterExpr::ExecTermJsonFieldInVariable() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
std::unordered_set<ValueType> term_set;
for (const auto& element : expr_->vals_) {
term_set.insert(GetValueFromProto<ValueType>(element));
}
if (term_set.empty()) {
for (size_t i = 0; i < real_batch_size; ++i) {
res[i] = false;
}
return res_vec;
}
auto execute_sub_batch = [](const Json* data,
const int size,
bool* res,
const std::string pointer,
const std::unordered_set<ValueType>& terms) {
auto executor = [&](size_t i) {
auto x = data[i].template at<GetType>(pointer);
if (x.error()) {
if constexpr (std::is_same_v<GetType, std::int64_t>) {
auto x = data[i].template at<double>(pointer);
if (x.error()) {
return false;
}
auto value = x.value();
// if the term set is {1}, and the value is 1.1, we should not return true.
return std::floor(value) == value &&
terms.find(ValueType(value)) != terms.end();
}
return false;
}
return terms.find(ValueType(x.value())) != terms.end();
};
for (size_t i = 0; i < size; ++i) {
res[i] = executor(i);
}
};
int64_t processed_size = ProcessDataChunks<milvus::Json>(
execute_sub_batch, std::nullptr_t{}, res, pointer, term_set);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename T>
VectorPtr
PhyTermFilterExpr::ExecVisitorImpl() {
if (is_index_mode_) {
return ExecVisitorImplForIndex<T>();
} else {
return ExecVisitorImplForData<T>();
}
}
template <typename T>
VectorPtr
PhyTermFilterExpr::ExecVisitorImplForIndex() {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
std::vector<IndexInnerType> vals;
for (auto& val : expr_->vals_) {
auto converted_val = GetValueFromProto<T>(val);
// Integral overflow process
if constexpr (std::is_integral_v<T>) {
if (milvus::query::out_of_range<T>(converted_val)) {
continue;
}
}
vals.emplace_back(converted_val);
}
auto execute_sub_batch = [](Index* index_ptr,
const std::vector<IndexInnerType>& vals) {
TermIndexFunc<T> func;
return func(index_ptr, vals.size(), vals.data());
};
auto res = ProcessIndexChunks<T>(execute_sub_batch, vals);
AssertInfo(res.size() == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
res.size(),
real_batch_size));
return std::make_shared<ColumnVector>(std::move(res));
}
template <>
VectorPtr
PhyTermFilterExpr::ExecVisitorImplForIndex<bool>() {
using Index = index::ScalarIndex<bool>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
std::vector<uint8_t> vals;
for (auto& val : expr_->vals_) {
vals.emplace_back(GetValueFromProto<bool>(val) ? 1 : 0);
}
auto execute_sub_batch = [](Index* index_ptr,
const std::vector<uint8_t>& vals) {
TermIndexFunc<bool> func;
return std::move(func(index_ptr, vals.size(), (bool*)vals.data()));
};
auto res = ProcessIndexChunks<bool>(execute_sub_batch, vals);
return std::make_shared<ColumnVector>(std::move(res));
}
template <typename T>
VectorPtr
PhyTermFilterExpr::ExecVisitorImplForData() {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
std::vector<T> vals;
for (auto& val : expr_->vals_) {
// Integral overflow process
bool overflowed = false;
auto converted_val = GetValueFromProtoWithOverflow<T>(val, overflowed);
if (!overflowed) {
vals.emplace_back(converted_val);
}
}
std::unordered_set<T> vals_set(vals.begin(), vals.end());
auto execute_sub_batch = [](const T* data,
const int size,
bool* res,
const std::unordered_set<T>& vals) {
TermElementFuncSet<T> func;
for (size_t i = 0; i < size; ++i) {
res[i] = func(vals, data[i]);
}
};
int64_t processed_size = ProcessDataChunks<T>(
execute_sub_batch, std::nullptr_t{}, res, vals_set);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,134 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
template <typename T>
struct TermElementFuncFlat {
bool
operator()(const T* src, size_t n, T val) {
for (size_t i = 0; i < n; ++i) {
if (src[i] == val) {
return true;
}
}
}
};
template <typename T>
struct TermElementFuncSet {
bool
operator()(const std::unordered_set<T>& srcs, T val) {
return srcs.find(val) != srcs.end();
}
};
template <typename T>
struct TermIndexFunc {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
FixedVector<bool>
operator()(Index* index, size_t n, const IndexInnerType* val) {
return index->In(n, val);
}
};
class PhyTermFilterExpr : public SegmentExpr {
public:
PhyTermFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::TermFilterExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: SegmentExpr(std::move(input),
name,
segment,
expr->column_.field_id_,
query_timestamp,
batch_size),
expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
void
InitPkCacheOffset();
VectorPtr
ExecPkTermImpl();
template <typename T>
VectorPtr
ExecVisitorImpl();
template <typename T>
VectorPtr
ExecVisitorImplForIndex();
template <typename T>
VectorPtr
ExecVisitorImplForData();
template <typename ValueType>
VectorPtr
ExecVisitorImplTemplateJson();
template <typename ValueType>
VectorPtr
ExecTermJsonVariableInField();
template <typename ValueType>
VectorPtr
ExecTermJsonFieldInVariable();
template <typename ValueType>
VectorPtr
ExecVisitorImplTemplateArray();
template <typename ValueType>
VectorPtr
ExecTermArrayVariableInField();
template <typename ValueType>
VectorPtr
ExecTermArrayFieldInVariable();
private:
std::shared_ptr<const milvus::expr::TermFilterExpr> expr_;
// If expr is like "pk in (..)", can use pk index to optimize
bool cached_offsets_inited_{false};
ColumnVectorPtr cached_offsets_;
FixedVector<bool> cached_bits_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,593 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "UnaryExpr.h"
#include "common/Json.h"
namespace milvus {
namespace exec {
void
PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
switch (expr_->column_.data_type_) {
case DataType::BOOL: {
result = ExecRangeVisitorImpl<bool>();
break;
}
case DataType::INT8: {
result = ExecRangeVisitorImpl<int8_t>();
break;
}
case DataType::INT16: {
result = ExecRangeVisitorImpl<int16_t>();
break;
}
case DataType::INT32: {
result = ExecRangeVisitorImpl<int32_t>();
break;
}
case DataType::INT64: {
result = ExecRangeVisitorImpl<int64_t>();
break;
}
case DataType::FLOAT: {
result = ExecRangeVisitorImpl<float>();
break;
}
case DataType::DOUBLE: {
result = ExecRangeVisitorImpl<double>();
break;
}
case DataType::VARCHAR: {
if (segment_->type() == SegmentType::Growing) {
result = ExecRangeVisitorImpl<std::string>();
} else {
result = ExecRangeVisitorImpl<std::string_view>();
}
break;
}
case DataType::JSON: {
auto val_type = expr_->val_.val_case();
switch (val_type) {
case proto::plan::GenericValue::ValCase::kBoolVal:
result = ExecRangeVisitorImplJson<bool>();
break;
case proto::plan::GenericValue::ValCase::kInt64Val:
result = ExecRangeVisitorImplJson<int64_t>();
break;
case proto::plan::GenericValue::ValCase::kFloatVal:
result = ExecRangeVisitorImplJson<double>();
break;
case proto::plan::GenericValue::ValCase::kStringVal:
result = ExecRangeVisitorImplJson<std::string>();
break;
case proto::plan::GenericValue::ValCase::kArrayVal:
result = ExecRangeVisitorImplJson<proto::plan::Array>();
break;
default:
PanicInfo(DataTypeInvalid,
fmt::format("unknown data type: {}", val_type));
}
break;
}
case DataType::ARRAY: {
auto val_type = expr_->val_.val_case();
switch (val_type) {
case proto::plan::GenericValue::ValCase::kBoolVal:
result = ExecRangeVisitorImplArray<bool>();
break;
case proto::plan::GenericValue::ValCase::kInt64Val:
result = ExecRangeVisitorImplArray<int64_t>();
break;
case proto::plan::GenericValue::ValCase::kFloatVal:
result = ExecRangeVisitorImplArray<double>();
break;
case proto::plan::GenericValue::ValCase::kStringVal:
result = ExecRangeVisitorImplArray<std::string>();
break;
case proto::plan::GenericValue::ValCase::kArrayVal:
result = ExecRangeVisitorImplArray<proto::plan::Array>();
break;
default:
PanicInfo(DataTypeInvalid,
fmt::format("unknown data type: {}", val_type));
}
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type: {}",
expr_->column_.data_type_));
}
}
template <typename ValueType>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
ValueType val = GetValueFromProto<ValueType>(expr_->val_);
auto op_type = expr_->op_type_;
int index = -1;
if (expr_->column_.nested_path_.size() > 0) {
index = std::stoi(expr_->column_.nested_path_[0]);
}
auto execute_sub_batch = [op_type](const milvus::ArrayView* data,
const int size,
bool* res,
ValueType val,
int index) {
switch (op_type) {
case proto::plan::GreaterThan: {
UnaryElementFuncForArray<ValueType, proto::plan::GreaterThan>
func;
func(data, size, val, index, res);
break;
}
case proto::plan::GreaterEqual: {
UnaryElementFuncForArray<ValueType, proto::plan::GreaterEqual>
func;
func(data, size, val, index, res);
break;
}
case proto::plan::LessThan: {
UnaryElementFuncForArray<ValueType, proto::plan::LessThan> func;
func(data, size, val, index, res);
break;
}
case proto::plan::LessEqual: {
UnaryElementFuncForArray<ValueType, proto::plan::LessEqual>
func;
func(data, size, val, index, res);
break;
}
case proto::plan::Equal: {
UnaryElementFuncForArray<ValueType, proto::plan::Equal> func;
func(data, size, val, index, res);
break;
}
case proto::plan::NotEqual: {
UnaryElementFuncForArray<ValueType, proto::plan::NotEqual> func;
func(data, size, val, index, res);
break;
}
case proto::plan::PrefixMatch: {
UnaryElementFuncForArray<ValueType, proto::plan::PrefixMatch>
func;
func(data, size, val, index, res);
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported operator type for unary expr: {}",
op_type));
}
};
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
execute_sub_batch, std::nullptr_t{}, res, val, index);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename ExprValueType>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() {
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
ExprValueType val = GetValueFromProto<ExprValueType>(expr_->val_);
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto op_type = expr_->op_type_;
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
#define UnaryRangeJSONCompare(cmp) \
do { \
auto x = data[i].template at<GetType>(pointer); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = data[i].template at<double>(pointer); \
res[i] = !x.error() && (cmp); \
break; \
} \
res[i] = false; \
break; \
} \
res[i] = (cmp); \
} while (false)
#define UnaryRangeJSONCompareNotEqual(cmp) \
do { \
auto x = data[i].template at<GetType>(pointer); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = data[i].template at<double>(pointer); \
res[i] = x.error() || (cmp); \
break; \
} \
res[i] = true; \
break; \
} \
res[i] = (cmp); \
} while (false)
auto execute_sub_batch = [op_type, pointer](const milvus::Json* data,
const int size,
bool* res,
ExprValueType val) {
switch (op_type) {
case proto::plan::GreaterThan: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(x.value() > val);
}
}
break;
}
case proto::plan::GreaterEqual: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(x.value() >= val);
}
}
break;
}
case proto::plan::LessThan: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(x.value() < val);
}
}
break;
}
case proto::plan::LessEqual: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(x.value() <= val);
}
}
break;
}
case proto::plan::Equal: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
res[i] = false;
continue;
}
res[i] = CompareTwoJsonArray(array, val);
} else {
UnaryRangeJSONCompare(x.value() == val);
}
}
break;
}
case proto::plan::NotEqual: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
auto doc = data[i].doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
res[i] = false;
continue;
}
res[i] = !CompareTwoJsonArray(array, val);
} else {
UnaryRangeJSONCompareNotEqual(x.value() != val);
}
}
break;
}
case proto::plan::PrefixMatch: {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(milvus::query::Match(
ExprValueType(x.value()), val, op_type));
}
}
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported operator type for unary expr: {}",
op_type));
}
};
int64_t processed_size = ProcessDataChunks<milvus::Json>(
execute_sub_batch, std::nullptr_t{}, res, val);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
template <typename T>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImpl() {
if (is_index_mode_) {
return ExecRangeVisitorImplForIndex<T>();
} else {
return ExecRangeVisitorImplForData<T>();
}
}
template <typename T>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForIndex() {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
if (auto res = PreCheckOverflow<T>()) {
return res;
}
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto op_type = expr_->op_type_;
auto execute_sub_batch = [op_type](Index* index_ptr, IndexInnerType val) {
FixedVector<bool> res;
switch (op_type) {
case proto::plan::GreaterThan: {
UnaryIndexFunc<T, proto::plan::GreaterThan> func;
res = std::move(func(index_ptr, val));
break;
}
case proto::plan::GreaterEqual: {
UnaryIndexFunc<T, proto::plan::GreaterEqual> func;
res = std::move(func(index_ptr, val));
break;
}
case proto::plan::LessThan: {
UnaryIndexFunc<T, proto::plan::LessThan> func;
res = std::move(func(index_ptr, val));
break;
}
case proto::plan::LessEqual: {
UnaryIndexFunc<T, proto::plan::LessEqual> func;
res = std::move(func(index_ptr, val));
break;
}
case proto::plan::Equal: {
UnaryIndexFunc<T, proto::plan::Equal> func;
res = std::move(func(index_ptr, val));
break;
}
case proto::plan::NotEqual: {
UnaryIndexFunc<T, proto::plan::NotEqual> func;
res = std::move(func(index_ptr, val));
break;
}
case proto::plan::PrefixMatch: {
UnaryIndexFunc<T, proto::plan::PrefixMatch> func;
res = std::move(func(index_ptr, val));
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported operator type for unary expr: {}",
op_type));
}
return res;
};
auto val = GetValueFromProto<IndexInnerType>(expr_->val_);
auto res = ProcessIndexChunks<T>(execute_sub_batch, val);
AssertInfo(res.size() == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
res.size(),
real_batch_size));
return std::make_shared<ColumnVector>(std::move(res));
}
template <typename T>
ColumnVectorPtr
PhyUnaryRangeFilterExpr::PreCheckOverflow() {
if constexpr (std::is_integral_v<T> && !std::is_same_v<T, bool>) {
int64_t val = GetValueFromProto<int64_t>(expr_->val_);
if (milvus::query::out_of_range<T>(val)) {
int64_t batch_size = overflow_check_pos_ + batch_size_ >= num_rows_
? num_rows_ - overflow_check_pos_
: batch_size_;
overflow_check_pos_ += batch_size;
if (cached_overflow_res_ != nullptr &&
cached_overflow_res_->size() == batch_size) {
return cached_overflow_res_;
}
switch (expr_->op_type_) {
case proto::plan::GreaterThan:
case proto::plan::GreaterEqual: {
auto res_vec = std::make_shared<ColumnVector>(
DataType::BOOL, batch_size);
cached_overflow_res_ = res_vec;
bool* res = (bool*)res_vec->GetRawData();
if (milvus::query::lt_lb<T>(val)) {
for (size_t i = 0; i < batch_size; ++i) {
res[i] = true;
}
return res_vec;
}
return res_vec;
}
case proto::plan::LessThan:
case proto::plan::LessEqual: {
auto res_vec = std::make_shared<ColumnVector>(
DataType::BOOL, batch_size);
cached_overflow_res_ = res_vec;
bool* res = (bool*)res_vec->GetRawData();
if (milvus::query::gt_ub<T>(val)) {
for (size_t i = 0; i < batch_size; ++i) {
res[i] = true;
}
return res_vec;
}
return res_vec;
}
case proto::plan::Equal: {
auto res_vec = std::make_shared<ColumnVector>(
DataType::BOOL, batch_size);
cached_overflow_res_ = res_vec;
bool* res = (bool*)res_vec->GetRawData();
for (size_t i = 0; i < batch_size; ++i) {
res[i] = false;
}
return res_vec;
}
case proto::plan::NotEqual: {
auto res_vec = std::make_shared<ColumnVector>(
DataType::BOOL, batch_size);
cached_overflow_res_ = res_vec;
bool* res = (bool*)res_vec->GetRawData();
for (size_t i = 0; i < batch_size; ++i) {
res[i] = true;
}
return res_vec;
}
default: {
PanicInfo(OpTypeInvalid,
fmt::format("unsupported range node {}",
expr_->op_type_));
}
}
}
}
return nullptr;
}
template <typename T>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
if (auto res = PreCheckOverflow<T>()) {
return res;
}
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
IndexInnerType val = GetValueFromProto<IndexInnerType>(expr_->val_);
auto res_vec =
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
bool* res = (bool*)res_vec->GetRawData();
auto expr_type = expr_->op_type_;
auto execute_sub_batch = [expr_type](const T* data,
const int size,
bool* res,
IndexInnerType val) {
switch (expr_type) {
case proto::plan::GreaterThan: {
UnaryElementFunc<T, proto::plan::GreaterThan> func;
func(data, size, val, res);
break;
}
case proto::plan::GreaterEqual: {
UnaryElementFunc<T, proto::plan::GreaterEqual> func;
func(data, size, val, res);
break;
}
case proto::plan::LessThan: {
UnaryElementFunc<T, proto::plan::LessThan> func;
func(data, size, val, res);
break;
}
case proto::plan::LessEqual: {
UnaryElementFunc<T, proto::plan::LessEqual> func;
func(data, size, val, res);
break;
}
case proto::plan::Equal: {
UnaryElementFunc<T, proto::plan::Equal> func;
func(data, size, val, res);
break;
}
case proto::plan::NotEqual: {
UnaryElementFunc<T, proto::plan::NotEqual> func;
func(data, size, val, res);
break;
}
case proto::plan::PrefixMatch: {
UnaryElementFunc<T, proto::plan::PrefixMatch> func;
func(data, size, val, res);
break;
}
default:
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported operator type for unary expr: {}",
expr_type));
}
};
auto skip_index_func = [expr_type, val](const SkipIndex& skip_index,
FieldId field_id,
int64_t chunk_id) {
return skip_index.CanSkipUnaryRange<T>(
field_id, chunk_id, expr_type, val);
};
int64_t processed_size =
ProcessDataChunks<T>(execute_sub_batch, skip_index_func, res, val);
AssertInfo(processed_size == real_batch_size,
fmt::format("internal error: expr processed rows {} not equal "
"expect batch size {}",
processed_size,
real_batch_size));
return res_vec;
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,220 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "index/Meta.h"
#include "segcore/SegmentInterface.h"
#include "query/Utils.h"
namespace milvus {
namespace exec {
template <typename T, proto::plan::OpType op>
struct UnaryElementFunc {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
void
operator()(const T* src, size_t size, IndexInnerType val, bool* res) {
for (int i = 0; i < size; ++i) {
if constexpr (op == proto::plan::OpType::Equal) {
res[i] = src[i] == val;
} else if constexpr (op == proto::plan::OpType::NotEqual) {
res[i] = src[i] != val;
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
res[i] = src[i] > val;
} else if constexpr (op == proto::plan::OpType::LessThan) {
res[i] = src[i] < val;
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
res[i] = src[i] >= val;
} else if constexpr (op == proto::plan::OpType::LessEqual) {
res[i] = src[i] <= val;
} else if constexpr (op == proto::plan::OpType::PrefixMatch) {
res[i] = milvus::query::Match(
src[i], val, proto::plan::OpType::PrefixMatch);
} else {
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported op_type:{} for UnaryElementFunc",
op));
}
}
}
};
#define UnaryArrayCompare(cmp) \
do { \
if constexpr (std::is_same_v<GetType, proto::plan::Array>) { \
res[i] = false; \
} else { \
if (index >= src[i].length()) { \
res[i] = false; \
continue; \
} \
auto array_data = src[i].template get_data<GetType>(index); \
res[i] = (cmp); \
} \
} while (false)
template <typename ValueType, proto::plan::OpType op>
struct UnaryElementFuncForArray {
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
std::string_view,
ValueType>;
void
operator()(const ArrayView* src,
size_t size,
ValueType val,
int index,
bool* res) {
for (int i = 0; i < size; ++i) {
if constexpr (op == proto::plan::OpType::Equal) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = src[i].is_same_array(val);
} else {
if (index >= src[i].length()) {
res[i] = false;
continue;
}
auto array_data = src[i].template get_data<GetType>(index);
res[i] = array_data == val;
}
} else if constexpr (op == proto::plan::OpType::NotEqual) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = !src[i].is_same_array(val);
} else {
if (index >= src[i].length()) {
res[i] = false;
continue;
}
auto array_data = src[i].template get_data<GetType>(index);
res[i] = array_data != val;
}
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
UnaryArrayCompare(array_data > val);
} else if constexpr (op == proto::plan::OpType::LessThan) {
UnaryArrayCompare(array_data < val);
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
UnaryArrayCompare(array_data >= val);
} else if constexpr (op == proto::plan::OpType::LessEqual) {
UnaryArrayCompare(array_data <= val);
} else if constexpr (op == proto::plan::OpType::PrefixMatch) {
UnaryArrayCompare(milvus::query::Match(array_data, val, op));
} else {
PanicInfo(OpTypeInvalid,
fmt::format("unsupported op_type:{} for "
"UnaryElementFuncForArray",
op));
}
}
}
};
template <typename T, proto::plan::OpType op>
struct UnaryIndexFunc {
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
using Index = index::ScalarIndex<IndexInnerType>;
FixedVector<bool>
operator()(Index* index, IndexInnerType val) {
if constexpr (op == proto::plan::OpType::Equal) {
return index->In(1, &val);
} else if constexpr (op == proto::plan::OpType::NotEqual) {
return index->NotIn(1, &val);
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
return index->Range(val, OpType::GreaterThan);
} else if constexpr (op == proto::plan::OpType::LessThan) {
return index->Range(val, OpType::LessThan);
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
return index->Range(val, OpType::GreaterEqual);
} else if constexpr (op == proto::plan::OpType::LessEqual) {
return index->Range(val, OpType::LessEqual);
} else if constexpr (op == proto::plan::OpType::PrefixMatch) {
auto dataset = std::make_unique<Dataset>();
dataset->Set(milvus::index::OPERATOR_TYPE,
proto::plan::OpType::PrefixMatch);
dataset->Set(milvus::index::PREFIX_VALUE, val);
return index->Query(std::move(dataset));
} else {
PanicInfo(
OpTypeInvalid,
fmt::format("unsupported op_type:{} for UnaryIndexFunc", op));
}
}
};
class PhyUnaryRangeFilterExpr : public SegmentExpr {
public:
PhyUnaryRangeFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::UnaryRangeFilterExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
Timestamp query_timestamp,
int64_t batch_size)
: SegmentExpr(std::move(input),
name,
segment,
expr->column_.field_id_,
query_timestamp,
batch_size),
expr_(expr) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
private:
template <typename T>
VectorPtr
ExecRangeVisitorImpl();
template <typename T>
VectorPtr
ExecRangeVisitorImplForIndex();
template <typename T>
VectorPtr
ExecRangeVisitorImplForData();
template <typename ExprValueType>
VectorPtr
ExecRangeVisitorImplJson();
template <typename ExprValueType>
VectorPtr
ExecRangeVisitorImplArray();
// Check overflow and cache result for performace
template <typename T>
ColumnVectorPtr
PreCheckOverflow();
private:
std::shared_ptr<const milvus::expr::UnaryRangeFilterExpr> expr_;
ColumnVectorPtr cached_overflow_res_{nullptr};
int64_t overflow_check_pos_{0};
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,166 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
#include "query/Utils.h"
namespace milvus {
namespace exec {
static ColumnVectorPtr
GetColumnVector(const VectorPtr& result) {
ColumnVectorPtr res;
if (auto convert_vector = std::dynamic_pointer_cast<ColumnVector>(result)) {
res = convert_vector;
} else if (auto convert_vector =
std::dynamic_pointer_cast<RowVector>(result)) {
if (auto convert_flat_vector = std::dynamic_pointer_cast<ColumnVector>(
convert_vector->child(0))) {
res = convert_flat_vector;
} else {
PanicInfo(
UnexpectedError,
"RowVector result must have a first ColumnVector children");
}
} else {
PanicInfo(UnexpectedError,
"expr result must have a ColumnVector or RowVector result");
}
return res;
}
template <typename T>
bool
CompareTwoJsonArray(T arr1, const proto::plan::Array& arr2) {
int json_array_length = 0;
if constexpr (std::is_same_v<
T,
simdjson::simdjson_result<simdjson::ondemand::array>>) {
json_array_length = arr1.count_elements();
}
if constexpr (std::is_same_v<T,
std::vector<simdjson::simdjson_result<
simdjson::ondemand::value>>>) {
json_array_length = arr1.size();
}
if (arr2.array_size() != json_array_length) {
return false;
}
int i = 0;
for (auto&& it : arr1) {
switch (arr2.array(i).val_case()) {
case proto::plan::GenericValue::kBoolVal: {
auto val = it.template get<bool>();
if (val.error() || val.value() != arr2.array(i).bool_val()) {
return false;
}
break;
}
case proto::plan::GenericValue::kInt64Val: {
auto val = it.template get<int64_t>();
if (val.error() || val.value() != arr2.array(i).int64_val()) {
return false;
}
break;
}
case proto::plan::GenericValue::kFloatVal: {
auto val = it.template get<double>();
if (val.error() || val.value() != arr2.array(i).float_val()) {
return false;
}
break;
}
case proto::plan::GenericValue::kStringVal: {
auto val = it.template get<std::string_view>();
if (val.error() || val.value() != arr2.array(i).string_val()) {
return false;
}
break;
}
default:
PanicInfo(DataTypeInvalid,
fmt::format("unsupported data type {}",
arr2.array(i).val_case()));
}
i++;
}
return true;
}
template <typename T>
T
GetValueFromProtoInternal(const milvus::proto::plan::GenericValue& value_proto,
bool& overflowed) {
if constexpr (std::is_same_v<T, bool>) {
Assert(value_proto.val_case() ==
milvus::proto::plan::GenericValue::kBoolVal);
return static_cast<T>(value_proto.bool_val());
} else if constexpr (std::is_integral_v<T>) {
Assert(value_proto.val_case() ==
milvus::proto::plan::GenericValue::kInt64Val);
auto val = value_proto.int64_val();
if (milvus::query::out_of_range<T>(val)) {
overflowed = true;
return T();
} else {
return static_cast<T>(val);
}
} else if constexpr (std::is_floating_point_v<T>) {
Assert(value_proto.val_case() ==
milvus::proto::plan::GenericValue::kFloatVal);
return static_cast<T>(value_proto.float_val());
} else if constexpr (std::is_same_v<T, std::string> ||
std::is_same_v<T, std::string_view>) {
Assert(value_proto.val_case() ==
milvus::proto::plan::GenericValue::kStringVal);
return static_cast<T>(value_proto.string_val());
} else if constexpr (std::is_same_v<T, proto::plan::Array>) {
Assert(value_proto.val_case() ==
milvus::proto::plan::GenericValue::kArrayVal);
return static_cast<T>(value_proto.array_val());
} else if constexpr (std::is_same_v<T, milvus::proto::plan::GenericValue>) {
return static_cast<T>(value_proto);
} else {
PanicInfo(Unsupported,
fmt::format("unsupported generic value {}",
value_proto.DebugString()));
}
}
template <typename T>
T
GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) {
bool dummy_overflowed = false;
return GetValueFromProtoInternal<T>(value_proto, dummy_overflowed);
}
template <typename T>
T
GetValueFromProtoWithOverflow(
const milvus::proto::plan::GenericValue& value_proto, bool& overflowed) {
return GetValueFromProtoInternal<T>(value_proto, overflowed);
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,47 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "common/Vector.h"
#include "exec/QueryContext.h"
namespace milvus {
namespace exec {
class VectorFunction {
public:
virtual ~VectorFunction() = default;
virtual void
Apply(std::vector<VectorPtr>& args,
DataType output_type,
EvalCtx& context,
VectorPtr& result) const = 0;
};
std::shared_ptr<VectorFunction>
GetVectorFunction(const std::string& name,
const std::vector<DataType>& input_types,
const QueryConfig& config);
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,89 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "exec/operator/Operator.h"
namespace milvus {
namespace exec {
class CallbackSink : public Operator {
public:
CallbackSink(
int32_t operator_id,
DriverContext* ctx,
std::function<BlockingReason(RowVectorPtr, ContinueFuture*)> callback)
: Operator(ctx, DataType::NONE, operator_id, "N/A", "CallbackSink"),
callback_(callback) {
}
void
AddInput(RowVectorPtr& input) override {
blocking_reason_ = callback_(input, &future_);
}
RowVectorPtr
GetOutput() override {
return nullptr;
}
void
NoMoreInput() override {
Operator::NoMoreInput();
Close();
}
bool
NeedInput() const override {
return callback_ != nullptr;
}
bool
IsFilter() override {
return false;
}
bool
IsFinished() override {
return no_more_input_;
}
BlockingReason
IsBlocked(ContinueFuture* future) override {
if (blocking_reason_ != BlockingReason::kNotBlocked) {
*future = std::move(future_);
blocking_reason_ = BlockingReason::kNotBlocked;
return BlockingReason::kWaitForConsumer;
}
return BlockingReason::kNotBlocked;
}
private:
void
Close() override {
if (callback_) {
callback_(nullptr, nullptr);
callback_ = nullptr;
}
}
ContinueFuture future_;
BlockingReason blocking_reason_{BlockingReason::kNotBlocked};
std::function<BlockingReason(RowVectorPtr, ContinueFuture*)> callback_;
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,83 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "FilterBits.h"
namespace milvus {
namespace exec {
FilterBits::FilterBits(
int32_t operator_id,
DriverContext* driverctx,
const std::shared_ptr<const plan::FilterBitsNode>& filter)
: Operator(driverctx,
filter->output_type(),
operator_id,
filter->id(),
"FilterBits") {
ExecContext* exec_context = operator_context_->get_exec_context();
QueryContext* query_context = exec_context->get_query_context();
std::vector<expr::TypedExprPtr> filters;
filters.emplace_back(filter->filter());
exprs_ = std::make_unique<ExprSet>(filters, exec_context);
need_process_rows_ = query_context->get_segment()->get_active_count(
query_context->get_query_timestamp());
num_processed_rows_ = 0;
}
void
FilterBits::AddInput(RowVectorPtr& input) {
input_ = std::move(input);
}
bool
FilterBits::AllInputProcessed() {
if (num_processed_rows_ == need_process_rows_) {
input_ = nullptr;
return true;
}
return false;
}
bool
FilterBits::IsFinished() {
return AllInputProcessed();
}
RowVectorPtr
FilterBits::GetOutput() {
if (AllInputProcessed()) {
return nullptr;
}
EvalCtx eval_ctx(
operator_context_->get_exec_context(), exprs_.get(), input_.get());
exprs_->Eval(0, 1, true, eval_ctx, results_);
AssertInfo(results_.size() == 1 && results_[0] != nullptr,
"FilterBits result size should be one and not be nullptr");
if (results_[0]->type() == DataType::ROW) {
auto row_vec = std::dynamic_pointer_cast<RowVector>(results_[0]);
num_processed_rows_ += row_vec->child(0)->size();
} else {
num_processed_rows_ += results_[0]->size();
}
return std::make_shared<RowVector>(results_);
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,74 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "exec/Driver.h"
#include "exec/expression/Expr.h"
#include "exec/operator/Operator.h"
#include "exec/QueryContext.h"
namespace milvus {
namespace exec {
class FilterBits : public Operator {
public:
FilterBits(int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::FilterBitsNode>& filter);
bool
IsFilter() override {
return true;
}
bool
NeedInput() const override {
return !input_;
}
void
AddInput(RowVectorPtr& input) override;
RowVectorPtr
GetOutput() override;
bool
IsFinished() override;
void
Close() override {
Operator::Close();
exprs_->Clear();
}
BlockingReason
IsBlocked(ContinueFuture* /* unused */) override {
return BlockingReason::kNotBlocked;
}
bool
AllInputProcessed();
private:
std::unique_ptr<ExprSet> exprs_;
int64_t num_processed_rows_;
int64_t need_process_rows_;
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,21 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Operator.h"
namespace milvus {
namespace exec {}
} // namespace milvus

View File

@ -0,0 +1,197 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/Driver.h"
#include "exec/Task.h"
#include "exec/QueryContext.h"
#include "plan/PlanNode.h"
namespace milvus {
namespace exec {
class OperatorContext {
public:
OperatorContext(DriverContext* driverCtx,
const plan::PlanNodeId& plannodeid,
int32_t operator_id,
const std::string& operator_type = "")
: driver_context_(driverCtx),
plannode_id_(plannodeid),
operator_id_(operator_id),
operator_type_(operator_type) {
}
ExecContext*
get_exec_context() const {
if (!exec_context_) {
exec_context_ = std::make_unique<ExecContext>(
driver_context_->task_->query_context().get());
}
return exec_context_.get();
}
const std::shared_ptr<Task>&
get_task() const {
return driver_context_->task_;
}
const std::string&
get_task_id() const {
return driver_context_->task_->taskid();
}
DriverContext*
get_driver_context() const {
return driver_context_;
}
const plan::PlanNodeId&
get_plannode_id() const {
return plannode_id_;
}
const std::string&
get_operator_type() const {
return operator_type_;
}
const int32_t
get_operator_id() const {
return operator_id_;
}
private:
DriverContext* driver_context_;
plan::PlanNodeId plannode_id_;
int32_t operator_id_;
std::string operator_type_;
mutable std::unique_ptr<ExecContext> exec_context_;
};
class Operator {
public:
Operator(DriverContext* ctx,
DataType output_type,
int32_t operator_id,
const std::string& plannode_id,
const std::string& operator_type = "")
: operator_context_(std::make_unique<OperatorContext>(
ctx, plannode_id, operator_id, operator_type)) {
}
virtual ~Operator() = default;
virtual bool
NeedInput() const = 0;
virtual void
AddInput(RowVectorPtr& input) = 0;
virtual void
NoMoreInput() {
no_more_input_ = true;
}
virtual RowVectorPtr
GetOutput() = 0;
virtual bool
IsFinished() = 0;
virtual bool
IsFilter() = 0;
virtual BlockingReason
IsBlocked(ContinueFuture* future) = 0;
virtual void
Close() {
input_ = nullptr;
results_.clear();
}
virtual bool
PreserveOrder() const {
return false;
}
const std::string&
get_operator_type() const {
return operator_context_->get_operator_type();
}
const int32_t
get_operator_id() const {
return operator_context_->get_operator_id();
}
const plan::PlanNodeId&
get_plannode_id() const {
return operator_context_->get_plannode_id();
}
protected:
std::unique_ptr<OperatorContext> operator_context_;
DataType output_type_;
RowVectorPtr input_;
bool no_more_input_{false};
std::vector<VectorPtr> results_;
};
class SourceOperator : public Operator {
public:
SourceOperator(DriverContext* driver_ctx,
DataType out_type,
int32_t operator_id,
const std::string& plannode_id,
const std::string& operator_type)
: Operator(
driver_ctx, out_type, operator_id, plannode_id, operator_type) {
}
bool
NeedInput() const override {
return false;
}
void
AddInput(RowVectorPtr& /* unused */) override {
throw NotImplementedException(
"SourceOperator does not support addInput()");
}
void
NoMoreInput() override {
throw NotImplementedException(
"SourceOperator does not support noMoreInput()");
}
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,557 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include <fmt/core.h>
#include "common/Schema.h"
#include "common/Types.h"
#include "pb/plan.pb.h"
namespace milvus {
namespace expr {
struct ColumnInfo {
FieldId field_id_;
DataType data_type_;
std::vector<std::string> nested_path_;
ColumnInfo(const proto::plan::ColumnInfo& column_info)
: field_id_(column_info.field_id()),
data_type_(static_cast<DataType>(column_info.data_type())),
nested_path_(column_info.nested_path().begin(),
column_info.nested_path().end()) {
}
ColumnInfo(FieldId field_id,
DataType data_type,
std::vector<std::string> nested_path = {})
: field_id_(field_id),
data_type_(data_type),
nested_path_(std::move(nested_path)) {
}
bool
operator==(const ColumnInfo& other) {
if (field_id_ != other.field_id_) {
return false;
}
if (data_type_ != other.data_type_) {
return false;
}
for (int i = 0; i < nested_path_.size(); ++i) {
if (nested_path_[i] != other.nested_path_[i]) {
return false;
}
}
return true;
}
std::string
ToString() const {
return fmt::format("[FieldId:{}, data_type:{}, nested_path:{}]",
std::to_string(field_id_.get()),
data_type_,
milvus::Join(nested_path_, ","));
}
};
/**
* @brief Base class for all exprs
* a strongly-typed expression, such as literal, function call, etc...
*/
class ITypeExpr {
public:
explicit ITypeExpr(DataType type) : type_(type), inputs_{} {
}
ITypeExpr(DataType type,
std::vector<std::shared_ptr<const ITypeExpr>> inputs)
: type_(type), inputs_{std::move(inputs)} {
}
virtual ~ITypeExpr() = default;
const std::vector<std::shared_ptr<const ITypeExpr>>&
inputs() const {
return inputs_;
}
DataType
type() const {
return type_;
}
virtual std::string
ToString() const = 0;
const std::vector<std::shared_ptr<const ITypeExpr>>&
inputs() {
return inputs_;
}
protected:
DataType type_;
std::vector<std::shared_ptr<const ITypeExpr>> inputs_;
};
using TypedExprPtr = std::shared_ptr<const ITypeExpr>;
class InputTypeExpr : public ITypeExpr {
public:
InputTypeExpr(DataType type) : ITypeExpr(type) {
}
std::string
ToString() const override {
return "ROW";
}
};
using InputTypeExprPtr = std::shared_ptr<const InputTypeExpr>;
class CallTypeExpr : public ITypeExpr {
public:
CallTypeExpr(DataType type,
const std::vector<TypedExprPtr>& inputs,
std::string fun_name)
: ITypeExpr{type, std::move(inputs)} {
}
virtual ~CallTypeExpr() = default;
virtual const std::string&
name() const {
return name_;
}
std::string
ToString() const override {
std::string str{};
str += name();
str += "(";
for (size_t i = 0; i < inputs_.size(); ++i) {
if (i != 0) {
str += ",";
}
str += inputs_[i]->ToString();
}
str += ")";
return str;
}
private:
std::string name_;
};
using CallTypeExprPtr = std::shared_ptr<const CallTypeExpr>;
class FieldAccessTypeExpr : public ITypeExpr {
public:
FieldAccessTypeExpr(DataType type, const std::string& name)
: ITypeExpr{type}, name_(name), is_input_column_(true) {
}
FieldAccessTypeExpr(DataType type,
const TypedExprPtr& input,
const std::string& name)
: ITypeExpr{type, {std::move(input)}}, name_(name) {
is_input_column_ =
dynamic_cast<const InputTypeExpr*>(inputs_[0].get()) != nullptr;
}
bool
is_input_column() const {
return is_input_column_;
}
std::string
ToString() const override {
if (inputs_.empty()) {
return fmt::format("{}", name_);
}
return fmt::format("{}[{}]", inputs_[0]->ToString(), name_);
}
private:
std::string name_;
bool is_input_column_;
};
using FieldAccessTypeExprPtr = std::shared_ptr<const FieldAccessTypeExpr>;
/**
* @brief Base class for all milvus filter exprs, output type must be BOOL
* a strongly-typed expression, such as literal, function call, etc...
*/
class ITypeFilterExpr : public ITypeExpr {
public:
ITypeFilterExpr() : ITypeExpr(DataType::BOOL) {
}
ITypeFilterExpr(std::vector<std::shared_ptr<const ITypeExpr>> inputs)
: ITypeExpr(DataType::BOOL, std::move(inputs)) {
}
virtual ~ITypeFilterExpr() = default;
};
class UnaryRangeFilterExpr : public ITypeFilterExpr {
public:
explicit UnaryRangeFilterExpr(const ColumnInfo& column,
proto::plan::OpType op_type,
const proto::plan::GenericValue& val)
: ITypeFilterExpr(), column_(column), op_type_(op_type), val_(val) {
}
std::string
ToString() const override {
std::stringstream ss;
ss << "UnaryRangeFilterExpr: {columnInfo:" << column_.ToString()
<< " op_type:" << milvus::proto::plan::OpType_Name(op_type_)
<< " val:" << val_.DebugString() << "}";
return ss.str();
}
public:
const ColumnInfo column_;
const proto::plan::OpType op_type_;
const proto::plan::GenericValue val_;
};
class AlwaysTrueExpr : public ITypeFilterExpr {
public:
explicit AlwaysTrueExpr() {
}
std::string
ToString() const override {
return "AlwaysTrue expr";
}
};
class ExistsExpr : public ITypeFilterExpr {
public:
explicit ExistsExpr(const ColumnInfo& column)
: ITypeFilterExpr(), column_(column) {
}
std::string
ToString() const override {
return "{Exists Expression - Column: " + column_.ToString() + "}";
}
const ColumnInfo column_;
};
class LogicalUnaryExpr : public ITypeFilterExpr {
public:
enum class OpType { Invalid = 0, LogicalNot = 1 };
explicit LogicalUnaryExpr(const OpType op_type, const TypedExprPtr& child)
: op_type_(op_type) {
inputs_.emplace_back(child);
}
std::string
ToString() const override {
std::string opTypeString;
switch (op_type_) {
case OpType::LogicalNot:
opTypeString = "Logical NOT";
break;
default:
opTypeString = "Invalid Operator";
break;
}
return fmt::format("LogicalUnaryExpr:[{} - Child: {}]",
opTypeString,
inputs_[0]->ToString());
}
const OpType op_type_;
};
class TermFilterExpr : public ITypeFilterExpr {
public:
explicit TermFilterExpr(const ColumnInfo& column,
const std::vector<proto::plan::GenericValue>& vals,
bool is_in_field = false)
: ITypeFilterExpr(),
column_(column),
vals_(vals),
is_in_field_(is_in_field) {
}
std::string
ToString() const override {
std::string values;
for (const auto& val : vals_) {
values += val.DebugString() + ", ";
}
std::stringstream ss;
ss << "TermFilterExpr:[Column: " << column_.ToString() << ", Values: ["
<< values << "]"
<< ", Is In Field: " << (is_in_field_ ? "true" : "false") << "]";
return ss.str();
}
public:
const ColumnInfo column_;
const std::vector<proto::plan::GenericValue> vals_;
const bool is_in_field_;
};
class LogicalBinaryExpr : public ITypeFilterExpr {
public:
enum class OpType { Invalid = 0, And = 1, Or = 2 };
explicit LogicalBinaryExpr(OpType op_type,
const TypedExprPtr& left,
const TypedExprPtr& right)
: ITypeFilterExpr(), op_type_(op_type) {
inputs_.emplace_back(left);
inputs_.emplace_back(right);
}
std::string
GetOpTypeString() const {
switch (op_type_) {
case OpType::Invalid:
return "Invalid";
case OpType::And:
return "And";
case OpType::Or:
return "Or";
default:
return "Unknown"; // Handle the default case if necessary
}
}
std::string
ToString() const override {
return fmt::format("LogicalBinaryExpr:[{} - Left: {}, Right: {}]",
GetOpTypeString(),
inputs_[0]->ToString(),
inputs_[1]->ToString());
}
std::string
name() const {
return GetOpTypeString();
}
public:
const OpType op_type_;
};
class BinaryRangeFilterExpr : public ITypeFilterExpr {
public:
BinaryRangeFilterExpr(const ColumnInfo& column,
const proto::plan::GenericValue& lower_value,
const proto::plan::GenericValue& upper_value,
bool lower_inclusive,
bool upper_inclusive)
: ITypeFilterExpr(),
column_(column),
lower_val_(lower_value),
upper_val_(upper_value),
lower_inclusive_(lower_inclusive),
upper_inclusive_(upper_inclusive) {
}
std::string
ToString() const override {
std::stringstream ss;
ss << "BinaryRangeFilterExpr:[Column: " << column_.ToString()
<< ", Lower Value: " << lower_val_.DebugString()
<< ", Upper Value: " << upper_val_.DebugString()
<< ", Lower Inclusive: " << (lower_inclusive_ ? "true" : "false")
<< ", Upper Inclusive: " << (upper_inclusive_ ? "true" : "false")
<< "]";
return ss.str();
}
const ColumnInfo column_;
const proto::plan::GenericValue lower_val_;
const proto::plan::GenericValue upper_val_;
const bool lower_inclusive_;
const bool upper_inclusive_;
};
class BinaryArithOpEvalRangeExpr : public ITypeFilterExpr {
public:
BinaryArithOpEvalRangeExpr(const ColumnInfo& column,
const proto::plan::OpType op_type,
const proto::plan::ArithOpType arith_op_type,
const proto::plan::GenericValue value,
const proto::plan::GenericValue right_operand)
: column_(column),
op_type_(op_type),
arith_op_type_(arith_op_type),
right_operand_(right_operand),
value_(value) {
}
std::string
ToString() const override {
std::stringstream ss;
ss << "BinaryArithOpEvalRangeExpr:[Column: " << column_.ToString()
<< ", Operator Type: " << milvus::proto::plan::OpType_Name(op_type_)
<< ", Arith Operator Type: "
<< milvus::proto::plan::ArithOpType_Name(arith_op_type_)
<< ", Value: " << value_.DebugString()
<< ", Right Operand: " << right_operand_.DebugString() << "]";
return ss.str();
}
public:
const ColumnInfo column_;
const proto::plan::OpType op_type_;
const proto::plan::ArithOpType arith_op_type_;
const proto::plan::GenericValue right_operand_;
const proto::plan::GenericValue value_;
};
class CompareExpr : public ITypeFilterExpr {
public:
CompareExpr(const FieldId& left_field,
const FieldId& right_field,
DataType left_data_type,
DataType right_data_type,
proto::plan::OpType op_type)
: left_field_id_(left_field),
right_field_id_(right_field),
left_data_type_(left_data_type),
right_data_type_(right_data_type),
op_type_(op_type) {
}
std::string
ToString() const override {
std::string opTypeString;
return fmt::format(
"CompareExpr:[Left Field ID: {}, Right Field ID: {}, Left Data "
"Type: {}, "
"Operator: {}, Right "
"Data Type: {}]",
left_field_id_.get(),
right_field_id_.get(),
milvus::proto::plan::OpType_Name(op_type_),
left_data_type_,
right_data_type_);
}
public:
const FieldId left_field_id_;
const FieldId right_field_id_;
const DataType left_data_type_;
const DataType right_data_type_;
const proto::plan::OpType op_type_;
};
class JsonContainsExpr : public ITypeFilterExpr {
public:
JsonContainsExpr(ColumnInfo column,
ContainsType op,
const bool same_type,
const std::vector<proto::plan::GenericValue>& vals)
: column_(column),
op_(op),
same_type_(same_type),
vals_(std::move(vals)) {
}
std::string
ToString() const override {
std::string values;
for (const auto& val : vals_) {
values += val.DebugString() + ", ";
}
return fmt::format(
"JsonContainsExpr:[Column: {}, Operator: {}, Same Type: {}, "
"Values: [{}]]",
column_.ToString(),
JSONContainsExpr_JSONOp_Name(op_),
(same_type_ ? "true" : "false"),
values);
}
public:
const ColumnInfo column_;
ContainsType op_;
bool same_type_;
const std::vector<proto::plan::GenericValue> vals_;
};
} // namespace expr
} // namespace milvus
template <>
struct fmt::formatter<milvus::proto::plan::ArithOpType>
: formatter<string_view> {
auto
format(milvus::proto::plan::ArithOpType c, format_context& ctx) const {
using namespace milvus::proto::plan;
string_view name = "unknown";
switch (c) {
case ArithOpType::Unknown:
name = "Unknown";
break;
case ArithOpType::Add:
name = "Add";
break;
case ArithOpType::Sub:
name = "Sub";
break;
case ArithOpType::Mul:
name = "Mul";
break;
case ArithOpType::Div:
name = "Div";
break;
case ArithOpType::Mod:
name = "Mod";
break;
case ArithOpType::ArrayLength:
name = "ArrayLength";
break;
case ArithOpType::ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_:
name = "ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_";
break;
case ArithOpType::ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_:
name = "ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_";
break;
}
return formatter<string_view>::format(name, ctx);
}
};

View File

@ -13,7 +13,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common/Types.h"

View File

@ -68,7 +68,7 @@ ScalarIndexSort<T>::BuildV2(const Config& config) {
PanicInfo(S3Error, "failed to create scan iterator");
}
auto reader = res.value();
std::vector<storage::FieldDataPtr> field_datas;
std::vector<FieldDataPtr> field_datas;
for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) {
if (!rec.ok()) {
PanicInfo(DataFormatBroken, "failed to read data");
@ -280,7 +280,7 @@ ScalarIndexSort<T>::LoadV2(const Config& config) {
index_files.push_back(b.name);
}
}
std::map<std::string, storage::FieldDataPtr> index_datas{};
std::map<std::string, FieldDataPtr> index_datas{};
for (auto& file_name : index_files) {
auto res = space_->GetBlobByteSize(file_name);
if (!res.ok()) {

View File

@ -24,11 +24,12 @@
#include "common/Types.h"
#include "common/EasyAssert.h"
#include "common/Exception.h"
#include "common/Utils.h"
#include "common/Slice.h"
#include "index/StringIndexMarisa.h"
#include "index/Utils.h"
#include "index/Index.h"
#include "common/Utils.h"
#include "common/Slice.h"
#include "storage/Util.h"
#include "storage/space.h"
@ -73,7 +74,7 @@ StringIndexMarisa::BuildV2(const Config& config) {
PanicInfo(S3Error, "failed to create scan iterator");
}
auto reader = res.value();
std::vector<storage::FieldDataPtr> field_datas;
std::vector<FieldDataPtr> field_datas;
for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) {
if (!rec.ok()) {
PanicInfo(DataFormatBroken, "failed to read data");
@ -315,7 +316,7 @@ StringIndexMarisa::LoadV2(const Config& config) {
index_files.push_back(b.name);
}
}
std::map<std::string, storage::FieldDataPtr> index_datas{};
std::map<std::string, FieldDataPtr> index_datas{};
for (auto& file_name : index_files) {
auto res = space_->GetBlobByteSize(file_name);
if (!res.ok()) {

View File

@ -24,17 +24,18 @@
#include <vector>
#include <functional>
#include <iostream>
#include <unistd.h>
#include <google/protobuf/text_format.h>
#include "common/EasyAssert.h"
#include "common/Exception.h"
#include "common/File.h"
#include "common/FieldData.h"
#include "common/Slice.h"
#include "index/Utils.h"
#include "index/Meta.h"
#include <google/protobuf/text_format.h>
#include <unistd.h>
#include "common/EasyAssert.h"
#include "knowhere/comp/index_param.h"
#include "common/Slice.h"
#include "storage/FieldData.h"
#include "storage/Util.h"
#include "common/File.h"
#include "knowhere/comp/index_param.h"
namespace milvus::index {
@ -205,7 +206,7 @@ ParseConfigFromIndexParams(
}
void
AssembleIndexDatas(std::map<std::string, storage::FieldDataPtr>& index_datas) {
AssembleIndexDatas(std::map<std::string, FieldDataPtr>& index_datas) {
if (index_datas.find(INDEX_FILE_SLICE_META) != index_datas.end()) {
auto slice_meta = index_datas.at(INDEX_FILE_SLICE_META);
Config meta_data = Config::parse(std::string(
@ -237,9 +238,8 @@ AssembleIndexDatas(std::map<std::string, storage::FieldDataPtr>& index_datas) {
}
void
AssembleIndexDatas(
std::map<std::string, storage::FieldDataChannelPtr>& index_datas,
std::unordered_map<std::string, storage::FieldDataPtr>& result) {
AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
std::unordered_map<std::string, FieldDataPtr>& result) {
if (auto meta_iter = index_datas.find(INDEX_FILE_SLICE_META);
meta_iter != index_datas.end()) {
auto raw_metadata_array =

View File

@ -28,9 +28,9 @@
#include <string>
#include "common/Types.h"
#include "common/FieldData.h"
#include "index/IndexInfo.h"
#include "storage/Types.h"
#include "storage/FieldData.h"
namespace milvus::index {
@ -114,12 +114,11 @@ ParseConfigFromIndexParams(
const std::map<std::string, std::string>& index_params);
void
AssembleIndexDatas(std::map<std::string, storage::FieldDataPtr>& index_datas);
AssembleIndexDatas(std::map<std::string, FieldDataPtr>& index_datas);
void
AssembleIndexDatas(
std::map<std::string, storage::FieldDataChannelPtr>& index_datas,
std::unordered_map<std::string, storage::FieldDataPtr>& result);
AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
std::unordered_map<std::string, FieldDataPtr>& result);
// On Linux, read() (and similar system calls) will transfer at most 0x7ffff000 (2,147,479,552) bytes once
void

View File

@ -38,20 +38,20 @@
#include "knowhere/factory.h"
#include "knowhere/comp/time_recorder.h"
#include "common/BitsetView.h"
#include "common/Slice.h"
#include "common/Consts.h"
#include "common/FieldData.h"
#include "common/File.h"
#include "common/Slice.h"
#include "common/Tracer.h"
#include "common/RangeSearchHelper.h"
#include "common/Utils.h"
#include "log/Log.h"
#include "mmap/Types.h"
#include "storage/DataCodec.h"
#include "storage/FieldData.h"
#include "storage/MemFileManagerImpl.h"
#include "storage/ThreadPools.h"
#include "storage/Util.h"
#include "common/File.h"
#include "common/Tracer.h"
#include "storage/space.h"
#include "storage/Util.h"
namespace milvus::index {
@ -189,7 +189,7 @@ VectorMemIndex<T>::LoadV2(const Config& config) {
auto slice_meta_file = index_prefix + "/" + INDEX_FILE_SLICE_META;
auto res = space_->GetBlobByteSize(std::string(slice_meta_file));
std::map<std::string, storage::FieldDataPtr> index_datas{};
std::map<std::string, FieldDataPtr> index_datas{};
if (!res.ok() && !res.status().IsFileNotFound()) {
PanicInfo(DataFormatBroken, "failed to read blob");
@ -289,7 +289,7 @@ VectorMemIndex<T>::Load(const Config& config) {
auto parallel_degree =
static_cast<uint64_t>(DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE);
std::map<std::string, storage::FieldDataPtr> index_datas{};
std::map<std::string, FieldDataPtr> index_datas{};
// try to read slice meta first
std::string slice_meta_filepath;
@ -414,7 +414,7 @@ VectorMemIndex<T>::BuildV2(const Config& config) {
}
auto reader = res.value();
std::vector<storage::FieldDataPtr> field_datas;
std::vector<FieldDataPtr> field_datas;
for (auto rec : *reader) {
if (!rec.ok()) {
PanicInfo(IndexBuildError,

View File

@ -53,8 +53,13 @@
__FUNCTION__, \
GetThreadName().c_str())
#define LOG_SEGCORE_TRACE_ DLOG(INFO) << SEGCORE_MODULE_FUNCTION
#define LOG_SEGCORE_DEBUG_ DLOG(INFO) << SEGCORE_MODULE_FUNCTION
// GLOG has no debug and trace level,
// Using VLOG to implement it.
#define GLOG_DEBUG 5
#define GLOG_TRACE 6
#define LOG_SEGCORE_TRACE_ VLOG(GLOG_TRACE) << SEGCORE_MODULE_FUNCTION
#define LOG_SEGCORE_DEBUG_ VLOG(GLOG_DEBUG) << SEGCORE_MODULE_FUNCTION
#define LOG_SEGCORE_INFO_ LOG(INFO) << SEGCORE_MODULE_FUNCTION
#define LOG_SEGCORE_WARNING_ LOG(WARNING) << SEGCORE_MODULE_FUNCTION
#define LOG_SEGCORE_ERROR_ LOG(ERROR) << SEGCORE_MODULE_FUNCTION

View File

@ -21,15 +21,15 @@
#include <cstring>
#include <filesystem>
#include "common/FieldMeta.h"
#include "common/Span.h"
#include "common/Array.h"
#include "common/EasyAssert.h"
#include "common/File.h"
#include "common/FieldMeta.h"
#include "common/FieldData.h"
#include "common/Span.h"
#include "fmt/format.h"
#include "log/Log.h"
#include "mmap/Utils.h"
#include "storage/FieldData.h"
#include "common/Array.h"
namespace milvus {
@ -156,7 +156,7 @@ class ColumnBase {
Span() const = 0;
void
AppendBatch(const storage::FieldDataPtr& data) {
AppendBatch(const FieldDataPtr& data) {
size_t required_size = size_ + data->Size();
if (required_size > cap_size_) {
Expand(required_size * 2 + padding_);

View File

@ -19,13 +19,13 @@
#include <memory>
#include <string>
#include <vector>
#include "storage/FieldData.h"
#include "common/FieldData.h"
namespace milvus {
struct FieldDataInfo {
FieldDataInfo() {
channel = std::make_shared<storage::FieldDataChannel>();
channel = std::make_shared<FieldDataChannel>();
}
FieldDataInfo(int64_t field_id,
@ -34,12 +34,12 @@ struct FieldDataInfo {
: field_id(field_id),
row_count(row_count),
mmap_dir_path(std::move(mmap_dir_path)) {
channel = std::make_shared<storage::FieldDataChannel>();
channel = std::make_shared<FieldDataChannel>();
}
FieldDataInfo(int64_t field_id,
size_t row_count,
storage::FieldDataChannelPtr channel)
FieldDataChannelPtr channel)
: field_id(field_id),
row_count(row_count),
channel(std::move(channel)) {
@ -48,7 +48,7 @@ struct FieldDataInfo {
FieldDataInfo(int64_t field_id,
size_t row_count,
std::string mmap_dir_path,
storage::FieldDataChannelPtr channel)
FieldDataChannelPtr channel)
: field_id(field_id),
row_count(row_count),
mmap_dir_path(std::move(mmap_dir_path)),
@ -57,9 +57,9 @@ struct FieldDataInfo {
FieldDataInfo(int64_t field_id,
size_t row_count,
const std::vector<storage::FieldDataPtr>& batch)
const std::vector<FieldDataPtr>& batch)
: field_id(field_id), row_count(row_count) {
channel = std::make_shared<storage::FieldDataChannel>();
channel = std::make_shared<FieldDataChannel>();
for (auto& data : batch) {
channel->push(data);
}
@ -69,11 +69,11 @@ struct FieldDataInfo {
FieldDataInfo(int64_t field_id,
size_t row_count,
std::string mmap_dir_path,
const std::vector<storage::FieldDataPtr>& batch)
const std::vector<FieldDataPtr>& batch)
: field_id(field_id),
row_count(row_count),
mmap_dir_path(std::move(mmap_dir_path)) {
channel = std::make_shared<storage::FieldDataChannel>();
channel = std::make_shared<FieldDataChannel>();
for (auto& data : batch) {
channel->push(data);
}
@ -83,6 +83,6 @@ struct FieldDataInfo {
int64_t field_id;
size_t row_count;
std::string mmap_dir_path;
storage::FieldDataChannelPtr channel;
FieldDataChannelPtr channel;
};
} // namespace milvus

View File

@ -32,7 +32,7 @@
namespace milvus {
inline size_t
GetDataSize(const std::vector<storage::FieldDataPtr>& datas) {
GetDataSize(const std::vector<FieldDataPtr>& datas) {
size_t total_size{0};
for (auto data : datas) {
total_size += data->Size();
@ -42,7 +42,7 @@ GetDataSize(const std::vector<storage::FieldDataPtr>& datas) {
}
inline void*
FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) {
FillField(DataType data_type, const FieldDataPtr data, void* dst) {
char* dest = reinterpret_cast<char*>(dst);
if (datatype_is_variable(data_type)) {
switch (data_type) {
@ -80,7 +80,7 @@ FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) {
inline size_t
WriteFieldData(File& file,
DataType data_type,
const storage::FieldDataPtr& data,
const FieldDataPtr& data,
std::vector<std::vector<uint64_t>>& element_indices) {
size_t total_written{0};
if (datatype_is_variable(data_type)) {

View File

@ -0,0 +1,287 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "common/Types.h"
#include "common/Vector.h"
#include "expr/ITypeExpr.h"
#include "common/EasyAssert.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace plan {
typedef std::string PlanNodeId;
/**
* @brief Base class for all logic plan node
*
*/
class PlanNode {
public:
explicit PlanNode(const PlanNodeId& id) : id_(id) {
}
virtual ~PlanNode() = default;
const PlanNodeId&
id() const {
return id_;
}
virtual DataType
output_type() const = 0;
virtual std::vector<std::shared_ptr<PlanNode>>
sources() const = 0;
virtual bool
RequireSplits() const {
return false;
}
virtual std::string
ToString() const = 0;
virtual std::string_view
name() const = 0;
private:
PlanNodeId id_;
};
using PlanNodePtr = std::shared_ptr<PlanNode>;
class SegmentNode : public PlanNode {
public:
SegmentNode(
const PlanNodeId& id,
const std::shared_ptr<milvus::segcore::SegmentInternalInterface>&
segment)
: PlanNode(id), segment_(segment) {
}
DataType
output_type() const override {
return DataType::ROW;
}
std::vector<std::shared_ptr<PlanNode>>
sources() const override {
return {};
}
std::string_view
name() const override {
return "SegmentNode";
}
std::string
ToString() const override {
return "SegmentNode";
}
private:
std::shared_ptr<milvus::segcore::SegmentInternalInterface> segment_;
};
class ValuesNode : public PlanNode {
public:
ValuesNode(const PlanNodeId& id,
const std::vector<RowVectorPtr>& values,
bool parallelizeable = false)
: PlanNode(id),
values_{std::move(values)},
output_type_(values[0]->type()) {
AssertInfo(!values.empty(), "ValueNode must has value");
}
ValuesNode(const PlanNodeId& id,
std::vector<RowVectorPtr>&& values,
bool parallelizeable = false)
: PlanNode(id),
values_{std::move(values)},
output_type_(values[0]->type()) {
AssertInfo(!values.empty(), "ValueNode must has value");
}
DataType
output_type() const override {
return output_type_;
}
const std::vector<RowVectorPtr>&
values() const {
return values_;
}
std::vector<PlanNodePtr>
sources() const override {
return {};
}
bool
parallelizable() {
return parallelizable_;
}
std::string_view
name() const override {
return "Values";
}
std::string
ToString() const override {
return "Values";
}
private:
DataType output_type_;
const std::vector<RowVectorPtr> values_;
bool parallelizable_;
};
class FilterNode : public PlanNode {
public:
FilterNode(const PlanNodeId& id,
expr::TypedExprPtr filter,
std::vector<PlanNodePtr> sources)
: PlanNode(id),
sources_{std::move(sources)},
filter_(std::move(filter)) {
AssertInfo(
filter_->type() == DataType::BOOL,
fmt::format("Filter expression must be of type BOOLEAN, Got {}",
filter_->type()));
}
DataType
output_type() const override {
return sources_[0]->output_type();
}
std::vector<PlanNodePtr>
sources() const override {
return sources_;
}
const expr::TypedExprPtr&
filter() const {
return filter_;
}
std::string_view
name() const override {
return "Filter";
}
std::string
ToString() const override {
return "";
}
private:
const std::vector<PlanNodePtr> sources_;
const expr::TypedExprPtr filter_;
};
class FilterBitsNode : public PlanNode {
public:
FilterBitsNode(
const PlanNodeId& id,
expr::TypedExprPtr filter,
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
: PlanNode(id),
sources_{std::move(sources)},
filter_(std::move(filter)) {
AssertInfo(
filter_->type() == DataType::BOOL,
fmt::format("Filter expression must be of type BOOLEAN, Got {}",
filter_->type()));
}
DataType
output_type() const override {
return DataType::BOOL;
}
std::vector<PlanNodePtr>
sources() const override {
return sources_;
}
const expr::TypedExprPtr&
filter() const {
return filter_;
}
std::string_view
name() const override {
return "FilterBits";
}
std::string
ToString() const override {
return fmt::format("FilterBitsNode:[filter_expr:{}]",
filter_->ToString());
}
private:
const std::vector<PlanNodePtr> sources_;
const expr::TypedExprPtr filter_;
};
enum class ExecutionStrategy {
// Process splits as they come in any available driver.
kUngrouped,
// Process splits from each split group only in one driver.
// It is used when split groups represent separate partitions of the data on
// the grouping keys or join keys. In that case it is sufficient to keep only
// the keys from a single split group in a hash table used by group-by or
// join.
kGrouped,
};
struct PlanFragment {
std::shared_ptr<const PlanNode> plan_node_;
ExecutionStrategy execution_strategy_{ExecutionStrategy::kUngrouped};
int32_t num_splitgroups_{0};
PlanFragment() = default;
inline bool
IsGroupedExecution() const {
return execution_strategy_ == ExecutionStrategy::kGrouped;
}
explicit PlanFragment(std::shared_ptr<const PlanNode> top_node,
ExecutionStrategy strategy,
int32_t num_splitgroups)
: plan_node_(std::move(top_node)),
execution_strategy_(strategy),
num_splitgroups_(num_splitgroups) {
}
explicit PlanFragment(std::shared_ptr<const PlanNode> top_node)
: plan_node_(std::move(top_node)) {
}
};
} // namespace plan
} // namespace milvus

View File

@ -20,10 +20,12 @@
#include "common/QueryInfo.h"
#include "query/Expr.h"
namespace milvus::plan {
class PlanNode;
};
namespace milvus::query {
class PlanNodeVisitor;
// Base of all Nodes
struct PlanNode {
public:
@ -36,6 +38,7 @@ using PlanNodePtr = std::unique_ptr<PlanNode>;
struct VectorPlanNode : PlanNode {
std::optional<ExprPtr> predicate_;
std::optional<std::shared_ptr<milvus::plan::PlanNode>> filter_plannode_;
SearchInfo search_info_;
std::string placeholder_tag_;
};
@ -64,6 +67,7 @@ struct RetrievePlanNode : PlanNode {
accept(PlanNodeVisitor&) override;
std::optional<ExprPtr> predicate_;
std::optional<std::shared_ptr<milvus::plan::PlanNode>> filter_plannode_;
bool is_count_;
int64_t limit_;
};

View File

@ -185,6 +185,12 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
}
}();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(anns_proto.predicates());
return std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
expr);
};
auto& query_info_proto = anns_proto.query_info();
SearchInfo search_info;
@ -210,6 +216,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
}();
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
plan_node->predicate_ = std::move(expr_opt);
if (anns_proto.has_predicates()) {
plan_node->filter_plannode_ = std::move(expr_parser());
}
plan_node->search_info_ = std::move(search_info);
return plan_node;
}
@ -227,7 +236,13 @@ ProtoParser::RetrievePlanNodeFromProto(
auto expr_opt = [&]() -> ExprPtr {
return ParseExpr(predicate_proto);
}();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(predicate_proto);
return std::make_shared<plan::FilterBitsNode>(
DEFAULT_PLANNODE_ID, expr);
}();
node->predicate_ = std::move(expr_opt);
node->filter_plannode_ = std::move(expr_parser);
} else {
auto& query = plan_node_proto.query();
if (query.has_predicates()) {
@ -235,7 +250,13 @@ ProtoParser::RetrievePlanNodeFromProto(
auto expr_opt = [&]() -> ExprPtr {
return ParseExpr(predicate_proto);
}();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(predicate_proto);
return std::make_shared<plan::FilterBitsNode>(
DEFAULT_PLANNODE_ID, expr);
}();
node->predicate_ = std::move(expr_opt);
node->filter_plannode_ = std::move(expr_parser);
}
node->is_count_ = query.is_count();
node->limit_ = query.limit();
@ -284,6 +305,16 @@ ProtoParser::CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto) {
return retrieve_plan;
}
expr::TypedExprPtr
ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
return std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value());
}
ExprPtr
ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
@ -352,6 +383,21 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) {
return result;
}
expr::TypedExprPtr
ProtoParser::ParseBinaryRangeExprs(
const proto::plan::BinaryRangeExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
return std::make_shared<expr::BinaryRangeFilterExpr>(
columnInfo,
expr_pb.lower_value(),
expr_pb.upper_value(),
expr_pb.lower_inclusive(),
expr_pb.upper_inclusive());
}
ExprPtr
ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
@ -436,6 +482,27 @@ ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) {
return result;
}
expr::TypedExprPtr
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
auto& left_column_info = expr_pb.left_column_info();
auto left_field_id = FieldId(left_column_info.field_id());
auto left_data_type = schema[left_field_id].get_data_type();
Assert(left_data_type ==
static_cast<DataType>(left_column_info.data_type()));
auto& right_column_info = expr_pb.right_column_info();
auto right_field_id = FieldId(right_column_info.field_id());
auto right_data_type = schema[right_field_id].get_data_type();
Assert(right_data_type ==
static_cast<DataType>(right_column_info.data_type()));
return std::make_shared<expr::CompareExpr>(left_field_id,
right_field_id,
left_data_type,
right_data_type,
expr_pb.op());
}
ExprPtr
ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) {
auto& left_column_info = expr_pb.left_column_info();
@ -461,6 +528,20 @@ ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) {
}();
}
expr::TypedExprPtr
ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
std::vector<::milvus::proto::plan::GenericValue> values;
for (size_t i = 0; i < expr_pb.values_size(); i++) {
values.emplace_back(expr_pb.values(i));
}
return std::make_shared<expr::TermFilterExpr>(
columnInfo, values, expr_pb.is_in_field());
}
ExprPtr
ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
@ -568,6 +649,14 @@ ProtoParser::ParseUnaryExpr(const proto::plan::UnaryExpr& expr_pb) {
return std::make_unique<LogicalUnaryExpr>(op, expr);
}
expr::TypedExprPtr
ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) {
auto op = static_cast<expr::LogicalUnaryExpr::OpType>(expr_pb.op());
Assert(op == expr::LogicalUnaryExpr::OpType::LogicalNot);
auto child_expr = this->ParseExprs(expr_pb.child());
return std::make_shared<expr::LogicalUnaryExpr>(op, child_expr);
}
ExprPtr
ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
auto op = static_cast<LogicalBinaryExpr::OpType>(expr_pb.op());
@ -576,6 +665,14 @@ ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
return std::make_unique<LogicalBinaryExpr>(op, left_expr, right_expr);
}
expr::TypedExprPtr
ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) {
auto op = static_cast<expr::LogicalBinaryExpr::OpType>(expr_pb.op());
auto left_expr = this->ParseExprs(expr_pb.left());
auto right_expr = this->ParseExprs(expr_pb.right());
return std::make_shared<expr::LogicalBinaryExpr>(op, left_expr, right_expr);
}
ExprPtr
ProtoParser::ParseBinaryArithOpEvalRangeExpr(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
@ -642,11 +739,35 @@ ProtoParser::ParseBinaryArithOpEvalRangeExpr(
return result;
}
expr::TypedExprPtr
ProtoParser::ParseBinaryArithOpEvalRangeExprs(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
return std::make_shared<expr::BinaryArithOpEvalRangeExpr>(
column_info,
expr_pb.op(),
expr_pb.arith_op(),
expr_pb.value(),
expr_pb.right_operand());
}
std::unique_ptr<ExistsExprImpl>
ExtractExistsExprImpl(const proto::plan::ExistsExpr& expr_proto) {
return std::make_unique<ExistsExprImpl>(expr_proto.info());
}
expr::TypedExprPtr
ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
auto& column_info = expr_pb.info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
return std::make_shared<expr::ExistsExpr>(column_info);
}
ExprPtr
ProtoParser::ParseExistExpr(const proto::plan::ExistsExpr& expr_pb) {
auto& column_info = expr_pb.info();
@ -718,6 +839,24 @@ ExtractJsonContainsExprImpl(const proto::plan::JSONContainsExpr& expr_proto) {
val_case);
}
expr::TypedExprPtr
ProtoParser::ParseJsonContainsExprs(
const proto::plan::JSONContainsExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
std::vector<::milvus::proto::plan::GenericValue> values;
for (size_t i = 0; i < expr_pb.elements_size(); i++) {
values.emplace_back(expr_pb.elements(i));
}
return std::make_shared<expr::JsonContainsExpr>(
columnInfo,
expr_pb.op(),
expr_pb.elements_same_type(),
std::move(values));
}
ExprPtr
ProtoParser::ParseJsonContainsExpr(
const proto::plan::JSONContainsExpr& expr_pb) {
@ -755,6 +894,55 @@ ProtoParser::ParseJsonContainsExpr(
return result;
}
expr::TypedExprPtr
ProtoParser::CreateAlwaysTrueExprs() {
return std::make_shared<expr::AlwaysTrueExpr>();
}
expr::TypedExprPtr
ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) {
using ppe = proto::plan::Expr;
switch (expr_pb.expr_case()) {
case ppe::kUnaryRangeExpr: {
return ParseUnaryRangeExprs(expr_pb.unary_range_expr());
}
case ppe::kBinaryExpr: {
return ParseBinaryExprs(expr_pb.binary_expr());
}
case ppe::kUnaryExpr: {
return ParseUnaryExprs(expr_pb.unary_expr());
}
case ppe::kTermExpr: {
return ParseTermExprs(expr_pb.term_expr());
}
case ppe::kBinaryRangeExpr: {
return ParseBinaryRangeExprs(expr_pb.binary_range_expr());
}
case ppe::kCompareExpr: {
return ParseCompareExprs(expr_pb.compare_expr());
}
case ppe::kBinaryArithOpEvalRangeExpr: {
return ParseBinaryArithOpEvalRangeExprs(
expr_pb.binary_arith_op_eval_range_expr());
}
case ppe::kExistsExpr: {
return ParseExistExprs(expr_pb.exists_expr());
}
case ppe::kAlwaysTrueExpr: {
return CreateAlwaysTrueExprs();
}
case ppe::kJsonContainsExpr: {
return ParseJsonContainsExprs(expr_pb.json_contains_expr());
}
default: {
std::string s;
google::protobuf::TextFormat::PrintToString(expr_pb, &s);
PanicInfo(ExprInvalid,
std::string("unsupported expr proto node: ") + s);
}
}
}
ExprPtr
ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) {
using ppe = proto::plan::Expr;

View File

@ -18,6 +18,7 @@
#include "PlanNode.h"
#include "common/Schema.h"
#include "pb/plan.pb.h"
#include "plan/PlanNode.h"
namespace milvus::query {
@ -72,6 +73,40 @@ class ProtoParser {
std::unique_ptr<RetrievePlan>
CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto);
expr::TypedExprPtr
ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb);
expr::TypedExprPtr
ParseExprs(const proto::plan::Expr& expr_pb);
expr::TypedExprPtr
ParseBinaryArithOpEvalRangeExprs(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
expr::TypedExprPtr
ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb);
expr::TypedExprPtr
ParseCompareExprs(const proto::plan::CompareExpr& expr_pb);
expr::TypedExprPtr
ParseTermExprs(const proto::plan::TermExpr& expr_pb);
expr::TypedExprPtr
ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb);
expr::TypedExprPtr
ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb);
expr::TypedExprPtr
ParseExistExprs(const proto::plan::ExistsExpr& expr_pb);
expr::TypedExprPtr
ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb);
expr::TypedExprPtr
CreateAlwaysTrueExprs();
private:
const Schema& schema;
};

View File

@ -16,6 +16,7 @@
#include "query/Expr.h"
#include "common/Utils.h"
#include "simd/hook.h"
namespace milvus::query {
@ -70,4 +71,61 @@ inline bool
out_of_range(int64_t t) {
return gt_ub<T>(t) || lt_lb<T>(t);
}
inline void
AppendOneChunk(BitsetType& result, const bool* chunk_ptr, size_t chunk_len) {
// Append a value once instead of BITSET_BLOCK_BIT_SIZE times.
auto AppendBlock = [&result](const bool* ptr, int n) {
for (int i = 0; i < n; ++i) {
#if defined(USE_DYNAMIC_SIMD)
auto val = milvus::simd::get_bitset_block(ptr);
#else
BitsetBlockType val = 0;
// This can use CPU SIMD optimzation
uint8_t vals[BITSET_BLOCK_SIZE] = {0};
for (size_t j = 0; j < 8; ++j) {
for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) {
vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j;
}
}
for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) {
val |= BitsetBlockType(vals[j]) << (8 * j);
}
#endif
result.append(val);
ptr += BITSET_BLOCK_SIZE * 8;
}
};
// Append bit for these bits that can not be union as a block
// Usually n less than BITSET_BLOCK_BIT_SIZE.
auto AppendBit = [&result](const bool* ptr, int n) {
for (int i = 0; i < n; ++i) {
bool bit = *ptr++;
result.push_back(bit);
}
};
size_t res_len = result.size();
int n_prefix =
res_len % BITSET_BLOCK_BIT_SIZE == 0
? 0
: std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE,
chunk_len);
AppendBit(chunk_ptr, n_prefix);
if (n_prefix == chunk_len)
return;
size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE;
size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE;
AppendBlock(chunk_ptr + n_prefix, n_block);
AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix);
return;
}
} // namespace milvus::query

View File

@ -19,6 +19,7 @@
#include "PlanNodeVisitor.h"
namespace milvus::query {
class ExecPlanNodeVisitor : public PlanNodeVisitor {
public:
void
@ -96,6 +97,24 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
return expr_use_pk_index_;
}
void
ExecuteExprNodeInternal(
const std::shared_ptr<milvus::plan::PlanNode>& plannode,
const milvus::segcore::SegmentInternalInterface* segment,
BitsetType& result,
bool& cache_offset_getted,
std::vector<int64_t>& cache_offset);
void
ExecuteExprNode(const std::shared_ptr<milvus::plan::PlanNode>& plannode,
const milvus::segcore::SegmentInternalInterface* segment,
BitsetType& result) {
bool get_cache_offset;
std::vector<int64_t> cache_offsets;
ExecuteExprNodeInternal(
plannode, segment, result, get_cache_offset, cache_offsets);
}
private:
template <typename VectorType>
void

View File

@ -16,9 +16,12 @@
#include "query/PlanImpl.h"
#include "query/SubSearchResult.h"
#include "query/generated/ExecExprVisitor.h"
#include "query/Utils.h"
#include "segcore/SegmentGrowing.h"
#include "common/Json.h"
#include "log/Log.h"
#include "plan/PlanNode.h"
#include "exec/Task.h"
namespace milvus::query {
@ -73,6 +76,63 @@ empty_search_result(int64_t num_queries, SearchInfo& search_info) {
return final_result;
}
void
ExecPlanNodeVisitor::ExecuteExprNodeInternal(
const std::shared_ptr<milvus::plan::PlanNode>& plannode,
const milvus::segcore::SegmentInternalInterface* segment,
BitsetType& bitset_holder,
bool& cache_offset_getted,
std::vector<int64_t>& cache_offset) {
bitset_holder.clear();
LOG_SEGCORE_INFO_ << "plannode:" << plannode->ToString();
auto plan = plan::PlanFragment(plannode);
// TODO: get query id from proxy
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment, timestamp_);
auto task =
milvus::exec::Task::Create(DEFAULT_TASK_ID, plan, 0, query_context);
for (;;) {
auto result = task->Next();
if (!result) {
break;
}
auto childrens = result->childrens();
AssertInfo(childrens.size() == 1,
"expr result vector's children size not equal one");
LOG_SEGCORE_DEBUG_ << "output result length:" << childrens[0]->size()
<< std::endl;
if (auto vec = std::dynamic_pointer_cast<ColumnVector>(childrens[0])) {
AppendOneChunk(bitset_holder,
static_cast<bool*>(vec->GetRawData()),
vec->size());
} else if (auto row =
std::dynamic_pointer_cast<RowVector>(childrens[0])) {
auto bit_vec =
std::dynamic_pointer_cast<ColumnVector>(row->child(0));
AppendOneChunk(bitset_holder,
static_cast<bool*>(bit_vec->GetRawData()),
bit_vec->size());
if (!cache_offset_getted) {
// offset cache only get once because not support iterator batch
auto cache_offset_vec =
std::dynamic_pointer_cast<ColumnVector>(row->child(1));
auto cache_offset_vec_ptr =
(int64_t*)(cache_offset_vec->GetRawData());
for (size_t i = 0; i < cache_offset_vec->size(); ++i) {
cache_offset.push_back(cache_offset_vec_ptr[i]);
}
cache_offset_getted = true;
}
} else {
PanicInfo(UnexpectedError, "expr return type not matched");
}
}
// std::string s;
// boost::to_string(*bitset_holder, s);
// std::cout << bitset_holder->size() << " . " << s << std::endl;
}
template <typename VectorType>
void
ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
@ -98,10 +158,10 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
}
std::unique_ptr<BitsetType> bitset_holder;
if (node.predicate_.has_value()) {
bitset_holder = std::make_unique<BitsetType>(
ExecExprVisitor(*segment, this, active_count, timestamp_)
.call_child(*node.predicate_.value()));
if (node.filter_plannode_.has_value()) {
BitsetType expr_res;
ExecuteExprNode(node.filter_plannode_.value(), segment, expr_res);
bitset_holder = std::make_unique<BitsetType>(expr_res);
bitset_holder->flip();
} else {
bitset_holder = std::make_unique<BitsetType>(active_count, false);
@ -165,10 +225,16 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
bitset_holder.resize(active_count);
}
if (node.predicate_.has_value() && node.predicate_.value() != nullptr) {
bitset_holder =
ExecExprVisitor(*segment, this, active_count, timestamp_)
.call_child(*(node.predicate_.value()));
// This flag used to indicate whether to get offset from expr module that
// speeds up mvcc filter in the next interface: "timestamp_filter"
bool get_cache_offset = false;
std::vector<int64_t> cache_offsets;
if (node.filter_plannode_.has_value()) {
ExecuteExprNodeInternal(node.filter_plannode_.value(),
segment,
bitset_holder,
get_cache_offset,
cache_offsets);
bitset_holder.flip();
}
@ -189,9 +255,8 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
}
bool false_filtered_out = false;
if (GetExprUsePkIndex() && IsTermExpr(node.predicate_.value().get())) {
segment->timestamp_filter(
bitset_holder, expr_cached_pk_id_offsets_, timestamp_);
if (get_cache_offset) {
segment->timestamp_filter(bitset_holder, cache_offsets, timestamp_);
} else {
bitset_holder.flip();
false_filtered_out = true;

View File

@ -42,6 +42,6 @@ set(SEGCORE_FILES
SkipIndex.cpp)
add_library(milvus_segcore SHARED ${SEGCORE_FILES})
target_link_libraries(milvus_segcore milvus_query ${OpenMP_CXX_FLAGS} milvus-storage)
target_link_libraries(milvus_segcore milvus_query milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage)
install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}")

View File

@ -25,13 +25,13 @@
#include <utility>
#include <vector>
#include "common/EasyAssert.h"
#include "common/FieldMeta.h"
#include "common/FieldData.h"
#include "common/Json.h"
#include "common/Span.h"
#include "common/Types.h"
#include "common/Utils.h"
#include "common/EasyAssert.h"
#include "storage/FieldData.h"
namespace milvus::segcore {
@ -103,7 +103,7 @@ class VectorBase {
virtual void
set_data_raw(ssize_t element_offset,
const std::vector<storage::FieldDataPtr>& data) = 0;
const std::vector<FieldDataPtr>& data) = 0;
void
set_data_raw(ssize_t element_offset,
@ -112,7 +112,7 @@ class VectorBase {
const FieldMeta& field_meta);
virtual void
fill_chunk_data(const std::vector<storage::FieldDataPtr>& data) = 0;
fill_chunk_data(const std::vector<FieldDataPtr>& data) = 0;
virtual SpanBase
get_span_base(int64_t chunk_id) const = 0;
@ -197,7 +197,7 @@ class ConcurrentVectorImpl : public VectorBase {
}
void
fill_chunk_data(const std::vector<storage::FieldDataPtr>& datas)
fill_chunk_data(const std::vector<FieldDataPtr>& datas)
override { // used only for sealed segment
AssertInfo(chunks_.size() == 0, "no empty concurrent vector");
@ -217,7 +217,7 @@ class ConcurrentVectorImpl : public VectorBase {
void
set_data_raw(ssize_t element_offset,
const std::vector<storage::FieldDataPtr>& datas) override {
const std::vector<FieldDataPtr>& datas) override {
for (auto& field_data : datas) {
auto num_rows = field_data->get_num_rows();
set_data_raw(element_offset, field_data->Data(), num_rows);

View File

@ -306,7 +306,7 @@ class IndexingRecord {
AppendingIndex(int64_t reserved_offset,
int64_t size,
FieldId fieldId,
const storage::FieldDataPtr data,
const FieldDataPtr data,
const InsertRecord<is_sealed>& record) {
if (is_in(fieldId)) {
auto& indexing = field_indexings_.at(fieldId);

View File

@ -424,7 +424,7 @@ struct InsertRecord {
}
void
insert_pks(const std::vector<storage::FieldDataPtr>& field_datas) {
insert_pks(const std::vector<FieldDataPtr>& field_datas) {
std::lock_guard lck(shared_mutex_);
int64_t offset = 0;
for (auto& data : field_datas) {

View File

@ -21,6 +21,7 @@
#include "common/Consts.h"
#include "common/EasyAssert.h"
#include "common/FieldData.h"
#include "common/Types.h"
#include "fmt/format.h"
#include "log/Log.h"
@ -29,7 +30,6 @@
#include "query/SearchOnSealed.h"
#include "segcore/SegmentGrowingImpl.h"
#include "segcore/Utils.h"
#include "storage/FieldData.h"
#include "storage/RemoteChunkManagerSingleton.h"
#include "storage/Util.h"
#include "storage/ThreadPools.h"
@ -58,8 +58,12 @@ SegmentGrowingImpl::mask_with_delete(BitsetType& bitset,
return;
}
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
AssertInfo(delete_bitset.size() == bitset.size(),
"Deleted bitmap size not equal to filtered bitmap size");
AssertInfo(
delete_bitset.size() == bitset.size(),
fmt::format(
"Deleted bitmap size:{} not equal to filtered bitmap size:{}",
delete_bitset.size(),
bitset.size()));
bitset |= delete_bitset;
}
@ -177,12 +181,12 @@ SegmentGrowingImpl::LoadFieldData(const LoadFieldDataInfo& infos) {
for (auto& [id, info] : infos.field_infos) {
auto field_id = FieldId(id);
auto insert_files = info.insert_files;
auto channel = std::make_shared<storage::FieldDataChannel>();
auto channel = std::make_shared<FieldDataChannel>();
auto& pool =
ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::MIDDLE);
auto load_future =
pool.Submit(LoadFieldDatasFromRemote, insert_files, channel);
auto field_data = CollectFieldDataChannel(channel);
auto field_data = storage::CollectFieldDataChannel(channel);
if (field_id == TimestampFieldID) {
// step 2: sort timestamp
// query node already guarantees that the timestamp is ordered, avoid field data copy in c++
@ -263,7 +267,8 @@ SegmentGrowingImpl::LoadFieldDataV2(const LoadFieldDataInfo& infos) {
std::shared_ptr<milvus_storage::Space> space = std::move(res.value());
auto load_future = pool.Submit(
LoadFieldDatasFromRemote2, space, schema_, field_data_info);
auto field_data = CollectFieldDataChannel(field_data_info.channel);
auto field_data =
milvus::storage::CollectFieldDataChannel(field_data_info.channel);
if (field_id == TimestampFieldID) {
// step 2: sort timestamp
// query node already guarantees that the timestamp is ordered, avoid field data copy in c++

View File

@ -235,7 +235,13 @@ class SegmentGrowingImpl : public SegmentGrowing {
bool
HasIndex(FieldId field_id) const override {
return true;
auto& field_meta = schema_->operator[](field_id);
if (datatype_is_vector(field_meta.get_data_type()) &&
indexing_record_.SyncDataWithIndex(field_id)) {
return true;
}
return false;
}
bool

View File

@ -47,6 +47,7 @@ class SegmentSealed : public SegmentInternalInterface {
}
};
using SegmentSealedPtr = std::unique_ptr<SegmentSealed>;
using SegmentSealedSPtr = std::shared_ptr<SegmentSealed>;
using SegmentSealedUPtr = std::unique_ptr<SegmentSealed>;
} // namespace milvus::segcore

View File

@ -33,6 +33,7 @@
#include "mmap/Column.h"
#include "common/Consts.h"
#include "common/FieldMeta.h"
#include "common/FieldData.h"
#include "common/Types.h"
#include "log/Log.h"
#include "pb/schema.pb.h"
@ -40,7 +41,6 @@
#include "query/ScalarIndex.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h"
#include "storage/FieldData.h"
#include "storage/Util.h"
#include "storage/ThreadPools.h"
#include "storage/ChunkCacheSingleton.h"
@ -279,7 +279,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
if (system_field_type == SystemFieldType::Timestamp) {
std::vector<Timestamp> timestamps(num_rows);
int64_t offset = 0;
auto field_data = CollectFieldDataChannel(data.channel);
auto field_data = storage::CollectFieldDataChannel(data.channel);
for (auto& data : field_data) {
int64_t row_count = data->get_num_rows();
std::copy_n(static_cast<const Timestamp*>(data->Data()),
@ -307,7 +307,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
AssertInfo(system_field_type == SystemFieldType::RowId,
"System field type of id column is not RowId");
auto field_data = CollectFieldDataChannel(data.channel);
auto field_data = storage::CollectFieldDataChannel(data.channel);
// write data under lock
std::unique_lock lck(mutex_);
@ -335,7 +335,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
auto var_column =
std::make_shared<VariableColumn<std::string>>(
num_rows, field_meta);
storage::FieldDataPtr field_data;
FieldDataPtr field_data;
while (data.channel->pop(field_data)) {
for (auto i = 0; i < field_data->get_num_rows(); i++) {
auto str = static_cast<const std::string*>(
@ -354,7 +354,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
auto var_column =
std::make_shared<VariableColumn<milvus::Json>>(
num_rows, field_meta);
storage::FieldDataPtr field_data;
FieldDataPtr field_data;
while (data.channel->pop(field_data)) {
for (auto i = 0; i < field_data->get_num_rows(); i++) {
auto padded_string =
@ -374,7 +374,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
case milvus::DataType::ARRAY: {
auto var_column =
std::make_shared<ArrayColumn>(num_rows, field_meta);
storage::FieldDataPtr field_data;
FieldDataPtr field_data;
while (data.channel->pop(field_data)) {
for (auto i = 0; i < field_data->get_num_rows(); i++) {
auto rawValue = field_data->RawValue(i);
@ -398,7 +398,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
field_id, num_rows, field_data_size);
} else {
column = std::make_shared<Column>(num_rows, field_meta);
storage::FieldDataPtr field_data;
FieldDataPtr field_data;
while (data.channel->pop(field_data)) {
column->AppendBatch(field_data);
}
@ -469,7 +469,7 @@ SegmentSealedImpl::MapFieldData(const FieldId field_id, FieldDataInfo& data) {
auto data_size = 0;
std::vector<uint64_t> indices{};
std::vector<std::vector<uint64_t>> element_indices{};
storage::FieldDataPtr field_data;
FieldDataPtr field_data;
while (data.channel->pop(field_data)) {
data_size += field_data->Size();
auto written =
@ -669,8 +669,12 @@ SegmentSealedImpl::mask_with_delete(BitsetType& bitset,
return;
}
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
AssertInfo(delete_bitset.size() == bitset.size(),
"Deleted bitmap size not equal to filtered bitmap size");
AssertInfo(
delete_bitset.size() == bitset.size(),
fmt::format(
"Deleted bitmap size:{} not equal to filtered bitmap size:{}",
delete_bitset.size(),
bitset.size()));
bitset |= delete_bitset;
}

View File

@ -299,7 +299,7 @@ class SegmentSealedImpl : public SegmentSealed {
vec_binlog_config_;
};
inline SegmentSealedPtr
inline SegmentSealedUPtr
CreateSealedSegment(
SchemaPtr schema,
IndexMetaPtr index_meta = nullptr,

View File

@ -14,14 +14,14 @@
#include <memory>
#include <string>
#include "common/Common.h"
#include "common/FieldData.h"
#include "index/ScalarIndex.h"
#include "log/Log.h"
#include "storage/FieldData.h"
#include "storage/RemoteChunkManagerSingleton.h"
#include "common/Common.h"
#include "storage/ThreadPool.h"
#include "storage/Util.h"
#include "mmap/Utils.h"
#include "storage/ThreadPool.h"
#include "storage/RemoteChunkManagerSingleton.h"
#include "storage/Util.h"
namespace milvus::segcore {
@ -50,7 +50,7 @@ ParsePksFromFieldData(std::vector<PkType>& pks, const DataArray& data) {
void
ParsePksFromFieldData(DataType data_type,
std::vector<PkType>& pks,
const std::vector<storage::FieldDataPtr>& datas) {
const std::vector<FieldDataPtr>& datas) {
int64_t offset = 0;
for (auto& field_data : datas) {
@ -737,7 +737,7 @@ LoadFieldDatasFromRemote2(std::shared_ptr<milvus_storage::Space> space,
// segcore use default remote chunk manager to load data from minio/s3
void
LoadFieldDatasFromRemote(std::vector<std::string>& remote_files,
storage::FieldDataChannelPtr channel) {
FieldDataChannelPtr channel) {
try {
auto parallel_degree = static_cast<uint64_t>(
DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE);

View File

@ -20,13 +20,13 @@
#include <utility>
#include <vector>
#include "common/FieldData.h"
#include "common/QueryResult.h"
// #include "common/Schema.h"
#include "common/Types.h"
#include "index/Index.h"
#include "segcore/DeletedRecord.h"
#include "segcore/InsertRecord.h"
#include "index/Index.h"
#include "storage/FieldData.h"
#include "storage/space.h"
namespace milvus::segcore {
@ -37,7 +37,7 @@ ParsePksFromFieldData(std::vector<PkType>& pks, const DataArray& data);
void
ParsePksFromFieldData(DataType data_type,
std::vector<PkType>& pks,
const std::vector<storage::FieldDataPtr>& datas);
const std::vector<FieldDataPtr>& datas);
void
ParsePksFromIDs(std::vector<PkType>& pks,
@ -159,7 +159,7 @@ ReverseDataFromIndex(const index::IndexBase* index,
void
LoadFieldDatasFromRemote(std::vector<std::string>& remote_files,
storage::FieldDataChannelPtr channel);
FieldDataChannelPtr channel);
void
LoadFieldDatasFromRemote2(std::shared_ptr<milvus_storage::Space> space,

View File

@ -10,21 +10,22 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "segcore/segment_c.h"
#include <memory>
#include "common/FieldData.h"
#include "common/LoadInfo.h"
#include "common/Types.h"
#include "common/Tracer.h"
#include "common/type_c.h"
#include "google/protobuf/text_format.h"
#include "log/Log.h"
#include "mmap/Types.h"
#include "segcore/Collection.h"
#include "segcore/SegmentGrowingImpl.h"
#include "segcore/SegmentSealedImpl.h"
#include "segcore/Utils.h"
#include "storage/FieldData.h"
#include "storage/Util.h"
#include "mmap/Types.h"
#include "storage/space.h"
////////////////////////////// common interfaces //////////////////////////////
@ -292,8 +293,8 @@ LoadFieldRawData(CSegmentInterface c_segment,
}
auto field_data = milvus::storage::CreateFieldData(data_type, dim);
field_data->FillFieldData(data, row_count);
milvus::storage::FieldDataChannelPtr channel =
std::make_shared<milvus::storage::FieldDataChannel>();
milvus::FieldDataChannelPtr channel =
std::make_shared<milvus::FieldDataChannel>();
channel->push(field_data);
channel->close();
auto field_data_info = milvus::FieldDataInfo(

View File

@ -28,6 +28,10 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw")
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*")
# TODO: add arm cpu simd
message ("simd using arm mode")
list(APPEND MILVUS_SIMD_SRCS
neon.cpp
)
endif()
add_library(milvus_simd ${MILVUS_SIMD_SRCS})

View File

@ -39,7 +39,7 @@ GetBitsetBlockAVX2(const bool* src) {
BitsetBlockType res[4];
_mm256_storeu_si256((__m256i*)res, tmpvec);
return res[0];
// __m128i tmpvec = _mm_loadu_si64(tmp);
// __m256i tmpvec = _mm_loadu_si64(tmp);
// BitsetBlockType res;
// _mm_storeu_si64(&res, tmpvec);
// return res;
@ -231,6 +231,80 @@ FindTermAVX2(const double* src, size_t vec_size, double val) {
return false;
}
bool
AllFalseAVX2(const bool* src, int64_t size) {
int num_chunk = size / 32;
__m256i highbit = _mm256_set1_epi8(0x7F);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m256i data =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
__m256i highbits = _mm256_add_epi8(data, highbit);
if (_mm256_movemask_epi8(highbits) != 0) {
return false;
}
}
for (size_t i = num_chunk * 16; i < size; ++i) {
if (src[i]) {
return false;
}
}
return true;
}
bool
AllTrueAVX2(const bool* src, int64_t size) {
int num_chunk = size / 16;
__m256i highbit = _mm256_set1_epi8(0x7F);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m256i data =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
__m256i highbits = _mm256_add_epi8(data, highbit);
if (_mm256_movemask_epi8(highbits) != 0xFFFF) {
return false;
}
}
for (size_t i = num_chunk * 16; i < size; ++i) {
if (!src[i]) {
return false;
}
}
return true;
}
void
AndBoolAVX2(bool* left, bool* right, int64_t size) {
int num_chunk = size / 32;
for (size_t i = 0; i < num_chunk * 32; i += 32) {
__m256i l_reg =
_mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i));
__m256i r_reg =
_mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i));
__m256i res = _mm256_and_si256(l_reg, r_reg);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res);
}
for (size_t i = num_chunk * 32; i < size; ++i) {
left[i] &= right[i];
}
}
void
OrBoolAVX2(bool* left, bool* right, int64_t size) {
int num_chunk = size / 32;
for (size_t i = 0; i < num_chunk * 32; i += 32) {
__m256i l_reg =
_mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i));
__m256i r_reg =
_mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i));
__m256i res = _mm256_or_si256(l_reg, r_reg);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res);
}
for (size_t i = num_chunk * 32; i < size; ++i) {
left[i] |= right[i];
}
}
} // namespace simd
} // namespace milvus

View File

@ -58,5 +58,17 @@ template <>
bool
FindTermAVX2(const double* src, size_t vec_size, double val);
bool
AllFalseAVX2(const bool* src, int64_t size);
bool
AllTrueAVX2(const bool* src, int64_t size);
void
AndBoolAVX2(bool* left, bool* right, int64_t size);
void
OrBoolAVX2(bool* left, bool* right, int64_t size);
} // namespace simd
} // namespace milvus

View File

@ -183,6 +183,39 @@ FindTermAVX512(const double* src, size_t vec_size, double val) {
}
return false;
}
void
AndBoolAVX512(bool* left, bool* right, int64_t size) {
int num_chunk = size / 64;
for (size_t i = 0; i < num_chunk * 64; i += 64) {
__m512i l_reg =
_mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i));
__m512i r_reg =
_mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i));
__m512i res = _mm512_and_si512(l_reg, r_reg);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res);
}
for (size_t i = num_chunk * 64; i < size; ++i) {
left[i] &= right[i];
}
}
void
OrBoolAVX512(bool* left, bool* right, int64_t size) {
int num_chunk = size / 64;
for (size_t i = 0; i < num_chunk * 64; i += 64) {
__m512i l_reg =
_mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i));
__m512i r_reg =
_mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i));
__m512i res = _mm512_or_si512(l_reg, r_reg);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res);
}
for (size_t i = num_chunk * 64; i < size; ++i) {
left[i] |= right[i];
}
}
} // namespace simd
} // namespace milvus
#endif

View File

@ -55,5 +55,11 @@ template <>
bool
FindTermAVX512(const double* src, size_t vec_size, double val);
void
AndBoolAVX512(bool* left, bool* right, int64_t size);
void
OrBoolAVX512(bool* left, bool* right, int64_t size);
} // namespace simd
} // namespace milvus

View File

@ -25,6 +25,8 @@
#include "sse2.h"
#include "sse4.h"
#include "instruction_set.h"
#elif defined(__ARM_NEON)
#include "neon.h"
#endif
namespace milvus {
@ -44,6 +46,12 @@ bool use_find_term_avx512;
#endif
decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef;
decltype(all_false) all_false = AllFalseRef;
decltype(all_true) all_true = AllTrueRef;
decltype(invert_bool) invert_bool = InvertBoolRef;
decltype(and_bool) and_bool = AndBoolRef;
decltype(or_bool) or_bool = OrBoolRef;
FindTermPtr<bool> find_term_bool = FindTermRef<bool>;
FindTermPtr<int8_t> find_term_int8 = FindTermRef<int8_t>;
FindTermPtr<int16_t> find_term_int16 = FindTermRef<int16_t>;
@ -161,9 +169,82 @@ find_term_hook() {
LOG_SEGCORE_INFO_ << "find term hook simd type: " << simd_type;
}
void
all_boolean_hook() {
static std::mutex hook_mutex;
std::lock_guard<std::mutex> lock(hook_mutex);
std::string simd_type = "REF";
#if defined(__x86_64__)
if (use_sse2 && cpu_support_sse2()) {
simd_type = "SSE2";
all_false = AllFalseSSE2;
all_true = AllTrueSSE2;
}
#elif defined(__ARM_NEON)
simd_type = "NEON";
all_false = AllFalseNEON;
all_true = AllTrueNEON;
#endif
// TODO: support arm cpu
LOG_SEGCORE_INFO_ << "AllFalse/AllTrue hook simd type: " << simd_type;
}
void
invert_boolean_hook() {
static std::mutex hook_mutex;
std::lock_guard<std::mutex> lock(hook_mutex);
std::string simd_type = "REF";
#if defined(__x86_64__)
if (use_sse2 && cpu_support_sse2()) {
simd_type = "SSE2";
invert_bool = InvertBoolSSE2;
}
#elif defined(__ARM_NEON)
simd_type = "NEON";
invert_bool = InvertBoolNEON;
#endif
// TODO: support arm cpu
LOG_SEGCORE_INFO_ << "InvertBoolean hook simd type: " << simd_type;
}
void
logical_boolean_hook() {
static std::mutex hook_mutex;
std::lock_guard<std::mutex> lock(hook_mutex);
std::string simd_type = "REF";
#if defined(__x86_64__)
if (use_avx512 && cpu_support_avx512()) {
simd_type = "AVX512";
and_bool = AndBoolAVX512;
or_bool = OrBoolAVX512;
} else if (use_avx2 && cpu_support_avx2()) {
simd_type = "AVX2";
and_bool = AndBoolAVX2;
or_bool = OrBoolAVX2;
} else if (use_sse2 && cpu_support_sse2()) {
simd_type = "SSE2";
and_bool = AndBoolSSE2;
or_bool = OrBoolSSE2;
}
#elif defined(__ARM_NEON)
simd_type = "NEON";
and_bool = AndBoolNEON;
or_bool = OrBoolNEON;
#endif
// TODO: support arm cpu
LOG_SEGCORE_INFO_ << "InvertBoolean hook simd type: " << simd_type;
}
void
boolean_hook() {
all_boolean_hook();
invert_boolean_hook();
logical_boolean_hook();
}
static int init_hook_ = []() {
bitset_hook();
find_term_hook();
boolean_hook();
return 0;
}();

View File

@ -19,6 +19,11 @@ namespace milvus {
namespace simd {
extern BitsetBlockType (*get_bitset_block)(const bool* src);
extern bool (*all_false)(const bool* src, int64_t size);
extern bool (*all_true)(const bool* src, int64_t size);
extern void (*invert_bool)(bool* src, int64_t size);
extern void (*and_bool)(bool* left, bool* right, int64_t size);
extern void (*or_bool)(bool* left, bool* right, int64_t size);
template <typename T>
using FindTermPtr = bool (*)(const T* src, size_t size, T val);
@ -63,6 +68,18 @@ bitset_hook();
void
find_term_hook();
void
boolean_hook();
void
all_boolean_hook();
void
invert_boolean_hook();
void
logical_boolean_hook();
template <typename T>
bool
find_term_func(const T* data, size_t size, T val) {

View File

@ -0,0 +1,123 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#if defined(__ARM_NEON)
#include "neon.h"
#include <cstddef>
#include <arm_neon.h>
namespace milvus {
namespace simd {
bool
AllFalseNEON(const bool* src, int64_t size) {
int num_chunk = size / 16;
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(src);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
uint8x16_t data = vld1q_u8(ptr + i);
if (vmaxvq_u8(data) != 0) {
return false;
}
}
for (size_t i = num_chunk * 16; i < size; ++i) {
if (src[i]) {
return false;
}
}
return true;
}
bool
AllTrueNEON(const bool* src, int64_t size) {
int num_chunk = size / 16;
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(src);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
uint8x16_t data = vld1q_u8(ptr + i);
if (vminvq_u8(data) == 0) {
return false;
}
}
for (size_t i = num_chunk * 16; i < size; ++i) {
if (!src[i]) {
return false;
}
}
return true;
}
void
InvertBoolNEON(bool* src, int64_t size) {
int num_chunk = size / 16;
uint8x16_t mask = vdupq_n_u8(0x01);
uint8_t* ptr = reinterpret_cast<uint8_t*>(src);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
uint8x16_t data = vld1q_u8(ptr + i);
uint8x16_t flipped = veorq_u8(data, mask);
vst1q_u8(ptr + i, flipped);
}
for (size_t i = num_chunk * 16; i < size; ++i) {
src[i] = !src[i];
}
}
void
AndBoolNEON(bool* left, bool* right, int64_t size) {
int num_chunk = size / 16;
uint8_t* lptr = reinterpret_cast<uint8_t*>(left);
uint8_t* rptr = reinterpret_cast<uint8_t*>(right);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
uint8x16_t l_reg = vld1q_u8(lptr + i);
uint8x16_t r_reg = vld1q_u8(rptr + i);
uint8x16_t res = vandq_u8(l_reg, r_reg);
vst1q_u8(lptr + i, res);
}
for (size_t i = num_chunk * 16; i < size; ++i) {
left[i] &= right[i];
}
}
void
OrBoolNEON(bool* left, bool* right, int64_t size) {
int num_chunk = size / 16;
uint8_t* lptr = reinterpret_cast<uint8_t*>(left);
uint8_t* rptr = reinterpret_cast<uint8_t*>(right);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
uint8x16_t l_reg = vld1q_u8(lptr + i);
uint8x16_t r_reg = vld1q_u8(rptr + i);
uint8x16_t res = vorrq_u8(l_reg, r_reg);
vst1q_u8(lptr + i, res);
}
for (size_t i = num_chunk * 16; i < size; ++i) {
left[i] |= right[i];
}
}
} // namespace simd
} // namespace milvus
#endif

View File

@ -0,0 +1,37 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <cstdint>
#include "common.h"
namespace milvus {
namespace simd {
BitsetBlockType
GetBitsetBlockSSE2(const bool* src);
bool
AllFalseNEON(const bool* src, int64_t size);
bool
AllTrueNEON(const bool* src, int64_t size);
void
InvertBoolNEON(bool* src, int64_t size);
void
AndBoolNEON(bool* left, bool* right, int64_t size);
void
OrBoolNEON(bool* left, bool* right, int64_t size);
} // namespace simd
} // namespace milvus

View File

@ -29,5 +29,46 @@ GetBitsetBlockRef(const bool* src) {
return val;
}
bool
AllTrueRef(const bool* src, int64_t size) {
for (size_t i = 0; i < size; ++i) {
if (!src[i]) {
return false;
}
}
return true;
}
bool
AllFalseRef(const bool* src, int64_t size) {
for (size_t i = 0; i < size; ++i) {
if (src[i]) {
return false;
}
}
return true;
}
void
InvertBoolRef(bool* src, int64_t size) {
for (size_t i = 0; i < size; ++i) {
src[i] = !src[i];
}
}
void
AndBoolRef(bool* left, bool* right, int64_t size) {
for (size_t i = 0; i < size; ++i) {
left[i] &= right[i];
}
}
void
OrBoolRef(bool* left, bool* right, int64_t size) {
for (size_t i = 0; i < size; ++i) {
left[i] |= right[i];
}
}
} // namespace simd
} // namespace milvus

View File

@ -19,6 +19,21 @@ namespace simd {
BitsetBlockType
GetBitsetBlockRef(const bool* src);
bool
AllTrueRef(const bool* src, int64_t size);
bool
AllFalseRef(const bool* src, int64_t size);
void
InvertBoolRef(bool* src, int64_t size);
void
AndBoolRef(bool* left, bool* right, int64_t size);
void
OrBoolRef(bool* left, bool* right, int64_t size);
template <typename T>
bool
FindTermRef(const T* src, size_t size, T val) {

View File

@ -256,6 +256,102 @@ FindTermSSE2(const double* src, size_t vec_size, double val) {
return false;
}
void
print_m128i(__m128i v) {
alignas(16) int result[4];
_mm_store_si128(reinterpret_cast<__m128i*>(result), v);
for (int i = 0; i < 4; ++i) {
std::cout << std::hex << result[i] << " ";
}
std::cout << std::endl;
}
bool
AllFalseSSE2(const bool* src, int64_t size) {
int num_chunk = size / 16;
__m128i highbit = _mm_set1_epi8(0x7F);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m128i data =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
__m128i highbits = _mm_add_epi8(data, highbit);
if (_mm_movemask_epi8(highbits) != 0) {
return false;
}
}
for (size_t i = num_chunk * 16; i < size; ++i) {
if (src[i]) {
return false;
}
}
return true;
}
bool
AllTrueSSE2(const bool* src, int64_t size) {
int num_chunk = size / 16;
__m128i highbit = _mm_set1_epi8(0x7F);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m128i data =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
__m128i highbits = _mm_add_epi8(data, highbit);
if (_mm_movemask_epi8(highbits) != 0xFFFF) {
return false;
}
}
for (size_t i = num_chunk * 16; i < size; ++i) {
if (!src[i]) {
return false;
}
}
return true;
}
void
InvertBoolSSE2(bool* src, int64_t size) {
int num_chunk = size / 16;
__m128i mask = _mm_set1_epi8(0x01);
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m128i data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src + i));
__m128i flipped = _mm_xor_si128(data, mask);
_mm_storeu_si128(reinterpret_cast<__m128i*>(src + i), flipped);
}
for (size_t i = num_chunk * 16; i < size; ++i) {
src[i] = !src[i];
}
}
void
AndBoolSSE2(bool* left, bool* right, int64_t size) {
int num_chunk = size / 16;
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i));
__m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i));
__m128i res = _mm_and_si128(l_reg, r_reg);
_mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res);
}
for (size_t i = num_chunk * 16; i < size; ++i) {
left[i] &= right[i];
}
}
void
OrBoolSSE2(bool* left, bool* right, int64_t size) {
int num_chunk = size / 16;
for (size_t i = 0; i < num_chunk * 16; i += 16) {
__m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i));
__m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i));
__m128i res = _mm_or_si128(l_reg, r_reg);
_mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res);
}
for (size_t i = num_chunk * 16; i < size; ++i) {
left[i] |= right[i];
}
}
} // namespace simd
} // namespace milvus

View File

@ -24,6 +24,21 @@ namespace simd {
BitsetBlockType
GetBitsetBlockSSE2(const bool* src);
bool
AllFalseSSE2(const bool* src, int64_t size);
bool
AllTrueSSE2(const bool* src, int64_t size);
void
InvertBoolSSE2(bool* src, int64_t size);
void
AndBoolSSE2(bool* left, bool* right, int64_t size);
void
OrBoolSSE2(bool* left, bool* right, int64_t size);
template <typename T>
bool
FindTermSSE2(const T* src, size_t vec_size, T va) {

Some files were not shown because too many files have changed in this diff Show More