mirror of https://github.com/milvus-io/milvus.git
enhance: Refactor runtime and expr framework (#28166)
#28165 Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>pull/29293/head
parent
438f39e268
commit
a602171d06
2
Makefile
2
Makefile
|
@ -30,7 +30,7 @@ ifdef USE_ASAN
|
|||
use_asan =${USE_ASAN}
|
||||
endif
|
||||
|
||||
use_dynamic_simd = OFF
|
||||
use_dynamic_simd = ON
|
||||
ifdef USE_DYNAMIC_SIMD
|
||||
use_dynamic_simd = ${USE_DYNAMIC_SIMD}
|
||||
endif
|
||||
|
|
|
@ -5,6 +5,11 @@
|
|||
# `INFO``, ``WARNING``, ``ERROR``, and ``FATAL`` are 0, 1, 2, and 3
|
||||
--minloglevel=0
|
||||
--log_dir=/var/lib/milvus/logs/
|
||||
# using vlog to implement debug and trace log
|
||||
# if set vmodule to 5, open debug level
|
||||
# if set vmodule to 6, open trace level
|
||||
# default 4, not open debug and trace
|
||||
--v=4
|
||||
# MB
|
||||
--max_log_size=200
|
||||
--stop_logging_if_full_disk=true
|
||||
--stop_logging_if_full_disk=true
|
||||
|
|
|
@ -288,6 +288,7 @@ queryNode:
|
|||
# This parameter is only useful when enable-disk = true.
|
||||
# And this value should be a number greater than 1 and less than 32.
|
||||
chunkRows: 1024 # The number of vectors in a chunk.
|
||||
exprEvalBatchSize: 8192 # The batch size for executor get next
|
||||
interimIndex: # build a vector temperate index for growing segment or binlog to accelerate search
|
||||
enableIndex: true
|
||||
nlist: 128 # segment index nlist
|
||||
|
|
|
@ -36,7 +36,7 @@ class MilvusConan(ConanFile):
|
|||
"xz_utils/5.4.0",
|
||||
"prometheus-cpp/1.1.0",
|
||||
"re2/20230301",
|
||||
"folly/2023.10.30.04@milvus/dev",
|
||||
"folly/2023.10.30.05@milvus/dev",
|
||||
"google-cloud-cpp/2.5.0@milvus/dev",
|
||||
"opentelemetry-cpp/1.8.1.1@milvus/dev",
|
||||
"librdkafka/1.9.1",
|
||||
|
@ -44,6 +44,9 @@ class MilvusConan(ConanFile):
|
|||
)
|
||||
generators = ("cmake", "cmake_find_package")
|
||||
default_options = {
|
||||
"libevent:shared": True,
|
||||
"double-conversion:shared": True,
|
||||
"folly:shared": True,
|
||||
"librdkafka:shared": True,
|
||||
"librdkafka:zstd": True,
|
||||
"librdkafka:ssl": True,
|
||||
|
|
|
@ -32,6 +32,7 @@ add_subdirectory( index )
|
|||
add_subdirectory( query )
|
||||
add_subdirectory( segcore )
|
||||
add_subdirectory( indexbuilder )
|
||||
add_subdirectory( exec )
|
||||
if(USE_DYNAMIC_SIMD)
|
||||
add_subdirectory( simd )
|
||||
endif()
|
||||
|
|
|
@ -22,6 +22,7 @@ set(COMMON_SRC
|
|||
Tracer.cpp
|
||||
IndexMeta.cpp
|
||||
EasyAssert.cpp
|
||||
FieldData.cpp
|
||||
)
|
||||
|
||||
add_library(milvus_common SHARED ${COMMON_SRC})
|
||||
|
|
|
@ -27,6 +27,7 @@ int64_t MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT =
|
|||
int64_t LOW_PRIORITY_THREAD_CORE_COEFFICIENT =
|
||||
DEFAULT_LOW_PRIORITY_THREAD_CORE_COEFFICIENT;
|
||||
int CPU_NUM = DEFAULT_CPU_NUM;
|
||||
int64_t EXEC_EVAL_EXPR_BATCH_SIZE = DEFAULT_EXEC_EVAL_EXPR_BATCH_SIZE;
|
||||
|
||||
void
|
||||
SetIndexSliceSize(const int64_t size) {
|
||||
|
@ -56,6 +57,13 @@ SetLowPriorityThreadCoreCoefficient(const int64_t coefficient) {
|
|||
<< LOW_PRIORITY_THREAD_CORE_COEFFICIENT;
|
||||
}
|
||||
|
||||
void
|
||||
SetDefaultExecEvalExprBatchSize(int64_t val) {
|
||||
EXEC_EVAL_EXPR_BATCH_SIZE = val;
|
||||
LOG_SEGCORE_INFO_ << "set default expr eval batch size: "
|
||||
<< EXEC_EVAL_EXPR_BATCH_SIZE;
|
||||
}
|
||||
|
||||
void
|
||||
SetCpuNum(const int num) {
|
||||
CPU_NUM = num;
|
||||
|
|
|
@ -26,6 +26,7 @@ extern int64_t HIGH_PRIORITY_THREAD_CORE_COEFFICIENT;
|
|||
extern int64_t MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT;
|
||||
extern int64_t LOW_PRIORITY_THREAD_CORE_COEFFICIENT;
|
||||
extern int CPU_NUM;
|
||||
extern int64_t EXEC_EVAL_EXPR_BATCH_SIZE;
|
||||
|
||||
void
|
||||
SetIndexSliceSize(const int64_t size);
|
||||
|
@ -42,4 +43,7 @@ SetLowPriorityThreadCoreCoefficient(const int64_t coefficient);
|
|||
void
|
||||
SetCpuNum(const int core);
|
||||
|
||||
void
|
||||
SetDefaultExecEvalExprBatchSize(int64_t val);
|
||||
|
||||
} // namespace milvus
|
||||
|
|
|
@ -39,6 +39,10 @@ const char INDEX_BUILD_ID_KEY[] = "indexBuildID";
|
|||
const char INDEX_ROOT_PATH[] = "index_files";
|
||||
const char RAWDATA_ROOT_PATH[] = "raw_datas";
|
||||
|
||||
const char DEFAULT_PLANNODE_ID[] = "0";
|
||||
const char DEAFULT_QUERY_ID[] = "0";
|
||||
const char DEFAULT_TASK_ID[] = "0";
|
||||
|
||||
const int64_t DEFAULT_FIELD_MAX_MEMORY_LIMIT = 64 << 20; // bytes
|
||||
const int64_t DEFAULT_HIGH_PRIORITY_THREAD_CORE_COEFFICIENT = 10;
|
||||
const int64_t DEFAULT_MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT = 5;
|
||||
|
@ -48,6 +52,8 @@ const int64_t DEFAULT_INDEX_FILE_SLICE_SIZE = 4 << 20; // bytes
|
|||
|
||||
const int DEFAULT_CPU_NUM = 1;
|
||||
|
||||
const int64_t DEFAULT_EXEC_EVAL_EXPR_BATCH_SIZE = 8192;
|
||||
|
||||
constexpr const char* RADIUS = knowhere::meta::RADIUS;
|
||||
constexpr const char* RANGE_FILTER = knowhere::meta::RANGE_FILTER;
|
||||
|
||||
|
|
|
@ -0,0 +1,218 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
class NotImplementedException : public std::exception {
|
||||
public:
|
||||
explicit NotImplementedException(const std::string& msg)
|
||||
: std::exception(), exception_message_(msg) {
|
||||
}
|
||||
const char*
|
||||
what() const noexcept {
|
||||
return exception_message_.c_str();
|
||||
}
|
||||
virtual ~NotImplementedException() {
|
||||
}
|
||||
|
||||
private:
|
||||
std::string exception_message_;
|
||||
};
|
||||
|
||||
class NotSupportedDataTypeException : public std::exception {
|
||||
public:
|
||||
explicit NotSupportedDataTypeException(const std::string& msg)
|
||||
: std::exception(), exception_message_(msg) {
|
||||
}
|
||||
const char*
|
||||
what() const noexcept {
|
||||
return exception_message_.c_str();
|
||||
}
|
||||
virtual ~NotSupportedDataTypeException() {
|
||||
}
|
||||
|
||||
private:
|
||||
std::string exception_message_;
|
||||
};
|
||||
|
||||
class UnistdException : public std::runtime_error {
|
||||
public:
|
||||
explicit UnistdException(const std::string& msg) : std::runtime_error(msg) {
|
||||
}
|
||||
|
||||
virtual ~UnistdException() {
|
||||
}
|
||||
};
|
||||
|
||||
// Exceptions for storage module
|
||||
class LocalChunkManagerException : public std::runtime_error {
|
||||
public:
|
||||
explicit LocalChunkManagerException(const std::string& msg)
|
||||
: std::runtime_error(msg) {
|
||||
}
|
||||
virtual ~LocalChunkManagerException() {
|
||||
}
|
||||
};
|
||||
|
||||
class InvalidPathException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit InvalidPathException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~InvalidPathException() {
|
||||
}
|
||||
};
|
||||
|
||||
class OpenFileException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit OpenFileException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~OpenFileException() {
|
||||
}
|
||||
};
|
||||
|
||||
class CreateFileException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit CreateFileException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~CreateFileException() {
|
||||
}
|
||||
};
|
||||
|
||||
class ReadFileException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit ReadFileException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~ReadFileException() {
|
||||
}
|
||||
};
|
||||
|
||||
class WriteFileException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit WriteFileException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~WriteFileException() {
|
||||
}
|
||||
};
|
||||
|
||||
class PathAlreadyExistException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit PathAlreadyExistException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~PathAlreadyExistException() {
|
||||
}
|
||||
};
|
||||
|
||||
class DirNotExistException : public LocalChunkManagerException {
|
||||
public:
|
||||
explicit DirNotExistException(const std::string& msg)
|
||||
: LocalChunkManagerException(msg) {
|
||||
}
|
||||
virtual ~DirNotExistException() {
|
||||
}
|
||||
};
|
||||
|
||||
class MinioException : public std::runtime_error {
|
||||
public:
|
||||
explicit MinioException(const std::string& msg) : std::runtime_error(msg) {
|
||||
}
|
||||
virtual ~MinioException() {
|
||||
}
|
||||
};
|
||||
|
||||
class InvalidBucketNameException : public MinioException {
|
||||
public:
|
||||
explicit InvalidBucketNameException(const std::string& msg)
|
||||
: MinioException(msg) {
|
||||
}
|
||||
virtual ~InvalidBucketNameException() {
|
||||
}
|
||||
};
|
||||
|
||||
class ObjectNotExistException : public MinioException {
|
||||
public:
|
||||
explicit ObjectNotExistException(const std::string& msg)
|
||||
: MinioException(msg) {
|
||||
}
|
||||
virtual ~ObjectNotExistException() {
|
||||
}
|
||||
};
|
||||
class S3ErrorException : public MinioException {
|
||||
public:
|
||||
explicit S3ErrorException(const std::string& msg) : MinioException(msg) {
|
||||
}
|
||||
virtual ~S3ErrorException() {
|
||||
}
|
||||
};
|
||||
|
||||
class DiskANNFileManagerException : public std::runtime_error {
|
||||
public:
|
||||
explicit DiskANNFileManagerException(const std::string& msg)
|
||||
: std::runtime_error(msg) {
|
||||
}
|
||||
virtual ~DiskANNFileManagerException() {
|
||||
}
|
||||
};
|
||||
|
||||
class ArrowException : public std::runtime_error {
|
||||
public:
|
||||
explicit ArrowException(const std::string& msg) : std::runtime_error(msg) {
|
||||
}
|
||||
virtual ~ArrowException() {
|
||||
}
|
||||
};
|
||||
|
||||
// Exceptions for executor module
|
||||
class ExecDriverException : public std::exception {
|
||||
public:
|
||||
explicit ExecDriverException(const std::string& msg)
|
||||
: std::exception(), exception_message_(msg) {
|
||||
}
|
||||
const char*
|
||||
what() const noexcept {
|
||||
return exception_message_.c_str();
|
||||
}
|
||||
virtual ~ExecDriverException() {
|
||||
}
|
||||
|
||||
private:
|
||||
std::string exception_message_;
|
||||
};
|
||||
class ExecOperatorException : public std::exception {
|
||||
public:
|
||||
explicit ExecOperatorException(const std::string& msg)
|
||||
: std::exception(), exception_message_(msg) {
|
||||
}
|
||||
const char*
|
||||
what() const noexcept {
|
||||
return exception_message_.c_str();
|
||||
}
|
||||
virtual ~ExecOperatorException() {
|
||||
}
|
||||
|
||||
private:
|
||||
std::string exception_message_;
|
||||
};
|
||||
} // namespace milvus
|
|
@ -14,15 +14,17 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "storage/FieldData.h"
|
||||
#include "common/FieldData.h"
|
||||
|
||||
#include "arrow/array/array_binary.h"
|
||||
#include "common/Array.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Exception.h"
|
||||
#include "common/FieldDataInterface.h"
|
||||
#include "common/Json.h"
|
||||
#include "simdjson/padded_string.h"
|
||||
#include "common/Array.h"
|
||||
#include "FieldDataInterface.h"
|
||||
|
||||
namespace milvus::storage {
|
||||
namespace milvus {
|
||||
|
||||
template <typename Type, bool is_scalar>
|
||||
void
|
||||
|
@ -183,4 +185,33 @@ template class FieldDataImpl<int8_t, false>;
|
|||
template class FieldDataImpl<float, false>;
|
||||
template class FieldDataImpl<float16, false>;
|
||||
|
||||
} // namespace milvus::storage
|
||||
FieldDataPtr
|
||||
InitScalarFieldData(const DataType& type, int64_t cap_rows) {
|
||||
switch (type) {
|
||||
case DataType::BOOL:
|
||||
return std::make_shared<FieldData<bool>>(type, cap_rows);
|
||||
case DataType::INT8:
|
||||
return std::make_shared<FieldData<int8_t>>(type, cap_rows);
|
||||
case DataType::INT16:
|
||||
return std::make_shared<FieldData<int16_t>>(type, cap_rows);
|
||||
case DataType::INT32:
|
||||
return std::make_shared<FieldData<int32_t>>(type, cap_rows);
|
||||
case DataType::INT64:
|
||||
return std::make_shared<FieldData<int64_t>>(type, cap_rows);
|
||||
case DataType::FLOAT:
|
||||
return std::make_shared<FieldData<float>>(type, cap_rows);
|
||||
case DataType::DOUBLE:
|
||||
return std::make_shared<FieldData<double>>(type, cap_rows);
|
||||
case DataType::STRING:
|
||||
case DataType::VARCHAR:
|
||||
return std::make_shared<FieldData<std::string>>(type, cap_rows);
|
||||
case DataType::JSON:
|
||||
return std::make_shared<FieldData<Json>>(type, cap_rows);
|
||||
default:
|
||||
throw NotSupportedDataTypeException(
|
||||
"InitScalarFieldData not support data type " +
|
||||
datatype_name(type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus
|
|
@ -21,10 +21,10 @@
|
|||
|
||||
#include <oneapi/tbb/concurrent_queue.h>
|
||||
|
||||
#include "storage/FieldDataInterface.h"
|
||||
#include "common/FieldDataInterface.h"
|
||||
#include "common/Channel.h"
|
||||
|
||||
namespace milvus::storage {
|
||||
namespace milvus {
|
||||
|
||||
template <typename Type>
|
||||
class FieldData : public FieldDataImpl<Type, true> {
|
||||
|
@ -34,6 +34,11 @@ class FieldData : public FieldDataImpl<Type, true> {
|
|||
: FieldDataImpl<Type, true>::FieldDataImpl(
|
||||
1, data_type, buffered_num_rows) {
|
||||
}
|
||||
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
|
||||
explicit FieldData(DataType data_type, FixedVector<Type>&& inner_data)
|
||||
: FieldDataImpl<Type, true>::FieldDataImpl(
|
||||
1, data_type, std::move(inner_data)) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
|
@ -106,7 +111,10 @@ class FieldData<Float16Vector> : public FieldDataImpl<float16, false> {
|
|||
};
|
||||
|
||||
using FieldDataPtr = std::shared_ptr<FieldDataBase>;
|
||||
using FieldDataChannel = Channel<storage::FieldDataPtr>;
|
||||
using FieldDataChannel = Channel<FieldDataPtr>;
|
||||
using FieldDataChannelPtr = std::shared_ptr<FieldDataChannel>;
|
||||
|
||||
} // namespace milvus::storage
|
||||
FieldDataPtr
|
||||
InitScalarFieldData(const DataType& type, int64_t cap_rows);
|
||||
|
||||
} // namespace milvus
|
|
@ -33,7 +33,7 @@
|
|||
#include "common/EasyAssert.h"
|
||||
#include "common/Array.h"
|
||||
|
||||
namespace milvus::storage {
|
||||
namespace milvus {
|
||||
|
||||
using DataType = milvus::DataType;
|
||||
|
||||
|
@ -49,8 +49,8 @@ class FieldDataBase {
|
|||
virtual void
|
||||
FillFieldData(const std::shared_ptr<arrow::Array> array) = 0;
|
||||
|
||||
virtual const void*
|
||||
Data() const = 0;
|
||||
virtual void*
|
||||
Data() = 0;
|
||||
|
||||
virtual const void*
|
||||
RawValue(ssize_t offset) const = 0;
|
||||
|
@ -109,6 +109,12 @@ class FieldDataImpl : public FieldDataBase {
|
|||
field_data_.resize(num_rows_ * dim_);
|
||||
}
|
||||
|
||||
explicit FieldDataImpl(size_t dim, DataType type, Chunk&& field_data)
|
||||
: FieldDataBase(type), dim_(is_scalar ? 1 : dim) {
|
||||
field_data_ = std::move(field_data);
|
||||
num_rows_ = field_data.size() / dim;
|
||||
}
|
||||
|
||||
void
|
||||
FillFieldData(const void* source, ssize_t element_count) override;
|
||||
|
||||
|
@ -126,8 +132,8 @@ class FieldDataImpl : public FieldDataBase {
|
|||
return "FieldDataImpl";
|
||||
}
|
||||
|
||||
const void*
|
||||
Data() const override {
|
||||
void*
|
||||
Data() override {
|
||||
return field_data_.data();
|
||||
}
|
||||
|
||||
|
@ -332,4 +338,4 @@ class FieldDataArrayImpl : public FieldDataImpl<Array, true> {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace milvus::storage
|
||||
} // namespace milvus
|
|
@ -0,0 +1,75 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <folly/Unit.h>
|
||||
#include <folly/futures/Future.h>
|
||||
|
||||
#include "log/Log.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
template <class T>
|
||||
class MilvusPromise : public folly::Promise<T> {
|
||||
public:
|
||||
MilvusPromise() : folly::Promise<T>() {
|
||||
}
|
||||
|
||||
explicit MilvusPromise(const std::string& context)
|
||||
: folly::Promise<T>(), context_(context) {
|
||||
}
|
||||
|
||||
MilvusPromise(folly::futures::detail::EmptyConstruct,
|
||||
const std::string& context) noexcept
|
||||
: folly::Promise<T>(folly::Promise<T>::makeEmpty()), context_(context) {
|
||||
}
|
||||
|
||||
~MilvusPromise() {
|
||||
if (!this->isFulfilled()) {
|
||||
LOG_SEGCORE_WARNING_
|
||||
<< "PROMISE: Unfulfilled promise is being deleted. Context: "
|
||||
<< context_;
|
||||
}
|
||||
}
|
||||
|
||||
explicit MilvusPromise(MilvusPromise<T>&& other)
|
||||
: folly::Promise<T>(std::move(other)),
|
||||
context_(std::move(other.context_)) {
|
||||
}
|
||||
|
||||
MilvusPromise&
|
||||
operator=(MilvusPromise<T>&& other) noexcept {
|
||||
folly::Promise<T>::operator=(std::move(other));
|
||||
context_ = std::move(other.context_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
static MilvusPromise
|
||||
MakeEmpty(const std::string& context = "") noexcept {
|
||||
return MilvusPromise<T>(folly::futures::detail::EmptyConstruct{},
|
||||
context);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Optional parameter to understand where this promise was created.
|
||||
std::string context_;
|
||||
};
|
||||
|
||||
using ContinuePromise = MilvusPromise<folly::Unit>;
|
||||
using ContinueFuture = folly::SemiFuture<folly::Unit>;
|
||||
|
||||
} // namespace milvus
|
|
@ -95,6 +95,10 @@ enum class DataType {
|
|||
ARRAY = 22,
|
||||
JSON = 23,
|
||||
|
||||
// Some special Data type, start from after 50
|
||||
// just for internal use now, may sync proto in future
|
||||
ROW = 50,
|
||||
|
||||
VECTOR_BINARY = 100,
|
||||
VECTOR_FLOAT = 101,
|
||||
VECTOR_FLOAT16 = 102,
|
||||
|
@ -182,8 +186,138 @@ using MayConstRef = std::conditional_t<std::is_same_v<T, std::string> ||
|
|||
const T&,
|
||||
T>;
|
||||
static_assert(std::is_same_v<const std::string&, MayConstRef<std::string>>);
|
||||
|
||||
template <DataType T>
|
||||
struct TypeTraits {};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::NONE> {
|
||||
static constexpr const char* Name = "NONE";
|
||||
};
|
||||
template <>
|
||||
struct TypeTraits<DataType::BOOL> {
|
||||
using NativeType = bool;
|
||||
static constexpr DataType TypeKind = DataType::BOOL;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "BOOL";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::INT8> {
|
||||
using NativeType = int8_t;
|
||||
static constexpr DataType TypeKind = DataType::INT8;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "INT8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::INT16> {
|
||||
using NativeType = int16_t;
|
||||
static constexpr DataType TypeKind = DataType::INT16;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "INT16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::INT32> {
|
||||
using NativeType = int32_t;
|
||||
static constexpr DataType TypeKind = DataType::INT32;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "INT32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::INT64> {
|
||||
using NativeType = int32_t;
|
||||
static constexpr DataType TypeKind = DataType::INT64;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "INT64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::FLOAT> {
|
||||
using NativeType = float;
|
||||
static constexpr DataType TypeKind = DataType::FLOAT;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "FLOAT";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::DOUBLE> {
|
||||
using NativeType = double;
|
||||
static constexpr DataType TypeKind = DataType::DOUBLE;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = true;
|
||||
static constexpr const char* Name = "DOUBLE";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::VARCHAR> {
|
||||
using NativeType = std::string;
|
||||
static constexpr DataType TypeKind = DataType::VARCHAR;
|
||||
static constexpr bool IsPrimitiveType = true;
|
||||
static constexpr bool IsFixedWidth = false;
|
||||
static constexpr const char* Name = "VARCHAR";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::STRING> : public TypeTraits<DataType::VARCHAR> {
|
||||
static constexpr DataType TypeKind = DataType::STRING;
|
||||
static constexpr const char* Name = "STRING";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::ARRAY> {
|
||||
using NativeType = void;
|
||||
static constexpr DataType TypeKind = DataType::ARRAY;
|
||||
static constexpr bool IsPrimitiveType = false;
|
||||
static constexpr bool IsFixedWidth = false;
|
||||
static constexpr const char* Name = "ARRAY";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::JSON> {
|
||||
using NativeType = void;
|
||||
static constexpr DataType TypeKind = DataType::JSON;
|
||||
static constexpr bool IsPrimitiveType = false;
|
||||
static constexpr bool IsFixedWidth = false;
|
||||
static constexpr const char* Name = "JSON";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::ROW> {
|
||||
using NativeType = void;
|
||||
static constexpr DataType TypeKind = DataType::ROW;
|
||||
static constexpr bool IsPrimitiveType = false;
|
||||
static constexpr bool IsFixedWidth = false;
|
||||
static constexpr const char* Name = "ROW";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::VECTOR_BINARY> {
|
||||
using NativeType = uint8_t;
|
||||
static constexpr DataType TypeKind = DataType::VECTOR_BINARY;
|
||||
static constexpr bool IsPrimitiveType = false;
|
||||
static constexpr bool IsFixedWidth = false;
|
||||
static constexpr const char* Name = "VECTOR_BINARY";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<DataType::VECTOR_FLOAT> {
|
||||
using NativeType = float;
|
||||
static constexpr DataType TypeKind = DataType::VECTOR_FLOAT;
|
||||
static constexpr bool IsPrimitiveType = false;
|
||||
static constexpr bool IsFixedWidth = false;
|
||||
static constexpr const char* Name = "VECTOR_FLOAT";
|
||||
};
|
||||
|
||||
} // namespace milvus
|
||||
//
|
||||
template <>
|
||||
struct fmt::formatter<milvus::DataType> : formatter<string_view> {
|
||||
auto
|
||||
|
@ -226,6 +360,9 @@ struct fmt::formatter<milvus::DataType> : formatter<string_view> {
|
|||
case milvus::DataType::JSON:
|
||||
name = "JSON";
|
||||
break;
|
||||
case milvus::DataType::ROW:
|
||||
name = "ROW";
|
||||
break;
|
||||
case milvus::DataType::VECTOR_BINARY:
|
||||
name = "VECTOR_BINARY";
|
||||
break;
|
||||
|
|
|
@ -192,4 +192,17 @@ is_in_disk_list(const IndexType& index_type) {
|
|||
return is_in_list<IndexType>(index_type, DISK_INDEX_LIST);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string
|
||||
Join(const std::vector<T>& items, const std::string& delimiter) {
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < items.size(); ++i) {
|
||||
if (i > 0) {
|
||||
ss << delimiter;
|
||||
}
|
||||
ss << items[i];
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace milvus
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common/FieldData.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
/**
|
||||
* @brief base class for different type vector
|
||||
* @todo implement full null value support
|
||||
*/
|
||||
|
||||
class BaseVector {
|
||||
public:
|
||||
BaseVector(DataType data_type,
|
||||
size_t length,
|
||||
std::optional<size_t> null_count = std::nullopt)
|
||||
: type_kind_(data_type), length_(length), null_count_(null_count) {
|
||||
}
|
||||
virtual ~BaseVector() = default;
|
||||
|
||||
int64_t
|
||||
size() {
|
||||
return length_;
|
||||
}
|
||||
|
||||
DataType
|
||||
type() {
|
||||
return type_kind_;
|
||||
}
|
||||
|
||||
protected:
|
||||
DataType type_kind_;
|
||||
size_t length_;
|
||||
std::optional<size_t> null_count_;
|
||||
};
|
||||
|
||||
using VectorPtr = std::shared_ptr<BaseVector>;
|
||||
|
||||
/**
|
||||
* @brief Single vector for scalar types
|
||||
* @todo using memory pool && buffer replace FieldData
|
||||
*/
|
||||
class ColumnVector final : public BaseVector {
|
||||
public:
|
||||
ColumnVector(DataType data_type,
|
||||
size_t length,
|
||||
std::optional<size_t> null_count = std::nullopt)
|
||||
: BaseVector(data_type, length, null_count) {
|
||||
values_ = InitScalarFieldData(data_type, length);
|
||||
}
|
||||
|
||||
ColumnVector(FixedVector<bool>&& data)
|
||||
: BaseVector(DataType::BOOL, data.size()) {
|
||||
values_ =
|
||||
std::make_shared<FieldData<bool>>(DataType::BOOL, std::move(data));
|
||||
}
|
||||
|
||||
virtual ~ColumnVector() override {
|
||||
values_.reset();
|
||||
}
|
||||
|
||||
void*
|
||||
GetRawData() {
|
||||
return values_->Data();
|
||||
}
|
||||
|
||||
template <typename As>
|
||||
const As*
|
||||
RawAsValues() const {
|
||||
return reinterpret_cast<const As*>(values_->Data());
|
||||
}
|
||||
|
||||
private:
|
||||
FieldDataPtr values_;
|
||||
};
|
||||
|
||||
using ColumnVectorPtr = std::shared_ptr<ColumnVector>;
|
||||
|
||||
/**
|
||||
* @brief Multi vectors for scalar types
|
||||
* mainly using it to pass internal result in segcore scalar engine system
|
||||
*/
|
||||
class RowVector : public BaseVector {
|
||||
public:
|
||||
RowVector(std::vector<DataType>& data_types,
|
||||
size_t length,
|
||||
std::optional<size_t> null_count = std::nullopt)
|
||||
: BaseVector(DataType::ROW, length, null_count) {
|
||||
for (auto& type : data_types) {
|
||||
children_values_.emplace_back(
|
||||
std::make_shared<ColumnVector>(type, length));
|
||||
}
|
||||
}
|
||||
|
||||
RowVector(const std::vector<VectorPtr>& children)
|
||||
: BaseVector(DataType::ROW, 0) {
|
||||
for (auto& child : children) {
|
||||
children_values_.push_back(child);
|
||||
if (child->size() > length_) {
|
||||
length_ = child->size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<VectorPtr>&
|
||||
childrens() {
|
||||
return children_values_;
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
child(int index) {
|
||||
assert(index < children_values_.size());
|
||||
return children_values_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<VectorPtr> children_values_;
|
||||
};
|
||||
|
||||
using RowVectorPtr = std::shared_ptr<RowVector>;
|
||||
|
||||
} // namespace milvus
|
|
@ -25,7 +25,7 @@
|
|||
#include "common/Tracer.h"
|
||||
#include "log/Log.h"
|
||||
|
||||
std::once_flag flag1, flag2, flag3, flag4, flag5;
|
||||
std::once_flag flag1, flag2, flag3, flag4, flag5, flag6;
|
||||
std::once_flag traceFlag;
|
||||
|
||||
void
|
||||
|
@ -70,6 +70,14 @@ InitCpuNum(const int value) {
|
|||
flag3, [](int value) { milvus::SetCpuNum(value); }, value);
|
||||
}
|
||||
|
||||
void
|
||||
InitDefaultExprEvalBatchSize(int64_t val) {
|
||||
std::call_once(
|
||||
flag6,
|
||||
[](int val) { milvus::SetDefaultExecEvalExprBatchSize(val); },
|
||||
val);
|
||||
}
|
||||
|
||||
void
|
||||
InitTrace(CTraceConfig* config) {
|
||||
auto traceConfig = milvus::tracer::TraceConfig{config->exporter,
|
||||
|
|
|
@ -36,6 +36,9 @@ InitMiddlePriorityThreadCoreCoefficient(const int64_t);
|
|||
void
|
||||
InitLowPriorityThreadCoreCoefficient(const int64_t);
|
||||
|
||||
void
|
||||
InitDefaultExprEvalBatchSize(int64_t val);
|
||||
|
||||
void
|
||||
InitCpuNum(const int);
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
set(MILVUS_EXEC_SRCS
|
||||
expression/Expr.cpp
|
||||
expression/UnaryExpr.cpp
|
||||
expression/ConjunctExpr.cpp
|
||||
expression/LogicalUnaryExpr.cpp
|
||||
expression/LogicalBinaryExpr.cpp
|
||||
expression/TermExpr.cpp
|
||||
expression/BinaryArithOpEvalRangeExpr.cpp
|
||||
expression/BinaryRangeExpr.cpp
|
||||
expression/AlwaysTrueExpr.cpp
|
||||
expression/CompareExpr.cpp
|
||||
expression/JsonContainsExpr.cpp
|
||||
expression/ExistsExpr.cpp
|
||||
operator/FilterBits.cpp
|
||||
operator/Operator.cpp
|
||||
Driver.cpp
|
||||
Task.cpp
|
||||
)
|
||||
|
||||
add_library(milvus_exec STATIC ${MILVUS_EXEC_SRCS})
|
||||
if(USE_DYNAMIC_SIMD)
|
||||
target_link_libraries(milvus_exec milvus_common milvus_simd milvus-storage ${CONAN_LIBS})
|
||||
else()
|
||||
target_link_libraries(milvus_exec milvus_common milvus-storage ${CONAN_LIBS})
|
||||
endif()
|
|
@ -0,0 +1,355 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Driver.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
|
||||
#include "exec/operator/CallbackSink.h"
|
||||
#include "exec/operator/FilterBits.h"
|
||||
#include "exec/operator/Operator.h"
|
||||
#include "exec/Task.h"
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
std::atomic_uint64_t BlockingState::num_blocked_drivers_{0};
|
||||
|
||||
std::shared_ptr<QueryConfig>
|
||||
DriverContext::GetQueryConfig() {
|
||||
return task_->query_context()->query_config();
|
||||
}
|
||||
|
||||
std::shared_ptr<Driver>
|
||||
DriverFactory::CreateDriver(std::unique_ptr<DriverContext> ctx,
|
||||
std::function<int(int pipelineid)> num_drivers) {
|
||||
auto driver = std::shared_ptr<Driver>(new Driver());
|
||||
ctx->driver_ = driver.get();
|
||||
std::vector<std::unique_ptr<Operator>> operators;
|
||||
operators.reserve(plannodes_.size());
|
||||
|
||||
for (size_t i = 0; i < plannodes_.size(); ++i) {
|
||||
auto id = operators.size();
|
||||
auto plannode = plannodes_[i];
|
||||
if (auto filternode =
|
||||
std::dynamic_pointer_cast<const plan::FilterBitsNode>(
|
||||
plannode)) {
|
||||
operators.push_back(
|
||||
std::make_unique<FilterBits>(id, ctx.get(), filternode));
|
||||
}
|
||||
// TODO: add more operators
|
||||
}
|
||||
|
||||
if (consumer_supplier_) {
|
||||
operators.push_back(consumer_supplier_(operators.size(), ctx.get()));
|
||||
}
|
||||
|
||||
driver->Init(std::move(ctx), std::move(operators));
|
||||
|
||||
return driver;
|
||||
}
|
||||
|
||||
void
|
||||
Driver::Enqueue(std::shared_ptr<Driver> driver) {
|
||||
if (driver->closed_) {
|
||||
return;
|
||||
}
|
||||
|
||||
driver->get_task()->query_context()->executor()->add(
|
||||
[driver]() { Driver::Run(driver); });
|
||||
}
|
||||
|
||||
void
|
||||
Driver::Run(std::shared_ptr<Driver> self) {
|
||||
std::shared_ptr<BlockingState> blocking_state;
|
||||
RowVectorPtr result;
|
||||
auto reason = self->RunInternal(self, blocking_state, result);
|
||||
|
||||
AssertInfo(result == nullptr,
|
||||
"The last operator (sink) must not produce any results.");
|
||||
|
||||
if (reason == StopReason::kBlock) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (reason) {
|
||||
case StopReason::kBlock:
|
||||
BlockingState::SetResume(blocking_state);
|
||||
return;
|
||||
case StopReason::kYield:
|
||||
Enqueue(self);
|
||||
case StopReason::kPause:
|
||||
case StopReason::kTerminate:
|
||||
case StopReason::kAlreadyTerminated:
|
||||
case StopReason::kAtEnd:
|
||||
return;
|
||||
default:
|
||||
AssertInfo(false, "Unhandled stop reason");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Driver::Init(std::unique_ptr<DriverContext> ctx,
|
||||
std::vector<std::unique_ptr<Operator>> operators) {
|
||||
assert(ctx != nullptr);
|
||||
ctx_ = std::move(ctx);
|
||||
AssertInfo(operators.size() != 0, "operators in driver must not empty");
|
||||
operators_ = std::move(operators);
|
||||
current_operator_index_ = operators_.size() - 1;
|
||||
}
|
||||
|
||||
void
|
||||
Driver::Close() {
|
||||
if (closed_) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& op : operators_) {
|
||||
op->Close();
|
||||
}
|
||||
|
||||
closed_ = true;
|
||||
|
||||
Task::RemoveDriver(ctx_->task_, this);
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
Driver::Next(std::shared_ptr<BlockingState>& blocking_state) {
|
||||
auto self = shared_from_this();
|
||||
|
||||
RowVectorPtr result;
|
||||
auto stop = RunInternal(self, blocking_state, result);
|
||||
|
||||
Assert(stop == StopReason::kBlock || stop == StopReason::kAtEnd ||
|
||||
stop == StopReason::kAlreadyTerminated);
|
||||
return result;
|
||||
}
|
||||
|
||||
#define CALL_OPERATOR(call_func, operator, method_name) \
|
||||
try { \
|
||||
call_func; \
|
||||
} catch (SegcoreError & e) { \
|
||||
auto err_msg = fmt::format( \
|
||||
"Operator::{} failed for [Operator:{}, plan node id: " \
|
||||
"{}] : {}", \
|
||||
method_name, \
|
||||
operator->get_operator_type(), \
|
||||
operator->get_plannode_id(), \
|
||||
e.what()); \
|
||||
LOG_SEGCORE_ERROR_ << err_msg; \
|
||||
throw ExecOperatorException(err_msg); \
|
||||
} catch (std::exception & e) { \
|
||||
throw ExecOperatorException( \
|
||||
fmt::format("Operator::{} failed for [Operator:{}, plan node id: " \
|
||||
"{}] : {}", \
|
||||
method_name, \
|
||||
operator->get_operator_type(), \
|
||||
operator->get_plannode_id(), \
|
||||
e.what())); \
|
||||
}
|
||||
|
||||
StopReason
|
||||
Driver::RunInternal(std::shared_ptr<Driver>& self,
|
||||
std::shared_ptr<BlockingState>& blocking_state,
|
||||
RowVectorPtr& result) {
|
||||
try {
|
||||
int num_operators = operators_.size();
|
||||
ContinueFuture future;
|
||||
|
||||
for (;;) {
|
||||
for (int32_t i = num_operators - 1; i >= 0; --i) {
|
||||
auto op = operators_[i].get();
|
||||
|
||||
current_operator_index_ = i;
|
||||
CALL_OPERATOR(
|
||||
blocking_reason_ = op->IsBlocked(&future), op, "IsBlocked");
|
||||
if (blocking_reason_ != BlockingReason::kNotBlocked) {
|
||||
blocking_state = std::make_shared<BlockingState>(
|
||||
self, std::move(future), op, blocking_reason_);
|
||||
return StopReason::kBlock;
|
||||
}
|
||||
Operator* next_op = nullptr;
|
||||
|
||||
if (i < operators_.size() - 1) {
|
||||
next_op = operators_[i + 1].get();
|
||||
CALL_OPERATOR(
|
||||
blocking_reason_ = next_op->IsBlocked(&future),
|
||||
next_op,
|
||||
"IsBlocked");
|
||||
if (blocking_reason_ != BlockingReason::kNotBlocked) {
|
||||
blocking_state = std::make_shared<BlockingState>(
|
||||
self, std::move(future), next_op, blocking_reason_);
|
||||
return StopReason::kBlock;
|
||||
}
|
||||
|
||||
bool needs_input;
|
||||
CALL_OPERATOR(needs_input = next_op->NeedInput(),
|
||||
next_op,
|
||||
"NeedInput");
|
||||
if (needs_input) {
|
||||
RowVectorPtr result;
|
||||
{
|
||||
CALL_OPERATOR(
|
||||
result = op->GetOutput(), op, "GetOutput");
|
||||
if (result) {
|
||||
AssertInfo(
|
||||
result->size() > 0,
|
||||
fmt::format(
|
||||
"GetOutput must return nullptr or "
|
||||
"a non-empty vector: {}",
|
||||
op->get_operator_type()));
|
||||
}
|
||||
}
|
||||
if (result) {
|
||||
CALL_OPERATOR(
|
||||
next_op->AddInput(result), next_op, "AddInput");
|
||||
i += 2;
|
||||
continue;
|
||||
} else {
|
||||
CALL_OPERATOR(
|
||||
blocking_reason_ = op->IsBlocked(&future),
|
||||
op,
|
||||
"IsBlocked");
|
||||
if (blocking_reason_ !=
|
||||
BlockingReason::kNotBlocked) {
|
||||
blocking_state =
|
||||
std::make_shared<BlockingState>(
|
||||
self,
|
||||
std::move(future),
|
||||
next_op,
|
||||
blocking_reason_);
|
||||
return StopReason::kBlock;
|
||||
}
|
||||
if (op->IsFinished()) {
|
||||
CALL_OPERATOR(next_op->NoMoreInput(),
|
||||
next_op,
|
||||
"NoMoreInput");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
{
|
||||
CALL_OPERATOR(
|
||||
result = op->GetOutput(), op, "GetOutput");
|
||||
if (result) {
|
||||
AssertInfo(
|
||||
result->size() > 0,
|
||||
fmt::format("GetOutput must return nullptr or "
|
||||
"a non-empty vector: {}",
|
||||
op->get_operator_type()));
|
||||
blocking_reason_ = BlockingReason::kWaitForConsumer;
|
||||
return StopReason::kBlock;
|
||||
}
|
||||
}
|
||||
if (op->IsFinished()) {
|
||||
Close();
|
||||
return StopReason::kAtEnd;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
get_task()->SetError(std::current_exception());
|
||||
return StopReason::kAlreadyTerminated;
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
MustStartNewPipeline(std::shared_ptr<const plan::PlanNode> plannode,
|
||||
int source_id) {
|
||||
//TODO: support LocalMerge and other shuffle
|
||||
return source_id != 0;
|
||||
}
|
||||
|
||||
OperatorSupplier
|
||||
MakeConsumerSupplier(ConsumerSupplier supplier) {
|
||||
if (supplier) {
|
||||
return [supplier](int32_t operator_id, DriverContext* ctx) {
|
||||
return std::make_unique<CallbackSink>(operator_id, ctx, supplier());
|
||||
};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint32_t
|
||||
MaxDrivers(const DriverFactory* factory, const QueryConfig& config) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void
|
||||
SplitPlan(const std::shared_ptr<const plan::PlanNode>& plannode,
|
||||
std::vector<std::shared_ptr<const plan::PlanNode>>* current_plannodes,
|
||||
const std::shared_ptr<const plan::PlanNode>& consumer_node,
|
||||
OperatorSupplier operator_supplier,
|
||||
std::vector<std::unique_ptr<DriverFactory>>* driver_factories) {
|
||||
if (!current_plannodes) {
|
||||
driver_factories->push_back(std::make_unique<DriverFactory>());
|
||||
current_plannodes = &driver_factories->back()->plannodes_;
|
||||
driver_factories->back()->consumer_supplier_ = operator_supplier;
|
||||
driver_factories->back()->consumer_node_ = consumer_node;
|
||||
}
|
||||
|
||||
auto sources = plannode->sources();
|
||||
if (sources.empty()) {
|
||||
driver_factories->back()->is_input_driver_ = true;
|
||||
} else {
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
SplitPlan(
|
||||
sources[i],
|
||||
MustStartNewPipeline(plannode, i) ? nullptr : current_plannodes,
|
||||
plannode,
|
||||
nullptr,
|
||||
driver_factories);
|
||||
}
|
||||
}
|
||||
current_plannodes->push_back(plannode);
|
||||
}
|
||||
|
||||
void
|
||||
LocalPlanner::Plan(
|
||||
const plan::PlanFragment& fragment,
|
||||
ConsumerSupplier consumer_supplier,
|
||||
std::vector<std::unique_ptr<DriverFactory>>* driver_factories,
|
||||
const QueryConfig& config,
|
||||
uint32_t max_drivers) {
|
||||
SplitPlan(fragment.plan_node_,
|
||||
nullptr,
|
||||
nullptr,
|
||||
MakeConsumerSupplier(consumer_supplier),
|
||||
driver_factories);
|
||||
|
||||
(*driver_factories)[0]->is_output_driver_ = true;
|
||||
|
||||
for (auto& factory : *driver_factories) {
|
||||
factory->max_drivers_ = MaxDrivers(factory.get(), config);
|
||||
factory->num_drivers_ = std::min(factory->max_drivers_, max_drivers);
|
||||
|
||||
if (factory->is_group_execution_) {
|
||||
factory->num_total_drivers_ =
|
||||
factory->num_drivers_ * fragment.num_splitgroups_;
|
||||
} else {
|
||||
factory->num_total_drivers_ = factory->num_drivers_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,254 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "common/Promise.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
enum class StopReason {
|
||||
// Keep running.
|
||||
kNone,
|
||||
// Go off thread and do not schedule more activity.
|
||||
kPause,
|
||||
// Stop and free all. This is returned once and the thread that gets
|
||||
// this value is responsible for freeing the state associated with
|
||||
// the thread. Other threads will get kAlreadyTerminated after the
|
||||
// first thread has received kTerminate.
|
||||
kTerminate,
|
||||
kAlreadyTerminated,
|
||||
// Go off thread and then enqueue to the back of the runnable queue.
|
||||
kYield,
|
||||
// Must wait for external events.
|
||||
kBlock,
|
||||
// No more data to produce.
|
||||
kAtEnd,
|
||||
kAlreadyOnThread
|
||||
};
|
||||
|
||||
enum class BlockingReason {
|
||||
kNotBlocked,
|
||||
kWaitForConsumer,
|
||||
kWaitForSplit,
|
||||
kWaitForExchange,
|
||||
kWaitForJoinBuild,
|
||||
/// For a build operator, it is blocked waiting for the probe operators to
|
||||
/// finish probing before build the next hash table from one of the previously
|
||||
/// spilled partition data.
|
||||
/// For a probe operator, it is blocked waiting for all its peer probe
|
||||
/// operators to finish probing before notifying the build operators to build
|
||||
/// the next hash table from the previously spilled data.
|
||||
kWaitForJoinProbe,
|
||||
kWaitForMemory,
|
||||
kWaitForConnector,
|
||||
/// Build operator is blocked waiting for all its peers to stop to run group
|
||||
/// spill on all of them.
|
||||
kWaitForSpill,
|
||||
};
|
||||
|
||||
class Driver;
|
||||
class Operator;
|
||||
class Task;
|
||||
class BlockingState {
|
||||
public:
|
||||
BlockingState(std::shared_ptr<Driver> driver,
|
||||
ContinueFuture&& future,
|
||||
Operator* op,
|
||||
BlockingReason reason)
|
||||
: driver_(std::move(driver_)),
|
||||
future_(std::move(future)),
|
||||
operator_(op),
|
||||
reason_(reason) {
|
||||
num_blocked_drivers_++;
|
||||
}
|
||||
|
||||
~BlockingState() {
|
||||
num_blocked_drivers_--;
|
||||
}
|
||||
|
||||
static void
|
||||
SetResume(std::shared_ptr<BlockingState> state) {
|
||||
}
|
||||
|
||||
Operator*
|
||||
op() {
|
||||
return operator_;
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
reason() {
|
||||
return reason_;
|
||||
}
|
||||
|
||||
// Moves out the blocking future stored inside. Can be called only once. Used
|
||||
// in single-threaded execution.
|
||||
ContinueFuture
|
||||
future() {
|
||||
return std::move(future_);
|
||||
}
|
||||
|
||||
// Returns total number of drivers process wide that are currently in blocked
|
||||
// state.
|
||||
static uint64_t
|
||||
get_num_blocked_drivers() {
|
||||
return num_blocked_drivers_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Driver> driver_;
|
||||
ContinueFuture future_;
|
||||
Operator* operator_;
|
||||
BlockingReason reason_;
|
||||
|
||||
static std::atomic_uint64_t num_blocked_drivers_;
|
||||
};
|
||||
|
||||
struct DriverContext {
|
||||
int driverid_;
|
||||
int pipelineid_;
|
||||
uint32_t split_groupid_;
|
||||
uint32_t partitionid_;
|
||||
|
||||
std::shared_ptr<Task> task_;
|
||||
Driver* driver_;
|
||||
|
||||
explicit DriverContext(std::shared_ptr<Task> task,
|
||||
int driverid,
|
||||
int pipilineid,
|
||||
uint32_t split_group_id,
|
||||
uint32_t partition_id)
|
||||
: driverid_(driverid),
|
||||
pipelineid_(pipilineid),
|
||||
split_groupid_(split_group_id),
|
||||
partitionid_(partition_id),
|
||||
task_(task) {
|
||||
}
|
||||
|
||||
std::shared_ptr<QueryConfig>
|
||||
GetQueryConfig();
|
||||
};
|
||||
using OperatorSupplier = std::function<std::unique_ptr<Operator>(
|
||||
int32_t operatorid, DriverContext* ctx)>;
|
||||
|
||||
struct DriverFactory {
|
||||
std::vector<std::shared_ptr<const plan::PlanNode>> plannodes_;
|
||||
OperatorSupplier consumer_supplier_;
|
||||
// The (local) node that will consume results supplied by this pipeline.
|
||||
// Can be null. We use that to determine the max drivers.
|
||||
std::shared_ptr<const plan::PlanNode> consumer_node_;
|
||||
uint32_t max_drivers_;
|
||||
uint32_t num_drivers_;
|
||||
uint32_t num_total_drivers_;
|
||||
|
||||
bool is_group_execution_;
|
||||
bool is_input_driver_;
|
||||
bool is_output_driver_;
|
||||
|
||||
std::shared_ptr<Driver>
|
||||
CreateDriver(std::unique_ptr<DriverContext> ctx,
|
||||
// TODO: support exchange function
|
||||
// std::shared_ptr<ExchangeClient> exchange_client,
|
||||
std::function<int(int pipilineid)> num_driver);
|
||||
|
||||
// TODO: support ditribution compute
|
||||
bool
|
||||
SupportSingleThreadExecution() const {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class Driver : public std::enable_shared_from_this<Driver> {
|
||||
public:
|
||||
static void
|
||||
Enqueue(std::shared_ptr<Driver> instance);
|
||||
|
||||
RowVectorPtr
|
||||
Next(std::shared_ptr<BlockingState>& blocking_state);
|
||||
|
||||
DriverContext*
|
||||
get_driver_context() const {
|
||||
return ctx_.get();
|
||||
}
|
||||
|
||||
const std::shared_ptr<Task>&
|
||||
get_task() const {
|
||||
return ctx_->task_;
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
GetBlockingReason() const {
|
||||
return blocking_reason_;
|
||||
}
|
||||
|
||||
void
|
||||
Init(std::unique_ptr<DriverContext> driver_ctx,
|
||||
std::vector<std::unique_ptr<Operator>> operators);
|
||||
|
||||
private:
|
||||
Driver() = default;
|
||||
|
||||
void
|
||||
EnqueueInternal() {
|
||||
}
|
||||
|
||||
static void
|
||||
Run(std::shared_ptr<Driver> self);
|
||||
|
||||
StopReason
|
||||
RunInternal(std::shared_ptr<Driver>& self,
|
||||
std::shared_ptr<BlockingState>& blocking_state,
|
||||
RowVectorPtr& result);
|
||||
|
||||
void
|
||||
Close();
|
||||
|
||||
std::unique_ptr<DriverContext> ctx_;
|
||||
|
||||
std::atomic_bool closed_{false};
|
||||
|
||||
std::vector<std::unique_ptr<Operator>> operators_;
|
||||
|
||||
size_t current_operator_index_{0};
|
||||
|
||||
BlockingReason blocking_reason_{BlockingReason::kNotBlocked};
|
||||
|
||||
friend struct DriverFactory;
|
||||
};
|
||||
|
||||
using Consumer = std::function<BlockingReason(RowVectorPtr, ContinueFuture*)>;
|
||||
using ConsumerSupplier = std::function<Consumer()>;
|
||||
class LocalPlanner {
|
||||
public:
|
||||
static void
|
||||
Plan(const plan::PlanFragment& fragment,
|
||||
ConsumerSupplier consumer_supplier,
|
||||
std::vector<std::unique_ptr<DriverFactory>>* driver_factories,
|
||||
const QueryConfig& config,
|
||||
uint32_t max_drivers);
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,257 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <folly/Executor.h>
|
||||
#include <folly/executors/CPUThreadPoolExecutor.h>
|
||||
#include <folly/Optional.h>
|
||||
|
||||
#include "common/Common.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Exception.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
enum class ContextScope { GLOBAL = 0, SESSION = 1, QUERY = 2, Executor = 3 };
|
||||
|
||||
class BaseConfig {
|
||||
public:
|
||||
virtual folly::Optional<std::string>
|
||||
Get(const std::string& key) const = 0;
|
||||
|
||||
template <typename T>
|
||||
folly::Optional<T>
|
||||
Get(const std::string& key) const {
|
||||
auto val = Get(key);
|
||||
if (val.hasValue()) {
|
||||
return folly::to<T>(val.value());
|
||||
} else {
|
||||
return folly::none;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T
|
||||
Get(const std::string& key, const T& default_value) const {
|
||||
auto val = Get(key);
|
||||
if (val.hasValue()) {
|
||||
return folly::to<T>(val.value());
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool
|
||||
IsValueExists(const std::string& key) const = 0;
|
||||
|
||||
virtual const std::unordered_map<std::string, std::string>&
|
||||
values() const {
|
||||
throw NotImplementedException("method values() is not supported");
|
||||
}
|
||||
|
||||
virtual ~BaseConfig() = default;
|
||||
};
|
||||
|
||||
class MemConfig : public BaseConfig {
|
||||
public:
|
||||
explicit MemConfig(
|
||||
const std::unordered_map<std::string, std::string>& values)
|
||||
: values_(values) {
|
||||
}
|
||||
|
||||
explicit MemConfig() : values_{} {
|
||||
}
|
||||
|
||||
explicit MemConfig(std::unordered_map<std::string, std::string>&& values)
|
||||
: values_(std::move(values)) {
|
||||
}
|
||||
|
||||
folly::Optional<std::string>
|
||||
Get(const std::string& key) const override {
|
||||
folly::Optional<std::string> val;
|
||||
auto it = values_.find(key);
|
||||
if (it != values_.end()) {
|
||||
val = it->second;
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
bool
|
||||
IsValueExists(const std::string& key) const override {
|
||||
return values_.find(key) != values_.end();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::string>&
|
||||
values() const override {
|
||||
return values_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> values_;
|
||||
};
|
||||
|
||||
class QueryConfig : public MemConfig {
|
||||
public:
|
||||
// Whether to use the simplified expression evaluation path. False by default.
|
||||
static constexpr const char* kExprEvalSimplified =
|
||||
"expression.eval_simplified";
|
||||
|
||||
static constexpr const char* kExprEvalBatchSize =
|
||||
"expression.eval_batch_size";
|
||||
|
||||
QueryConfig(const std::unordered_map<std::string, std::string>& values)
|
||||
: MemConfig(values) {
|
||||
}
|
||||
|
||||
QueryConfig() = default;
|
||||
|
||||
bool
|
||||
get_expr_eval_simplified() const {
|
||||
return BaseConfig::Get<bool>(kExprEvalSimplified, false);
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_expr_batch_size() const {
|
||||
return BaseConfig::Get<int64_t>(kExprEvalBatchSize,
|
||||
EXEC_EVAL_EXPR_BATCH_SIZE);
|
||||
}
|
||||
};
|
||||
|
||||
class Context {
|
||||
public:
|
||||
explicit Context(ContextScope scope,
|
||||
const std::shared_ptr<const Context> parent = nullptr)
|
||||
: scope_(scope), parent_(parent) {
|
||||
}
|
||||
|
||||
ContextScope
|
||||
scope() const {
|
||||
return scope_;
|
||||
}
|
||||
|
||||
std::shared_ptr<const Context>
|
||||
parent() const {
|
||||
return parent_;
|
||||
}
|
||||
// // TODO: support dynamic update
|
||||
// void
|
||||
// set_config(const std::shared_ptr<const Config>& config) {
|
||||
// std::atomic_exchange(&config_, config);
|
||||
// }
|
||||
|
||||
// std::shared_ptr<const config>
|
||||
// get_config() {
|
||||
// return config_;
|
||||
// }
|
||||
|
||||
private:
|
||||
ContextScope scope_;
|
||||
std::shared_ptr<const Context> parent_;
|
||||
//std::shared_ptr<const Config> config_;
|
||||
};
|
||||
|
||||
class QueryContext : public Context {
|
||||
public:
|
||||
QueryContext(const std::string& query_id,
|
||||
const milvus::segcore::SegmentInternalInterface* segment,
|
||||
milvus::Timestamp timestamp,
|
||||
std::shared_ptr<QueryConfig> query_config =
|
||||
std::make_shared<QueryConfig>(),
|
||||
folly::Executor* executor = nullptr,
|
||||
std::unordered_map<std::string, std::shared_ptr<Config>>
|
||||
connector_configs = {})
|
||||
: Context(ContextScope::QUERY),
|
||||
query_id_(query_id),
|
||||
segment_(segment),
|
||||
query_timestamp_(timestamp),
|
||||
query_config_(query_config),
|
||||
executor_(executor) {
|
||||
}
|
||||
|
||||
folly::Executor*
|
||||
executor() const {
|
||||
return executor_;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::shared_ptr<Config>>&
|
||||
connector_configs() const {
|
||||
return connector_configs_;
|
||||
}
|
||||
|
||||
std::shared_ptr<QueryConfig>
|
||||
query_config() const {
|
||||
return query_config_;
|
||||
}
|
||||
|
||||
std::string
|
||||
query_id() const {
|
||||
return query_id_;
|
||||
}
|
||||
|
||||
const milvus::segcore::SegmentInternalInterface*
|
||||
get_segment() {
|
||||
return segment_;
|
||||
}
|
||||
|
||||
milvus::Timestamp
|
||||
get_query_timestamp() {
|
||||
return query_timestamp_;
|
||||
}
|
||||
|
||||
private:
|
||||
folly::Executor* executor_;
|
||||
//folly::Executor::KeepAlive<> executor_keepalive_;
|
||||
std::unordered_map<std::string, std::shared_ptr<Config>> connector_configs_;
|
||||
std::shared_ptr<QueryConfig> query_config_;
|
||||
std::string query_id_;
|
||||
|
||||
// current segment that query execute in
|
||||
const milvus::segcore::SegmentInternalInterface* segment_;
|
||||
// timestamp this query generate
|
||||
milvus::Timestamp query_timestamp_;
|
||||
};
|
||||
|
||||
// Represent the state of one thread of query execution.
|
||||
// TODO: add more class member such as memory pool
|
||||
class ExecContext : public Context {
|
||||
public:
|
||||
ExecContext(QueryContext* query_context)
|
||||
: Context(ContextScope::Executor), query_context_(query_context) {
|
||||
}
|
||||
|
||||
QueryContext*
|
||||
get_query_context() const {
|
||||
return query_context_;
|
||||
}
|
||||
|
||||
std::shared_ptr<QueryConfig>
|
||||
get_query_config() const {
|
||||
return query_context_->query_config();
|
||||
}
|
||||
|
||||
private:
|
||||
QueryContext* query_context_;
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,230 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Task.h"
|
||||
|
||||
#include <boost/lexical_cast.hpp>
|
||||
#include <boost/uuid/uuid_generators.hpp>
|
||||
#include <boost/uuid/uuid_io.hpp>
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
// Special group id to reflect the ungrouped execution.
|
||||
constexpr uint32_t kUngroupedGroupId{std::numeric_limits<uint32_t>::max()};
|
||||
|
||||
std::string
|
||||
MakeUuid() {
|
||||
return boost::lexical_cast<std::string>(boost::uuids::random_generator()());
|
||||
}
|
||||
|
||||
std::shared_ptr<Task>
|
||||
Task::Create(const std::string& task_id,
|
||||
plan::PlanFragment plan_fragment,
|
||||
int destination,
|
||||
std::shared_ptr<QueryContext> query_context,
|
||||
Consumer consumer,
|
||||
std::function<void(std::exception_ptr)> on_error) {
|
||||
return Task::Create(task_id,
|
||||
std::move(plan_fragment),
|
||||
destination,
|
||||
std::move(query_context),
|
||||
(consumer ? [c = std::move(consumer)]() { return c; }
|
||||
: ConsumerSupplier{}),
|
||||
std::move(on_error));
|
||||
}
|
||||
|
||||
std::shared_ptr<Task>
|
||||
Task::Create(const std::string& task_id,
|
||||
const plan::PlanFragment& plan_fragment,
|
||||
int destination,
|
||||
std::shared_ptr<QueryContext> query_ctx,
|
||||
ConsumerSupplier supplier,
|
||||
std::function<void(std::exception_ptr)> on_error) {
|
||||
return std::shared_ptr<Task>(new Task(task_id,
|
||||
std::move(plan_fragment),
|
||||
destination,
|
||||
std::move(query_ctx),
|
||||
std::move(supplier),
|
||||
std::move(on_error)));
|
||||
}
|
||||
|
||||
void
|
||||
Task::SetError(const std::exception_ptr& exception) {
|
||||
{
|
||||
std::lock_guard<std::mutex> l(mutex_);
|
||||
if (!IsRunningLocked()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (exception_ != nullptr) {
|
||||
return;
|
||||
}
|
||||
exception_ = exception;
|
||||
}
|
||||
|
||||
Terminate(TaskState::kFailed);
|
||||
|
||||
if (on_error_) {
|
||||
on_error_(exception_);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Task::SetError(const std::string& message) {
|
||||
try {
|
||||
throw std::runtime_error(message);
|
||||
} catch (const std::runtime_error& e) {
|
||||
SetError(std::current_exception());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Task::CreateDriversLocked(std::shared_ptr<Task>& self,
|
||||
uint32_t split_group_id,
|
||||
std::vector<std::shared_ptr<Driver>>& out) {
|
||||
const bool is_group_execution_drivers =
|
||||
(split_group_id != kUngroupedGroupId);
|
||||
const auto num_pipelines = driver_factories_.size();
|
||||
|
||||
for (auto pipeline = 0; pipeline < num_pipelines; ++pipeline) {
|
||||
auto& factory = driver_factories_[pipeline];
|
||||
|
||||
if (factory->is_group_execution_ != is_group_execution_drivers) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t driverid_offset =
|
||||
factory->num_drivers_ *
|
||||
(is_group_execution_drivers ? split_group_id : 0);
|
||||
|
||||
for (uint32_t partition_id = 0; partition_id < factory->num_drivers_;
|
||||
++partition_id) {
|
||||
out.emplace_back(factory->CreateDriver(
|
||||
std::make_unique<DriverContext>(self,
|
||||
driverid_offset + partition_id,
|
||||
pipeline,
|
||||
split_group_id,
|
||||
partition_id),
|
||||
[self](size_t i) {
|
||||
return i < self->driver_factories_.size()
|
||||
? self->driver_factories_[i]->num_total_drivers_
|
||||
: 0;
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
Task::Next(ContinueFuture* future) {
|
||||
// NOTE: Task::Next is single-threaded execution
|
||||
AssertInfo(plan_fragment_.execution_strategy_ ==
|
||||
plan::ExecutionStrategy::kUngrouped,
|
||||
"Single-threaded execution supports only ungrouped execution");
|
||||
|
||||
AssertInfo(state_ == TaskState::kRunning,
|
||||
"Task has already finished processing.");
|
||||
|
||||
if (driver_factories_.empty()) {
|
||||
AssertInfo(
|
||||
consumer_supplier_ == nullptr,
|
||||
"Single-threaded execution doesn't support delivering results to a "
|
||||
"callback");
|
||||
|
||||
LocalPlanner::Plan(plan_fragment_,
|
||||
nullptr,
|
||||
&driver_factories_,
|
||||
*query_context_->query_config(),
|
||||
1);
|
||||
|
||||
for (const auto& factory : driver_factories_) {
|
||||
assert(factory->SupportSingleThreadExecution());
|
||||
num_ungrouped_drivers_ += factory->num_drivers_;
|
||||
num_total_drivers_ += factory->num_total_drivers_;
|
||||
}
|
||||
|
||||
auto self = shared_from_this();
|
||||
std::vector<std::shared_ptr<Driver>> drivers;
|
||||
|
||||
drivers.reserve(num_ungrouped_drivers_);
|
||||
CreateDriversLocked(self, kUngroupedGroupId, drivers);
|
||||
|
||||
drivers_ = std::move(drivers);
|
||||
}
|
||||
|
||||
const auto num_drivers = drivers_.size();
|
||||
|
||||
std::vector<ContinueFuture> futures;
|
||||
futures.resize(num_drivers);
|
||||
|
||||
for (;;) {
|
||||
int runnable_drivers = 0;
|
||||
int blocked_drivers = 0;
|
||||
|
||||
for (auto i = 0; i < num_drivers; ++i) {
|
||||
if (drivers_[i] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!futures[i].isReady()) {
|
||||
++blocked_drivers;
|
||||
continue;
|
||||
}
|
||||
|
||||
++runnable_drivers;
|
||||
|
||||
std::shared_ptr<BlockingState> blocking_state;
|
||||
|
||||
auto result = drivers_[i]->Next(blocking_state);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (blocking_state) {
|
||||
futures[i] = blocking_state->future();
|
||||
}
|
||||
|
||||
if (error()) {
|
||||
std::rethrow_exception(error());
|
||||
}
|
||||
}
|
||||
|
||||
if (runnable_drivers == 0) {
|
||||
if (blocked_drivers > 0) {
|
||||
if (!future) {
|
||||
throw ExecDriverException(
|
||||
"Cannot make progress as all remaining drivers are "
|
||||
"blocked and user are not expected to wait.");
|
||||
} else {
|
||||
std::vector<ContinueFuture> not_ready_futures;
|
||||
for (auto& continue_future : futures) {
|
||||
if (!continue_future.isReady()) {
|
||||
not_ready_futures.emplace_back(
|
||||
std::move(continue_future));
|
||||
}
|
||||
}
|
||||
*future =
|
||||
folly::collectAll(std::move(not_ready_futures)).unit();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,205 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "exec/Driver.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
enum class TaskState { kRunning, kFinished, kCanceled, kAborted, kFailed };
|
||||
|
||||
std::string
|
||||
MakeUuid();
|
||||
class Task : public std::enable_shared_from_this<Task> {
|
||||
public:
|
||||
static std::shared_ptr<Task>
|
||||
Create(const std::string& task_id,
|
||||
plan::PlanFragment plan_fragment,
|
||||
int destination,
|
||||
std::shared_ptr<QueryContext> query_context,
|
||||
Consumer consumer = nullptr,
|
||||
std::function<void(std::exception_ptr)> on_error = nullptr);
|
||||
|
||||
static std::shared_ptr<Task>
|
||||
Create(const std::string& task_id,
|
||||
const plan::PlanFragment& plan_fragment,
|
||||
int destination,
|
||||
std::shared_ptr<QueryContext> query_ctx,
|
||||
ConsumerSupplier supplier,
|
||||
std::function<void(std::exception_ptr)> on_error = nullptr);
|
||||
|
||||
Task(const std::string& task_id,
|
||||
plan::PlanFragment plan_fragment,
|
||||
int destination,
|
||||
std::shared_ptr<QueryContext> query_ctx,
|
||||
ConsumerSupplier consumer_supplier,
|
||||
std::function<void(std::exception_ptr)> on_error)
|
||||
: uuid_{MakeUuid()},
|
||||
taskid_(task_id),
|
||||
plan_fragment_(std::move(plan_fragment)),
|
||||
destination_(destination),
|
||||
query_context_(std::move(query_ctx)),
|
||||
consumer_supplier_(std::move(consumer_supplier)),
|
||||
on_error_(on_error) {
|
||||
}
|
||||
|
||||
~Task() {
|
||||
}
|
||||
|
||||
const std::string&
|
||||
uuid() const {
|
||||
return uuid_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
taskid() const {
|
||||
return taskid_;
|
||||
}
|
||||
|
||||
const int
|
||||
destination() const {
|
||||
return destination_;
|
||||
}
|
||||
|
||||
const std::shared_ptr<QueryContext>&
|
||||
query_context() const {
|
||||
return query_context_;
|
||||
}
|
||||
|
||||
static void
|
||||
Start(std::shared_ptr<Task> self,
|
||||
uint32_t max_drivers,
|
||||
uint32_t concurrent_split_groups = 1);
|
||||
|
||||
static void
|
||||
RemoveDriver(std::shared_ptr<Task> self, Driver* instance) {
|
||||
std::lock_guard<std::mutex> lock(self->mutex_);
|
||||
for (auto& driver_ptr : self->drivers_) {
|
||||
if (driver_ptr.get() != instance) {
|
||||
continue;
|
||||
}
|
||||
driver_ptr = nullptr;
|
||||
self->DriverClosedLocked();
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
SupportsSingleThreadedExecution() const {
|
||||
if (consumer_supplier_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
Next(ContinueFuture* future = nullptr);
|
||||
|
||||
void
|
||||
CreateDriversLocked(std::shared_ptr<Task>& self,
|
||||
uint32_t split_groupid,
|
||||
std::vector<std::shared_ptr<Driver>>& out);
|
||||
|
||||
void
|
||||
SetError(const std::exception_ptr& exception);
|
||||
|
||||
void
|
||||
SetError(const std::string& message);
|
||||
|
||||
bool
|
||||
IsRunning() const {
|
||||
std::lock_guard<std::mutex> l(mutex_);
|
||||
return (state_ == TaskState::kRunning);
|
||||
}
|
||||
|
||||
bool
|
||||
IsFinished() const {
|
||||
std::lock_guard<std::mutex> l(mutex_);
|
||||
return (state_ == TaskState::kFinished);
|
||||
}
|
||||
|
||||
bool
|
||||
IsRunningLocked() const {
|
||||
return (state_ == TaskState::kRunning);
|
||||
}
|
||||
|
||||
bool
|
||||
IsFinishedLocked() const {
|
||||
return (state_ == TaskState::kFinished);
|
||||
}
|
||||
|
||||
void
|
||||
Terminate(TaskState state) {
|
||||
}
|
||||
|
||||
std::exception_ptr
|
||||
error() const {
|
||||
std::lock_guard<std::mutex> l(mutex_);
|
||||
return exception_;
|
||||
}
|
||||
|
||||
void
|
||||
DriverClosedLocked() {
|
||||
if (IsRunningLocked()) {
|
||||
--num_running_drivers_;
|
||||
}
|
||||
|
||||
num_finished_drivers_++;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string uuid_;
|
||||
|
||||
std::string taskid_;
|
||||
|
||||
plan::PlanFragment plan_fragment_;
|
||||
|
||||
int destination_;
|
||||
|
||||
std::shared_ptr<QueryContext> query_context_;
|
||||
|
||||
std::exception_ptr exception_ = nullptr;
|
||||
|
||||
std::function<void(std::exception_ptr)> on_error_;
|
||||
|
||||
std::vector<std::unique_ptr<DriverFactory>> driver_factories_;
|
||||
|
||||
std::vector<std::shared_ptr<Driver>> drivers_;
|
||||
|
||||
ConsumerSupplier consumer_supplier_;
|
||||
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
TaskState state_ = TaskState::kRunning;
|
||||
|
||||
uint32_t num_running_drivers_{0};
|
||||
|
||||
uint32_t num_total_drivers_{0};
|
||||
|
||||
uint32_t num_ungrouped_drivers_{0};
|
||||
|
||||
uint32_t num_finished_drivers_{0};
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,45 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "AlwaysTrueExpr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
int64_t real_batch_size = current_pos_ + batch_size_ >= num_rows_
|
||||
? num_rows_ - current_pos_
|
||||
: batch_size_;
|
||||
|
||||
if (real_batch_size == 0) {
|
||||
result = nullptr;
|
||||
return;
|
||||
}
|
||||
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res_bool = (bool*)res_vec->GetRawData();
|
||||
for (size_t i = 0; i < real_batch_size; ++i) {
|
||||
res_bool[i] = true;
|
||||
}
|
||||
|
||||
result = res_vec;
|
||||
current_pos_ += real_batch_size;
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,56 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class PhyAlwaysTrueExpr : public Expr {
|
||||
public:
|
||||
PhyAlwaysTrueExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::AlwaysTrueExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: Expr(DataType::BOOL, std::move(input), name),
|
||||
expr_(expr),
|
||||
batch_size_(batch_size) {
|
||||
num_rows_ = segment->get_active_count(query_timestamp);
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::AlwaysTrueExpr> expr_;
|
||||
int64_t num_rows_;
|
||||
int64_t current_pos_{0};
|
||||
int64_t batch_size_;
|
||||
};
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,748 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "BinaryArithOpEvalRangeExpr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyBinaryArithOpEvalRangeExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
switch (expr_->column_.data_type_) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecRangeVisitorImpl<bool>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT8: {
|
||||
result = ExecRangeVisitorImpl<int8_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
result = ExecRangeVisitorImpl<int16_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
result = ExecRangeVisitorImpl<int32_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
result = ExecRangeVisitorImpl<int64_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
result = ExecRangeVisitorImpl<float>();
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
result = ExecRangeVisitorImpl<double>();
|
||||
break;
|
||||
}
|
||||
case DataType::JSON: {
|
||||
auto value_type = expr_->value_.val_case();
|
||||
switch (value_type) {
|
||||
case proto::plan::GenericValue::ValCase::kBoolVal: {
|
||||
result = ExecRangeVisitorImplForJson<bool>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val: {
|
||||
result = ExecRangeVisitorImplForJson<int64_t>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal: {
|
||||
result = ExecRangeVisitorImplForJson<double>();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported value type {} in expression",
|
||||
value_type));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::ARRAY: {
|
||||
auto value_type = expr_->value_.val_case();
|
||||
switch (value_type) {
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val: {
|
||||
result = ExecRangeVisitorImplForArray<int64_t>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal: {
|
||||
result = ExecRangeVisitorImplForArray<double>();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported value type {} in expression",
|
||||
value_type));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}",
|
||||
expr_->column_.data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
auto op_type = expr_->op_type_;
|
||||
auto arith_type = expr_->arith_op_type_;
|
||||
auto value = GetValueFromProto<ValueType>(expr_->value_);
|
||||
auto right_operand = GetValueFromProto<ValueType>(expr_->right_operand_);
|
||||
|
||||
#define BinaryArithRangeJSONCompare(cmp) \
|
||||
do { \
|
||||
for (size_t i = 0; i < size; ++i) { \
|
||||
auto x = data[i].template at<GetType>(pointer); \
|
||||
if (x.error()) { \
|
||||
if constexpr (std::is_same_v<GetType, int64_t>) { \
|
||||
auto x = data[i].template at<double>(pointer); \
|
||||
res[i] = !x.error() && (cmp); \
|
||||
continue; \
|
||||
} \
|
||||
res[i] = false; \
|
||||
continue; \
|
||||
} \
|
||||
res[i] = (cmp); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define BinaryArithRangeJSONCompareNotEqual(cmp) \
|
||||
do { \
|
||||
for (size_t i = 0; i < size; ++i) { \
|
||||
auto x = data[i].template at<GetType>(pointer); \
|
||||
if (x.error()) { \
|
||||
if constexpr (std::is_same_v<GetType, int64_t>) { \
|
||||
auto x = data[i].template at<double>(pointer); \
|
||||
res[i] = x.error() || (cmp); \
|
||||
continue; \
|
||||
} \
|
||||
res[i] = true; \
|
||||
continue; \
|
||||
} \
|
||||
res[i] = (cmp); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
auto execute_sub_batch = [op_type, arith_type](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
ValueType val,
|
||||
ValueType right_operand,
|
||||
const std::string& pointer) {
|
||||
switch (op_type) {
|
||||
case proto::plan::OpType::Equal: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
BinaryArithRangeJSONCompare(x.value() + right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
BinaryArithRangeJSONCompare(x.value() - right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
BinaryArithRangeJSONCompare(x.value() * right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
BinaryArithRangeJSONCompare(x.value() / right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
BinaryArithRangeJSONCompare(
|
||||
static_cast<ValueType>(
|
||||
fmod(x.value(), right_operand)) == val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::ArrayLength: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
int array_length = 0;
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (!array.error()) {
|
||||
array_length = array.count_elements();
|
||||
}
|
||||
res[i] = array_length == val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::OpType::NotEqual: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
BinaryArithRangeJSONCompareNotEqual(
|
||||
x.value() + right_operand != val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
BinaryArithRangeJSONCompareNotEqual(
|
||||
x.value() - right_operand != val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
BinaryArithRangeJSONCompareNotEqual(
|
||||
x.value() * right_operand != val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
BinaryArithRangeJSONCompareNotEqual(
|
||||
x.value() / right_operand != val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
BinaryArithRangeJSONCompareNotEqual(
|
||||
static_cast<ValueType>(
|
||||
fmod(x.value(), right_operand)) != val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::ArrayLength: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
int array_length = 0;
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (!array.error()) {
|
||||
array_length = array.count_elements();
|
||||
}
|
||||
res[i] = array_length != val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
op_type));
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::Json>(execute_sub_batch,
|
||||
std::nullptr_t{},
|
||||
res,
|
||||
value,
|
||||
right_operand,
|
||||
pointer);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
int index = -1;
|
||||
if (expr_->column_.nested_path_.size() > 0) {
|
||||
index = std::stoi(expr_->column_.nested_path_[0]);
|
||||
}
|
||||
auto op_type = expr_->op_type_;
|
||||
auto arith_type = expr_->arith_op_type_;
|
||||
auto value = GetValueFromProto<ValueType>(expr_->value_);
|
||||
auto right_operand =
|
||||
arith_type != proto::plan::ArithOpType::ArrayLength
|
||||
? GetValueFromProto<ValueType>(expr_->right_operand_)
|
||||
: ValueType();
|
||||
|
||||
#define BinaryArithRangeArrayCompare(cmp) \
|
||||
do { \
|
||||
for (size_t i = 0; i < size; ++i) { \
|
||||
if (index >= data[i].length()) { \
|
||||
res[i] = false; \
|
||||
continue; \
|
||||
} \
|
||||
auto value = data[i].get_data<GetType>(index); \
|
||||
res[i] = (cmp); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
auto execute_sub_batch = [op_type, arith_type](const ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
ValueType val,
|
||||
ValueType right_operand,
|
||||
int index) {
|
||||
switch (op_type) {
|
||||
case proto::plan::OpType::Equal: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
BinaryArithRangeArrayCompare(value + right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
BinaryArithRangeArrayCompare(value - right_operand ==
|
||||
val);
|
||||
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
BinaryArithRangeArrayCompare(value * right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
BinaryArithRangeArrayCompare(value / right_operand ==
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
BinaryArithRangeArrayCompare(
|
||||
static_cast<ValueType>(
|
||||
fmod(value, right_operand)) == val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::ArrayLength: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = data[i].length() == val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::OpType::NotEqual: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
BinaryArithRangeArrayCompare(value + right_operand !=
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
BinaryArithRangeArrayCompare(value - right_operand !=
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
BinaryArithRangeArrayCompare(value * right_operand !=
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
BinaryArithRangeArrayCompare(value / right_operand !=
|
||||
val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
BinaryArithRangeArrayCompare(
|
||||
static_cast<ValueType>(
|
||||
fmod(value, right_operand)) != val);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::ArrayLength: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = data[i].length() != val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
op_type));
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, value, right_operand, index);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImpl() {
|
||||
if (is_index_mode_) {
|
||||
return ExecRangeVisitorImplForIndex<T>();
|
||||
} else {
|
||||
return ExecRangeVisitorImplForData<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() {
|
||||
using Index = index::ScalarIndex<T>;
|
||||
typedef std::conditional_t<std::is_integral_v<T> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
T>
|
||||
HighPrecisionType;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value = GetValueFromProto<HighPrecisionType>(expr_->value_);
|
||||
auto right_operand =
|
||||
GetValueFromProto<HighPrecisionType>(expr_->right_operand_);
|
||||
auto op_type = expr_->op_type_;
|
||||
auto arith_type = expr_->arith_op_type_;
|
||||
auto sub_batch_size = size_per_chunk_;
|
||||
|
||||
auto execute_sub_batch = [op_type, arith_type, sub_batch_size](
|
||||
Index* index_ptr,
|
||||
HighPrecisionType value,
|
||||
HighPrecisionType right_operand) {
|
||||
FixedVector<bool> res;
|
||||
switch (op_type) {
|
||||
case proto::plan::OpType::Equal: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Add>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Sub>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Mul>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Div>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Mod>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::OpType::NotEqual: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Add>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Sub>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Mul>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Div>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
ArithOpIndexFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Mod>
|
||||
func;
|
||||
res = std::move(func(
|
||||
index_ptr, sub_batch_size, value, right_operand));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
op_type));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
auto res = ProcessIndexChunks<T>(execute_sub_batch, value, right_operand);
|
||||
AssertInfo(res.size() == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
res.size(),
|
||||
real_batch_size));
|
||||
return std::make_shared<ColumnVector>(std::move(res));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() {
|
||||
typedef std::conditional_t<std::is_integral_v<T> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
T>
|
||||
HighPrecisionType;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto value = GetValueFromProto<HighPrecisionType>(expr_->value_);
|
||||
auto right_operand =
|
||||
GetValueFromProto<HighPrecisionType>(expr_->right_operand_);
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
auto op_type = expr_->op_type_;
|
||||
auto arith_type = expr_->arith_op_type_;
|
||||
auto execute_sub_batch = [op_type, arith_type](
|
||||
const T* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
HighPrecisionType value,
|
||||
HighPrecisionType right_operand) {
|
||||
switch (op_type) {
|
||||
case proto::plan::OpType::Equal: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Add>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Sub>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Mul>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Div>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::Equal,
|
||||
proto::plan::ArithOpType::Mod>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::OpType::NotEqual: {
|
||||
switch (arith_type) {
|
||||
case proto::plan::ArithOpType::Add: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Add>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Sub: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Sub>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mul: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Mul>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Div: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Div>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::ArithOpType::Mod: {
|
||||
ArithOpElementFunc<T,
|
||||
proto::plan::OpType::NotEqual,
|
||||
proto::plan::ArithOpType::Mod>
|
||||
func;
|
||||
func(data, size, value, right_operand, res);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported arith type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
arith_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for binary "
|
||||
"arithmetic eval expr: {}",
|
||||
op_type));
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<T>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, value, right_operand);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,213 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
template <typename T,
|
||||
proto::plan::OpType cmp_op,
|
||||
proto::plan::ArithOpType arith_op>
|
||||
struct ArithOpElementFunc {
|
||||
typedef std::conditional_t<std::is_integral_v<T> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
T>
|
||||
HighPrecisonType;
|
||||
void
|
||||
operator()(const T* src,
|
||||
size_t size,
|
||||
HighPrecisonType val,
|
||||
HighPrecisonType right_operand,
|
||||
bool* res) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if constexpr (cmp_op == proto::plan::OpType::Equal) {
|
||||
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
|
||||
res[i] = (src[i] + right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Sub) {
|
||||
res[i] = (src[i] - right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mul) {
|
||||
res[i] = (src[i] * right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Div) {
|
||||
res[i] = (src[i] / right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mod) {
|
||||
res[i] = (fmod(src[i], right_operand)) == val;
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format(
|
||||
"unsupported arith type:{} for ArithOpElementFunc",
|
||||
arith_op));
|
||||
}
|
||||
} else if constexpr (cmp_op == proto::plan::OpType::NotEqual) {
|
||||
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
|
||||
res[i] = (src[i] + right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Sub) {
|
||||
res[i] = (src[i] - right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mul) {
|
||||
res[i] = (src[i] * right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Div) {
|
||||
res[i] = (src[i] / right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mod) {
|
||||
res[i] = (fmod(src[i], right_operand)) != val;
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format(
|
||||
"unsupported arith type:{} for ArithOpElementFunc",
|
||||
arith_op));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
proto::plan::OpType cmp_op,
|
||||
proto::plan::ArithOpType arith_op>
|
||||
struct ArithOpIndexFunc {
|
||||
typedef std::conditional_t<std::is_integral_v<T> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
T>
|
||||
HighPrecisonType;
|
||||
using Index = index::ScalarIndex<T>;
|
||||
FixedVector<bool>
|
||||
operator()(Index* index,
|
||||
size_t size,
|
||||
HighPrecisonType val,
|
||||
HighPrecisonType right_operand) {
|
||||
FixedVector<bool> res_vec(size);
|
||||
bool* res = res_vec.data();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (cmp_op == proto::plan::OpType::Equal) {
|
||||
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
|
||||
res[i] = (index->Reverse_Lookup(i) + right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Sub) {
|
||||
res[i] = (index->Reverse_Lookup(i) - right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mul) {
|
||||
res[i] = (index->Reverse_Lookup(i) * right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Div) {
|
||||
res[i] = (index->Reverse_Lookup(i) / right_operand) == val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mod) {
|
||||
res[i] =
|
||||
(fmod(index->Reverse_Lookup(i), right_operand)) == val;
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format(
|
||||
"unsupported arith type:{} for ArithOpElementFunc",
|
||||
arith_op));
|
||||
}
|
||||
} else if constexpr (cmp_op == proto::plan::OpType::NotEqual) {
|
||||
if constexpr (arith_op == proto::plan::ArithOpType::Add) {
|
||||
res[i] = (index->Reverse_Lookup(i) + right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Sub) {
|
||||
res[i] = (index->Reverse_Lookup(i) - right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mul) {
|
||||
res[i] = (index->Reverse_Lookup(i) * right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Div) {
|
||||
res[i] = (index->Reverse_Lookup(i) / right_operand) != val;
|
||||
} else if constexpr (arith_op ==
|
||||
proto::plan::ArithOpType::Mod) {
|
||||
res[i] =
|
||||
(fmod(index->Reverse_Lookup(i), right_operand)) != val;
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format(
|
||||
"unsupported arith type:{} for ArithOpElementFunc",
|
||||
arith_op));
|
||||
}
|
||||
}
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
};
|
||||
|
||||
class PhyBinaryArithOpEvalRangeExpr : public SegmentExpr {
|
||||
public:
|
||||
PhyBinaryArithOpEvalRangeExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::BinaryArithOpEvalRangeExpr>&
|
||||
expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: SegmentExpr(std::move(input),
|
||||
name,
|
||||
segment,
|
||||
expr->column_.field_id_,
|
||||
query_timestamp,
|
||||
batch_size),
|
||||
expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImpl();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForIndex();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForData();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForJson();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForArray();
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::BinaryArithOpEvalRangeExpr> expr_;
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,392 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "BinaryRangeExpr.h"
|
||||
|
||||
#include "query/Utils.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
switch (expr_->column_.data_type_) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecRangeVisitorImpl<bool>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT8: {
|
||||
result = ExecRangeVisitorImpl<int8_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
result = ExecRangeVisitorImpl<int16_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
result = ExecRangeVisitorImpl<int32_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
result = ExecRangeVisitorImpl<int64_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
result = ExecRangeVisitorImpl<float>();
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
result = ExecRangeVisitorImpl<double>();
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
if (segment_->type() == SegmentType::Growing) {
|
||||
result = ExecRangeVisitorImpl<std::string>();
|
||||
} else {
|
||||
result = ExecRangeVisitorImpl<std::string_view>();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::JSON: {
|
||||
auto value_type = expr_->lower_val_.val_case();
|
||||
switch (value_type) {
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val: {
|
||||
result = ExecRangeVisitorImplForJson<int64_t>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal: {
|
||||
result = ExecRangeVisitorImplForJson<double>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kStringVal: {
|
||||
result = ExecRangeVisitorImplForJson<std::string>();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported value type {} in expression",
|
||||
value_type));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::ARRAY: {
|
||||
auto value_type = expr_->lower_val_.val_case();
|
||||
switch (value_type) {
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val: {
|
||||
result = ExecRangeVisitorImplForArray<int64_t>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal: {
|
||||
result = ExecRangeVisitorImplForArray<double>();
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::ValCase::kStringVal: {
|
||||
result = ExecRangeVisitorImplForArray<std::string>();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported value type {} in expression",
|
||||
value_type));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}",
|
||||
expr_->column_.data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyBinaryRangeFilterExpr::ExecRangeVisitorImpl() {
|
||||
if (is_index_mode_) {
|
||||
return ExecRangeVisitorImplForIndex<T>();
|
||||
} else {
|
||||
return ExecRangeVisitorImplForData<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IndexInnerType, typename HighPrecisionType>
|
||||
ColumnVectorPtr
|
||||
PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1,
|
||||
HighPrecisionType& val2,
|
||||
bool& lower_inclusive,
|
||||
bool& upper_inclusive) {
|
||||
lower_inclusive = expr_->lower_inclusive_;
|
||||
upper_inclusive = expr_->upper_inclusive_;
|
||||
val1 = GetValueFromProto<HighPrecisionType>(expr_->lower_val_);
|
||||
val2 = GetValueFromProto<HighPrecisionType>(expr_->upper_val_);
|
||||
auto get_next_overflow_batch = [this]() -> ColumnVectorPtr {
|
||||
int64_t batch_size = overflow_check_pos_ + batch_size_ >= num_rows_
|
||||
? num_rows_ - overflow_check_pos_
|
||||
: batch_size_;
|
||||
overflow_check_pos_ += batch_size;
|
||||
if (cached_overflow_res_ != nullptr &&
|
||||
cached_overflow_res_->size() == batch_size) {
|
||||
return cached_overflow_res_;
|
||||
}
|
||||
auto res = std::make_shared<ColumnVector>(DataType::BOOL, batch_size);
|
||||
return res;
|
||||
};
|
||||
|
||||
if constexpr (std::is_integral_v<T> && !std::is_same_v<bool, T>) {
|
||||
if (milvus::query::gt_ub<T>(val1)) {
|
||||
return get_next_overflow_batch();
|
||||
} else if (milvus::query::lt_lb<T>(val1)) {
|
||||
val1 = std::numeric_limits<T>::min();
|
||||
lower_inclusive = true;
|
||||
}
|
||||
|
||||
if (milvus::query::gt_ub<T>(val2)) {
|
||||
val2 = std::numeric_limits<T>::max();
|
||||
upper_inclusive = true;
|
||||
} else if (milvus::query::lt_lb<T>(val2)) {
|
||||
return get_next_overflow_batch();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForIndex() {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
typedef std::conditional_t<std::is_integral_v<IndexInnerType> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
IndexInnerType>
|
||||
HighPrecisionType;
|
||||
|
||||
HighPrecisionType val1;
|
||||
HighPrecisionType val2;
|
||||
bool lower_inclusive = false;
|
||||
bool upper_inclusive = false;
|
||||
if (auto res =
|
||||
PreCheckOverflow<T>(val1, val2, lower_inclusive, upper_inclusive)) {
|
||||
return res;
|
||||
}
|
||||
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto execute_sub_batch =
|
||||
[lower_inclusive, upper_inclusive](
|
||||
Index* index_ptr, HighPrecisionType val1, HighPrecisionType val2) {
|
||||
BinaryRangeIndexFunc<T> func;
|
||||
return std::move(
|
||||
func(index_ptr, val1, val2, lower_inclusive, upper_inclusive));
|
||||
};
|
||||
auto res = ProcessIndexChunks<T>(execute_sub_batch, val1, val2);
|
||||
AssertInfo(res.size() == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
res.size(),
|
||||
real_batch_size));
|
||||
return std::make_shared<ColumnVector>(std::move(res));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
typedef std::conditional_t<std::is_integral_v<IndexInnerType> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
IndexInnerType>
|
||||
HighPrecisionType;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
HighPrecisionType val1;
|
||||
HighPrecisionType val2;
|
||||
bool lower_inclusive = false;
|
||||
bool upper_inclusive = false;
|
||||
if (auto res =
|
||||
PreCheckOverflow<T>(val1, val2, lower_inclusive, upper_inclusive)) {
|
||||
return res;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto execute_sub_batch = [lower_inclusive, upper_inclusive](
|
||||
const T* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
HighPrecisionType val1,
|
||||
HighPrecisionType val2) {
|
||||
if (lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeElementFunc<T, true, true> func;
|
||||
func(val1, val2, data, size, res);
|
||||
} else if (lower_inclusive && !upper_inclusive) {
|
||||
BinaryRangeElementFunc<T, true, false> func;
|
||||
func(val1, val2, data, size, res);
|
||||
} else if (!lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeElementFunc<T, false, true> func;
|
||||
func(val1, val2, data, size, res);
|
||||
} else {
|
||||
BinaryRangeElementFunc<T, false, false> func;
|
||||
func(val1, val2, data, size, res);
|
||||
}
|
||||
};
|
||||
auto skip_index_func =
|
||||
[val1, val2, lower_inclusive, upper_inclusive](
|
||||
const SkipIndex& skip_index, FieldId field_id, int64_t chunk_id) {
|
||||
if (lower_inclusive && upper_inclusive) {
|
||||
return skip_index.CanSkipBinaryRange<T>(
|
||||
field_id, chunk_id, val1, val2, true, true);
|
||||
} else if (lower_inclusive && !upper_inclusive) {
|
||||
return skip_index.CanSkipBinaryRange<T>(
|
||||
field_id, chunk_id, val1, val2, true, false);
|
||||
} else if (!lower_inclusive && upper_inclusive) {
|
||||
return skip_index.CanSkipBinaryRange<T>(
|
||||
field_id, chunk_id, val1, val2, false, true);
|
||||
} else {
|
||||
return skip_index.CanSkipBinaryRange<T>(
|
||||
field_id, chunk_id, val1, val2, false, false);
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<T>(
|
||||
execute_sub_batch, skip_index_func, res, val1, val2);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
bool lower_inclusive = expr_->lower_inclusive_;
|
||||
bool upper_inclusive = expr_->upper_inclusive_;
|
||||
ValueType val1 = GetValueFromProto<ValueType>(expr_->lower_val_);
|
||||
ValueType val2 = GetValueFromProto<ValueType>(expr_->upper_val_);
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
|
||||
auto execute_sub_batch = [lower_inclusive, upper_inclusive, pointer](
|
||||
const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
ValueType val1,
|
||||
ValueType val2) {
|
||||
if (lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeElementFuncForJson<ValueType, true, true> func;
|
||||
func(val1, val2, pointer, data, size, res);
|
||||
} else if (lower_inclusive && !upper_inclusive) {
|
||||
BinaryRangeElementFuncForJson<ValueType, true, false> func;
|
||||
func(val1, val2, pointer, data, size, res);
|
||||
} else if (!lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeElementFuncForJson<ValueType, false, true> func;
|
||||
func(val1, val2, pointer, data, size, res);
|
||||
} else {
|
||||
BinaryRangeElementFuncForJson<ValueType, false, false> func;
|
||||
func(val1, val2, pointer, data, size, res);
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, val1, val2);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
bool lower_inclusive = expr_->lower_inclusive_;
|
||||
bool upper_inclusive = expr_->upper_inclusive_;
|
||||
ValueType val1 = GetValueFromProto<ValueType>(expr_->lower_val_);
|
||||
ValueType val2 = GetValueFromProto<ValueType>(expr_->upper_val_);
|
||||
int index = -1;
|
||||
if (expr_->column_.nested_path_.size() > 0) {
|
||||
index = std::stoi(expr_->column_.nested_path_[0]);
|
||||
}
|
||||
|
||||
auto execute_sub_batch = [lower_inclusive, upper_inclusive](
|
||||
const milvus::ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
ValueType val1,
|
||||
ValueType val2,
|
||||
int index) {
|
||||
if (lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeElementFuncForArray<ValueType, true, true> func;
|
||||
func(val1, val2, index, data, size, res);
|
||||
} else if (lower_inclusive && !upper_inclusive) {
|
||||
BinaryRangeElementFuncForArray<ValueType, true, false> func;
|
||||
func(val1, val2, index, data, size, res);
|
||||
} else if (!lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeElementFuncForArray<ValueType, false, true> func;
|
||||
func(val1, val2, index, data, size, res);
|
||||
} else {
|
||||
BinaryRangeElementFuncForArray<ValueType, false, false> func;
|
||||
func(val1, val2, index, data, size, res);
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, val1, val2, index);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,228 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
template <typename T, bool lower_inclusive, bool upper_inclusive>
|
||||
struct BinaryRangeElementFunc {
|
||||
typedef std::conditional_t<std::is_integral_v<T> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
T>
|
||||
HighPrecisionType;
|
||||
void
|
||||
operator()(T val1, T val2, const T* src, size_t n, bool* res) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if constexpr (lower_inclusive && upper_inclusive) {
|
||||
res[i] = val1 <= src[i] && src[i] <= val2;
|
||||
} else if constexpr (lower_inclusive && !upper_inclusive) {
|
||||
res[i] = val1 <= src[i] && src[i] < val2;
|
||||
} else if constexpr (!lower_inclusive && upper_inclusive) {
|
||||
res[i] = val1 < src[i] && src[i] <= val2;
|
||||
} else {
|
||||
res[i] = val1 < src[i] && src[i] < val2;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define BinaryRangeJSONCompare(cmp) \
|
||||
do { \
|
||||
auto x = src[i].template at<GetType>(pointer); \
|
||||
if (x.error()) { \
|
||||
if constexpr (std::is_same_v<GetType, int64_t>) { \
|
||||
auto x = src[i].template at<double>(pointer); \
|
||||
if (!x.error()) { \
|
||||
auto value = x.value(); \
|
||||
res[i] = (cmp); \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
res[i] = false; \
|
||||
break; \
|
||||
} \
|
||||
auto value = x.value(); \
|
||||
res[i] = (cmp); \
|
||||
} while (false)
|
||||
|
||||
template <typename ValueType, bool lower_inclusive, bool upper_inclusive>
|
||||
struct BinaryRangeElementFuncForJson {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
void
|
||||
operator()(ValueType val1,
|
||||
ValueType val2,
|
||||
const std::string& pointer,
|
||||
const milvus::Json* src,
|
||||
size_t n,
|
||||
bool* res) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if constexpr (lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeJSONCompare(val1 <= value && value <= val2);
|
||||
} else if constexpr (lower_inclusive && !upper_inclusive) {
|
||||
BinaryRangeJSONCompare(val1 <= value && value < val2);
|
||||
} else if constexpr (!lower_inclusive && upper_inclusive) {
|
||||
BinaryRangeJSONCompare(val1 < value && value <= val2);
|
||||
} else {
|
||||
BinaryRangeJSONCompare(val1 < value && value < val2);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ValueType, bool lower_inclusive, bool upper_inclusive>
|
||||
struct BinaryRangeElementFuncForArray {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
void
|
||||
operator()(ValueType val1,
|
||||
ValueType val2,
|
||||
int index,
|
||||
const milvus::ArrayView* src,
|
||||
size_t n,
|
||||
bool* res) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if constexpr (lower_inclusive && upper_inclusive) {
|
||||
if (index >= src[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto value = src[i].get_data<GetType>(index);
|
||||
res[i] = val1 <= value && value <= val2;
|
||||
} else if constexpr (lower_inclusive && !upper_inclusive) {
|
||||
if (index >= src[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto value = src[i].get_data<GetType>(index);
|
||||
res[i] = val1 <= value && value < val2;
|
||||
} else if constexpr (!lower_inclusive && upper_inclusive) {
|
||||
if (index >= src[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto value = src[i].get_data<GetType>(index);
|
||||
res[i] = val1 < value && value <= val2;
|
||||
} else {
|
||||
if (index >= src[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto value = src[i].get_data<GetType>(index);
|
||||
res[i] = val1 < value && value < val2;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BinaryRangeIndexFunc {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
typedef std::conditional_t<std::is_integral_v<IndexInnerType> &&
|
||||
!std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
IndexInnerType>
|
||||
HighPrecisionType;
|
||||
FixedVector<bool>
|
||||
operator()(Index* index,
|
||||
IndexInnerType val1,
|
||||
IndexInnerType val2,
|
||||
bool lower_inclusive,
|
||||
bool upper_inclusive) {
|
||||
return index->Range(val1, lower_inclusive, val2, upper_inclusive);
|
||||
}
|
||||
};
|
||||
|
||||
class PhyBinaryRangeFilterExpr : public SegmentExpr {
|
||||
public:
|
||||
PhyBinaryRangeFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::BinaryRangeFilterExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: SegmentExpr(std::move(input),
|
||||
name,
|
||||
segment,
|
||||
expr->column_.field_id_,
|
||||
query_timestamp,
|
||||
batch_size),
|
||||
expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
// Check overflow and cache result for performace
|
||||
template <
|
||||
typename T,
|
||||
typename IndexInnerType = std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>,
|
||||
typename HighPrecisionType = std::conditional_t<
|
||||
std::is_integral_v<IndexInnerType> && !std::is_same_v<bool, T>,
|
||||
int64_t,
|
||||
IndexInnerType>>
|
||||
ColumnVectorPtr
|
||||
PreCheckOverflow(HighPrecisionType& val1,
|
||||
HighPrecisionType& val2,
|
||||
bool& lower_inclusive,
|
||||
bool& upper_inclusive);
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImpl();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForIndex();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForData();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForJson();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForArray();
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::BinaryRangeFilterExpr> expr_;
|
||||
ColumnVectorPtr cached_overflow_res_{nullptr};
|
||||
int64_t overflow_check_pos_{0};
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,319 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "CompareExpr.h"
|
||||
#include "query/Relational.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
bool
|
||||
PhyCompareFilterExpr::IsStringExpr() {
|
||||
return expr_->left_data_type_ == DataType::VARCHAR ||
|
||||
expr_->right_data_type_ == DataType::VARCHAR;
|
||||
}
|
||||
|
||||
int64_t
|
||||
PhyCompareFilterExpr::GetNextBatchSize() {
|
||||
auto current_rows =
|
||||
segment_->type() == SegmentType::Growing
|
||||
? current_chunk_id_ * size_per_chunk_ + current_chunk_pos_
|
||||
: current_chunk_pos_;
|
||||
return current_rows + batch_size_ >= num_rows_ ? num_rows_ - current_rows
|
||||
: batch_size_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ChunkDataAccessor
|
||||
PhyCompareFilterExpr::GetChunkData(FieldId field_id,
|
||||
int chunk_id,
|
||||
int data_barrier) {
|
||||
if (chunk_id >= data_barrier) {
|
||||
auto& indexing = segment_->chunk_scalar_index<T>(field_id, chunk_id);
|
||||
if (indexing.HasRawData()) {
|
||||
return [&indexing](int i) -> const number {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
auto chunk_data = segment_->chunk_data<T>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
}
|
||||
|
||||
template <>
|
||||
ChunkDataAccessor
|
||||
PhyCompareFilterExpr::GetChunkData<std::string>(FieldId field_id,
|
||||
int chunk_id,
|
||||
int data_barrier) {
|
||||
if (chunk_id >= data_barrier) {
|
||||
auto& indexing =
|
||||
segment_->chunk_scalar_index<std::string>(field_id, chunk_id);
|
||||
if (indexing.HasRawData()) {
|
||||
return [&indexing](int i) -> const std::string {
|
||||
return indexing.Reverse_Lookup(i);
|
||||
};
|
||||
}
|
||||
}
|
||||
if (segment_->type() == SegmentType::Growing) {
|
||||
auto chunk_data =
|
||||
segment_->chunk_data<std::string>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number { return chunk_data[i]; };
|
||||
} else {
|
||||
auto chunk_data =
|
||||
segment_->chunk_data<std::string_view>(field_id, chunk_id).data();
|
||||
return [chunk_data](int i) -> const number {
|
||||
return std::string(chunk_data[i]);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ChunkDataAccessor
|
||||
PhyCompareFilterExpr::GetChunkData(DataType data_type,
|
||||
FieldId field_id,
|
||||
int chunk_id,
|
||||
int data_barrier) {
|
||||
switch (data_type) {
|
||||
case DataType::BOOL:
|
||||
return GetChunkData<bool>(field_id, chunk_id, data_barrier);
|
||||
case DataType::INT8:
|
||||
return GetChunkData<int8_t>(field_id, chunk_id, data_barrier);
|
||||
case DataType::INT16:
|
||||
return GetChunkData<int16_t>(field_id, chunk_id, data_barrier);
|
||||
case DataType::INT32:
|
||||
return GetChunkData<int32_t>(field_id, chunk_id, data_barrier);
|
||||
case DataType::INT64:
|
||||
return GetChunkData<int64_t>(field_id, chunk_id, data_barrier);
|
||||
case DataType::FLOAT:
|
||||
return GetChunkData<float>(field_id, chunk_id, data_barrier);
|
||||
case DataType::DOUBLE:
|
||||
return GetChunkData<double>(field_id, chunk_id, data_barrier);
|
||||
case DataType::VARCHAR: {
|
||||
return GetChunkData<std::string>(field_id, chunk_id, data_barrier);
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}", data_type));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
VectorPtr
|
||||
PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_);
|
||||
auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_);
|
||||
|
||||
int64_t processed_rows = 0;
|
||||
for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_;
|
||||
++chunk_id) {
|
||||
auto chunk_size = chunk_id == num_chunk_ - 1
|
||||
? num_rows_ - chunk_id * size_per_chunk_
|
||||
: size_per_chunk_;
|
||||
auto left = GetChunkData(expr_->left_data_type_,
|
||||
expr_->left_field_id_,
|
||||
chunk_id,
|
||||
left_data_barrier);
|
||||
auto right = GetChunkData(expr_->right_data_type_,
|
||||
expr_->right_field_id_,
|
||||
chunk_id,
|
||||
right_data_barrier);
|
||||
|
||||
for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0;
|
||||
i < chunk_size;
|
||||
++i) {
|
||||
res[processed_rows++] = boost::apply_visitor(
|
||||
milvus::query::Relational<decltype(op)>{}, left(i), right(i));
|
||||
|
||||
if (processed_rows >= batch_size_) {
|
||||
current_chunk_id_ = chunk_id;
|
||||
current_chunk_pos_ = i + 1;
|
||||
return res_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
void
|
||||
PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
// For segment both fields has no index, can use SIMD to speed up.
|
||||
// Avoiding too much call stack that blocks SIMD.
|
||||
if (!is_left_indexed_ && !is_right_indexed_ && !IsStringExpr()) {
|
||||
result = ExecCompareExprDispatcherForBothDataSegment();
|
||||
return;
|
||||
}
|
||||
result = ExecCompareExprDispatcherForHybridSegment();
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyCompareFilterExpr::ExecCompareExprDispatcherForHybridSegment() {
|
||||
switch (expr_->op_type_) {
|
||||
case OpType::Equal: {
|
||||
return ExecCompareExprDispatcher(std::equal_to<>{});
|
||||
}
|
||||
case OpType::NotEqual: {
|
||||
return ExecCompareExprDispatcher(std::not_equal_to<>{});
|
||||
}
|
||||
case OpType::GreaterEqual: {
|
||||
return ExecCompareExprDispatcher(std::greater_equal<>{});
|
||||
}
|
||||
case OpType::GreaterThan: {
|
||||
return ExecCompareExprDispatcher(std::greater<>{});
|
||||
}
|
||||
case OpType::LessEqual: {
|
||||
return ExecCompareExprDispatcher(std::less_equal<>{});
|
||||
}
|
||||
case OpType::LessThan: {
|
||||
return ExecCompareExprDispatcher(std::less<>{});
|
||||
}
|
||||
case OpType::PrefixMatch: {
|
||||
return ExecCompareExprDispatcher(
|
||||
milvus::query::MatchOp<OpType::PrefixMatch>{});
|
||||
}
|
||||
// case OpType::PostfixMatch: {
|
||||
// }
|
||||
default: {
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported optype: {}", expr_->op_type_));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyCompareFilterExpr::ExecCompareExprDispatcherForBothDataSegment() {
|
||||
switch (expr_->left_data_type_) {
|
||||
case DataType::BOOL:
|
||||
return ExecCompareLeftType<bool>();
|
||||
case DataType::INT8:
|
||||
return ExecCompareLeftType<int8_t>();
|
||||
case DataType::INT16:
|
||||
return ExecCompareLeftType<int16_t>();
|
||||
case DataType::INT32:
|
||||
return ExecCompareLeftType<int32_t>();
|
||||
case DataType::INT64:
|
||||
return ExecCompareLeftType<int64_t>();
|
||||
case DataType::FLOAT:
|
||||
return ExecCompareLeftType<float>();
|
||||
case DataType::DOUBLE:
|
||||
return ExecCompareLeftType<double>();
|
||||
default:
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported left datatype:{} of compare expr",
|
||||
expr_->left_data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyCompareFilterExpr::ExecCompareLeftType() {
|
||||
switch (expr_->right_data_type_) {
|
||||
case DataType::BOOL:
|
||||
return ExecCompareRightType<T, bool>();
|
||||
case DataType::INT8:
|
||||
return ExecCompareRightType<T, int8_t>();
|
||||
case DataType::INT16:
|
||||
return ExecCompareRightType<T, int16_t>();
|
||||
case DataType::INT32:
|
||||
return ExecCompareRightType<T, int32_t>();
|
||||
case DataType::INT64:
|
||||
return ExecCompareRightType<T, int64_t>();
|
||||
case DataType::FLOAT:
|
||||
return ExecCompareRightType<T, float>();
|
||||
case DataType::DOUBLE:
|
||||
return ExecCompareRightType<T, double>();
|
||||
default:
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported right datatype:{} of compare expr",
|
||||
expr_->right_data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
VectorPtr
|
||||
PhyCompareFilterExpr::ExecCompareRightType() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto expr_type = expr_->op_type_;
|
||||
auto execute_sub_batch = [expr_type](const T* left,
|
||||
const U* right,
|
||||
const int size,
|
||||
bool* res) {
|
||||
switch (expr_type) {
|
||||
case proto::plan::GreaterThan: {
|
||||
CompareElementFunc<T, U, proto::plan::GreaterThan> func;
|
||||
func(left, right, size, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::GreaterEqual: {
|
||||
CompareElementFunc<T, U, proto::plan::GreaterEqual> func;
|
||||
func(left, right, size, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessThan: {
|
||||
CompareElementFunc<T, U, proto::plan::LessThan> func;
|
||||
func(left, right, size, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessEqual: {
|
||||
CompareElementFunc<T, U, proto::plan::LessEqual> func;
|
||||
func(left, right, size, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::Equal: {
|
||||
CompareElementFunc<T, U, proto::plan::Equal> func;
|
||||
func(left, right, size, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::NotEqual: {
|
||||
CompareElementFunc<T, U, proto::plan::NotEqual> func;
|
||||
func(left, right, size, res);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format(
|
||||
"unsupported operator type for compare column expr: {}",
|
||||
expr_type));
|
||||
}
|
||||
};
|
||||
int64_t processed_size =
|
||||
ProcessBothDataChunks<T, U>(execute_sub_batch, res);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
};
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,186 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
#include <boost/variant.hpp>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
using number = boost::variant<bool,
|
||||
int8_t,
|
||||
int16_t,
|
||||
int32_t,
|
||||
int64_t,
|
||||
float,
|
||||
double,
|
||||
std::string>;
|
||||
using ChunkDataAccessor = std::function<const number(int)>;
|
||||
|
||||
template <typename T, typename U, proto::plan::OpType op>
|
||||
struct CompareElementFunc {
|
||||
void
|
||||
operator()(const T* left, const U* right, size_t size, bool* res) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if constexpr (op == proto::plan::OpType::Equal) {
|
||||
res[i] = left[i] == right[i];
|
||||
} else if constexpr (op == proto::plan::OpType::NotEqual) {
|
||||
res[i] = left[i] != right[i];
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
|
||||
res[i] = left[i] > right[i];
|
||||
} else if constexpr (op == proto::plan::OpType::LessThan) {
|
||||
res[i] = left[i] < right[i];
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
|
||||
res[i] = left[i] >= right[i];
|
||||
} else if constexpr (op == proto::plan::OpType::LessEqual) {
|
||||
res[i] = left[i] <= right[i];
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported op_type:{} for CompareElementFunc",
|
||||
op));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class PhyCompareFilterExpr : public Expr {
|
||||
public:
|
||||
PhyCompareFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::CompareExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: Expr(DataType::BOOL, std::move(input), name),
|
||||
left_field_(expr->left_field_id_),
|
||||
right_field_(expr->right_field_id_),
|
||||
segment_(segment),
|
||||
query_timestamp_(query_timestamp),
|
||||
batch_size_(batch_size),
|
||||
expr_(expr) {
|
||||
is_left_indexed_ = segment_->HasIndex(left_field_);
|
||||
is_right_indexed_ = segment_->HasIndex(right_field_);
|
||||
num_rows_ = segment_->get_active_count(query_timestamp_);
|
||||
num_chunk_ = is_left_indexed_
|
||||
? segment_->num_chunk_index(expr_->left_field_id_)
|
||||
: segment_->num_chunk_data(expr_->left_field_id_);
|
||||
size_per_chunk_ = segment_->size_per_chunk();
|
||||
AssertInfo(
|
||||
batch_size_ > 0,
|
||||
fmt::format("expr batch size should greater than zero, but now: {}",
|
||||
batch_size_));
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
int64_t
|
||||
GetNextBatchSize();
|
||||
|
||||
bool
|
||||
IsStringExpr();
|
||||
|
||||
template <typename T>
|
||||
ChunkDataAccessor
|
||||
GetChunkData(FieldId field_id, int chunk_id, int data_barrier);
|
||||
|
||||
template <typename T, typename U, typename FUNC, typename... ValTypes>
|
||||
int64_t
|
||||
ProcessBothDataChunks(FUNC func, bool* res, ValTypes... values) {
|
||||
int64_t processed_size = 0;
|
||||
|
||||
for (size_t i = current_chunk_id_; i < num_chunk_; i++) {
|
||||
auto left_chunk = segment_->chunk_data<T>(left_field_, i);
|
||||
auto right_chunk = segment_->chunk_data<U>(right_field_, i);
|
||||
auto data_pos = (i == current_chunk_id_) ? current_chunk_pos_ : 0;
|
||||
auto size = (i == (num_chunk_ - 1))
|
||||
? (segment_->type() == SegmentType::Growing
|
||||
? num_rows_ % size_per_chunk_ - data_pos
|
||||
: num_rows_ - data_pos)
|
||||
: size_per_chunk_ - data_pos;
|
||||
|
||||
if (processed_size + size >= batch_size_) {
|
||||
size = batch_size_ - processed_size;
|
||||
}
|
||||
|
||||
const T* left_data = left_chunk.data() + data_pos;
|
||||
const U* right_data = right_chunk.data() + data_pos;
|
||||
func(left_data, right_data, size, res + processed_size, values...);
|
||||
processed_size += size;
|
||||
|
||||
if (processed_size >= batch_size_) {
|
||||
current_chunk_id_ = i;
|
||||
current_chunk_pos_ = data_pos + size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return processed_size;
|
||||
}
|
||||
|
||||
ChunkDataAccessor
|
||||
GetChunkData(DataType data_type,
|
||||
FieldId field_id,
|
||||
int chunk_id,
|
||||
int data_barrier);
|
||||
|
||||
template <typename OpType>
|
||||
VectorPtr
|
||||
ExecCompareExprDispatcher(OpType op);
|
||||
|
||||
VectorPtr
|
||||
ExecCompareExprDispatcherForHybridSegment();
|
||||
|
||||
VectorPtr
|
||||
ExecCompareExprDispatcherForBothDataSegment();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecCompareLeftType();
|
||||
|
||||
template <typename T, typename U>
|
||||
VectorPtr
|
||||
ExecCompareRightType();
|
||||
|
||||
private:
|
||||
const FieldId left_field_;
|
||||
const FieldId right_field_;
|
||||
bool is_left_indexed_;
|
||||
bool is_right_indexed_;
|
||||
int64_t num_rows_{0};
|
||||
int64_t num_chunk_{0};
|
||||
int64_t current_chunk_id_{0};
|
||||
int64_t current_chunk_pos_{0};
|
||||
int64_t size_per_chunk_{0};
|
||||
|
||||
const segcore::SegmentInternalInterface* segment_;
|
||||
Timestamp query_timestamp_;
|
||||
int64_t batch_size_;
|
||||
std::shared_ptr<const milvus::expr::CompareExpr> expr_;
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,131 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ConjunctExpr.h"
|
||||
#include "simd/hook.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
DataType
|
||||
PhyConjunctFilterExpr::ResolveType(const std::vector<DataType>& inputs) {
|
||||
AssertInfo(
|
||||
inputs.size() > 0,
|
||||
fmt::format(
|
||||
"Conjunct expressions expect at least one argument, received: {}",
|
||||
inputs.size()));
|
||||
|
||||
for (const auto& type : inputs) {
|
||||
AssertInfo(
|
||||
type == DataType::BOOL,
|
||||
fmt::format("Conjunct expressions expect BOOLEAN, received: {}",
|
||||
type));
|
||||
}
|
||||
return DataType::BOOL;
|
||||
}
|
||||
|
||||
static bool
|
||||
AllTrue(ColumnVectorPtr& vec) {
|
||||
bool* data = static_cast<bool*>(vec->GetRawData());
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
return milvus::simd::all_true(data, vec->size());
|
||||
#else
|
||||
for (int i = 0; i < vec->size(); ++i) {
|
||||
if (!data[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void
|
||||
AllSet(ColumnVectorPtr& vec) {
|
||||
bool* data = static_cast<bool*>(vec->GetRawData());
|
||||
for (int i = 0; i < vec->size(); ++i) {
|
||||
data[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
AllReset(ColumnVectorPtr& vec) {
|
||||
bool* data = static_cast<bool*>(vec->GetRawData());
|
||||
for (int i = 0; i < vec->size(); ++i) {
|
||||
data[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
AllFalse(ColumnVectorPtr& vec) {
|
||||
bool* data = static_cast<bool*>(vec->GetRawData());
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
return milvus::simd::all_false(data, vec->size());
|
||||
#else
|
||||
for (int i = 0; i < vec->size(); ++i) {
|
||||
if (data[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
int64_t
|
||||
PhyConjunctFilterExpr::UpdateResult(ColumnVectorPtr& input_result,
|
||||
EvalCtx& ctx,
|
||||
ColumnVectorPtr& result) {
|
||||
if (is_and_) {
|
||||
ConjunctElementFunc<true> func;
|
||||
return func(input_result, result);
|
||||
} else {
|
||||
ConjunctElementFunc<false> func;
|
||||
return func(input_result, result);
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
PhyConjunctFilterExpr::CanSkipNextExprs(ColumnVectorPtr& vec) {
|
||||
if ((is_and_ && AllFalse(vec)) || (!is_and_ && AllTrue(vec))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
PhyConjunctFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
for (int i = 0; i < inputs_.size(); ++i) {
|
||||
VectorPtr input_result;
|
||||
inputs_[i]->Eval(context, input_result);
|
||||
if (i == 0) {
|
||||
result = input_result;
|
||||
auto all_flat_result = GetColumnVector(result);
|
||||
if (CanSkipNextExprs(all_flat_result)) {
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto input_flat_result = GetColumnVector(input_result);
|
||||
auto all_flat_result = GetColumnVector(result);
|
||||
auto active_rows =
|
||||
UpdateResult(input_flat_result, context, all_flat_result);
|
||||
if (active_rows == 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,89 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
template <bool is_and>
|
||||
struct ConjunctElementFunc {
|
||||
int64_t
|
||||
operator()(ColumnVectorPtr& input_result, ColumnVectorPtr& result) {
|
||||
bool* input_data = static_cast<bool*>(input_result->GetRawData());
|
||||
bool* res_data = static_cast<bool*>(result->GetRawData());
|
||||
int64_t activate_rows = 0;
|
||||
for (int i = 0; i < result->size(); ++i) {
|
||||
if constexpr (is_and) {
|
||||
res_data[i] &= input_data[i];
|
||||
if (res_data[i]) {
|
||||
activate_rows++;
|
||||
}
|
||||
} else {
|
||||
res_data[i] |= input_data[i];
|
||||
if (!res_data[i]) {
|
||||
activate_rows++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return activate_rows;
|
||||
}
|
||||
};
|
||||
|
||||
class PhyConjunctFilterExpr : public Expr {
|
||||
public:
|
||||
PhyConjunctFilterExpr(std::vector<ExprPtr>&& inputs, bool is_and)
|
||||
: Expr(DataType::BOOL, std::move(inputs), is_and ? "and" : "or"),
|
||||
is_and_(is_and) {
|
||||
std::vector<DataType> input_types;
|
||||
input_types.reserve(inputs_.size());
|
||||
|
||||
std::transform(inputs_.begin(),
|
||||
inputs_.end(),
|
||||
std::back_inserter(input_types),
|
||||
[](const ExprPtr& expr) { return expr->type(); });
|
||||
|
||||
ResolveType(input_types);
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
int64_t
|
||||
UpdateResult(ColumnVectorPtr& input_result,
|
||||
EvalCtx& ctx,
|
||||
ColumnVectorPtr& result);
|
||||
|
||||
static DataType
|
||||
ResolveType(const std::vector<DataType>& inputs);
|
||||
|
||||
bool
|
||||
CanSkipNextExprs(ColumnVectorPtr& vec);
|
||||
// true if conjunction (and), false if disjunction (or).
|
||||
bool is_and_;
|
||||
std::vector<int32_t> input_order_;
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,62 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Vector.h"
|
||||
#include "exec/QueryContext.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class ExprSet;
|
||||
class EvalCtx {
|
||||
public:
|
||||
EvalCtx(ExecContext* exec_ctx, ExprSet* expr_set, RowVector* row)
|
||||
: exec_ctx_(exec_ctx), expr_set_(expr_set_), row_(row) {
|
||||
assert(exec_ctx_ != nullptr);
|
||||
assert(expr_set_ != nullptr);
|
||||
// assert(row_ != nullptr);
|
||||
}
|
||||
|
||||
explicit EvalCtx(ExecContext* exec_ctx)
|
||||
: exec_ctx_(exec_ctx), expr_set_(nullptr), row_(nullptr) {
|
||||
}
|
||||
|
||||
ExecContext*
|
||||
get_exec_context() {
|
||||
return exec_ctx_;
|
||||
}
|
||||
|
||||
std::shared_ptr<QueryConfig>
|
||||
get_query_config() {
|
||||
return exec_ctx_->get_query_config();
|
||||
}
|
||||
|
||||
private:
|
||||
ExecContext* exec_ctx_;
|
||||
ExprSet* expr_set_;
|
||||
RowVector* row_;
|
||||
bool input_no_nulls_;
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,72 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ExistsExpr.h"
|
||||
#include "common/Json.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
switch (expr_->column_.data_type_) {
|
||||
case DataType::JSON: {
|
||||
if (is_index_mode_) {
|
||||
PanicInfo(ExprInvalid,
|
||||
"exists expr for json index mode not supportted");
|
||||
}
|
||||
result = EvalJsonExistsForDataSegment();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}",
|
||||
expr_->column_.data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyExistsFilterExpr::EvalJsonExistsForDataSegment() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
auto execute_sub_batch = [](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
res[i] = data[i].exist(pointer);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,66 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
template <typename T>
|
||||
struct ExistsElementFunc {
|
||||
void
|
||||
operator()(const T* src, size_t size, T val, bool* res) {
|
||||
}
|
||||
};
|
||||
|
||||
class PhyExistsFilterExpr : public SegmentExpr {
|
||||
public:
|
||||
PhyExistsFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::ExistsExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: SegmentExpr(std::move(input),
|
||||
name,
|
||||
segment,
|
||||
expr->column_.field_id_,
|
||||
query_timestamp,
|
||||
batch_size),
|
||||
expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
VectorPtr
|
||||
EvalJsonExistsForDataSegment();
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::ExistsExpr> expr_;
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,255 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Expr.h"
|
||||
|
||||
#include "exec/expression/AlwaysTrueExpr.h"
|
||||
#include "exec/expression/BinaryArithOpEvalRangeExpr.h"
|
||||
#include "exec/expression/BinaryRangeExpr.h"
|
||||
#include "exec/expression/CompareExpr.h"
|
||||
#include "exec/expression/ConjunctExpr.h"
|
||||
#include "exec/expression/ExistsExpr.h"
|
||||
#include "exec/expression/JsonContainsExpr.h"
|
||||
#include "exec/expression/LogicalBinaryExpr.h"
|
||||
#include "exec/expression/LogicalUnaryExpr.h"
|
||||
#include "exec/expression/TermExpr.h"
|
||||
#include "exec/expression/UnaryExpr.h"
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
ExprSet::Eval(int32_t begin,
|
||||
int32_t end,
|
||||
bool initialize,
|
||||
EvalCtx& context,
|
||||
std::vector<VectorPtr>& results) {
|
||||
results.resize(exprs_.size());
|
||||
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
exprs_[i]->Eval(context, results[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ExprPtr>
|
||||
CompileExpressions(const std::vector<expr::TypedExprPtr>& sources,
|
||||
ExecContext* context,
|
||||
const std::unordered_set<std::string>& flatten_candidate,
|
||||
bool enable_constant_folding) {
|
||||
std::vector<std::shared_ptr<Expr>> exprs;
|
||||
exprs.reserve(sources.size());
|
||||
|
||||
for (auto& source : sources) {
|
||||
exprs.emplace_back(CompileExpression(source,
|
||||
context->get_query_context(),
|
||||
flatten_candidate,
|
||||
enable_constant_folding));
|
||||
}
|
||||
return exprs;
|
||||
}
|
||||
|
||||
static std::optional<std::string>
|
||||
ShouldFlatten(const expr::TypedExprPtr& expr,
|
||||
const std::unordered_set<std::string>& flat_candidates = {}) {
|
||||
if (auto call =
|
||||
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
|
||||
if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And ||
|
||||
call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) {
|
||||
return call->name();
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static bool
|
||||
IsCall(const expr::TypedExprPtr& expr, const std::string& name) {
|
||||
if (auto call =
|
||||
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
|
||||
return call->name() == name;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool
|
||||
AllInputTypeEqual(const expr::TypedExprPtr& expr) {
|
||||
const auto& inputs = expr->inputs();
|
||||
for (int i = 1; i < inputs.size(); i++) {
|
||||
if (inputs[0]->type() != inputs[i]->type()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void
|
||||
FlattenInput(const expr::TypedExprPtr& input,
|
||||
const std::string& flatten_call,
|
||||
std::vector<expr::TypedExprPtr>& flat) {
|
||||
if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) {
|
||||
for (auto& child : input->inputs()) {
|
||||
FlattenInput(child, flatten_call, flat);
|
||||
}
|
||||
} else {
|
||||
flat.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ExprPtr>
|
||||
CompileInputs(const expr::TypedExprPtr& expr,
|
||||
QueryContext* context,
|
||||
const std::unordered_set<std::string>& flatten_cadidates) {
|
||||
std::vector<ExprPtr> compiled_inputs;
|
||||
auto flatten = ShouldFlatten(expr);
|
||||
for (auto& input : expr->inputs()) {
|
||||
if (dynamic_cast<const expr::InputTypeExpr*>(input.get())) {
|
||||
AssertInfo(
|
||||
dynamic_cast<const expr::FieldAccessTypeExpr*>(expr.get()),
|
||||
"An InputReference can only occur under a FieldReference");
|
||||
} else {
|
||||
if (flatten.has_value()) {
|
||||
std::vector<expr::TypedExprPtr> flat_exprs;
|
||||
FlattenInput(input, flatten.value(), flat_exprs);
|
||||
for (auto& input : flat_exprs) {
|
||||
compiled_inputs.push_back(CompileExpression(
|
||||
input, context, flatten_cadidates, false));
|
||||
}
|
||||
} else {
|
||||
compiled_inputs.push_back(CompileExpression(
|
||||
input, context, flatten_cadidates, false));
|
||||
}
|
||||
}
|
||||
}
|
||||
return compiled_inputs;
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
CompileExpression(const expr::TypedExprPtr& expr,
|
||||
QueryContext* context,
|
||||
const std::unordered_set<std::string>& flatten_candidates,
|
||||
bool enable_constant_folding) {
|
||||
ExprPtr result;
|
||||
|
||||
auto result_type = expr->type();
|
||||
auto compiled_inputs = CompileInputs(expr, context, flatten_candidates);
|
||||
|
||||
auto GetTypes = [](const std::vector<ExprPtr>& exprs) {
|
||||
std::vector<DataType> types;
|
||||
for (auto& expr : exprs) {
|
||||
types.push_back(expr->type());
|
||||
}
|
||||
return types;
|
||||
};
|
||||
auto input_types = GetTypes(compiled_inputs);
|
||||
|
||||
if (auto call = dynamic_cast<const expr::CallTypeExpr*>(expr.get())) {
|
||||
// TODO: support function register and search mode
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::UnaryRangeFilterExpr>(expr)) {
|
||||
result = std::make_shared<PhyUnaryRangeFilterExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyUnaryRangeFilterExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::LogicalUnaryExpr>(expr)) {
|
||||
result = std::make_shared<PhyLogicalUnaryExpr>(
|
||||
compiled_inputs, casted_expr, "PhyLogicalUnaryExpr");
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::TermFilterExpr>(expr)) {
|
||||
result = std::make_shared<PhyTermFilterExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyTermFilterExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::LogicalBinaryExpr>(expr)) {
|
||||
if (casted_expr->op_type_ ==
|
||||
milvus::expr::LogicalBinaryExpr::OpType::And ||
|
||||
casted_expr->op_type_ ==
|
||||
milvus::expr::LogicalBinaryExpr::OpType::Or) {
|
||||
result = std::make_shared<PhyConjunctFilterExpr>(
|
||||
std::move(compiled_inputs),
|
||||
casted_expr->op_type_ ==
|
||||
milvus::expr::LogicalBinaryExpr::OpType::And);
|
||||
} else {
|
||||
result = std::make_shared<PhyLogicalBinaryExpr>(
|
||||
compiled_inputs, casted_expr, "PhyLogicalBinaryExpr");
|
||||
}
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::BinaryRangeFilterExpr>(expr)) {
|
||||
result = std::make_shared<PhyBinaryRangeFilterExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyBinaryRangeFilterExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::AlwaysTrueExpr>(expr)) {
|
||||
result = std::make_shared<PhyAlwaysTrueExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyAlwaysTrueExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::BinaryArithOpEvalRangeExpr>(expr)) {
|
||||
result = std::make_shared<PhyBinaryArithOpEvalRangeExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyBinaryArithOpEvalRangeExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr =
|
||||
std::dynamic_pointer_cast<const milvus::expr::CompareExpr>(
|
||||
expr)) {
|
||||
result = std::make_shared<PhyCompareFilterExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyCompareFilterExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr =
|
||||
std::dynamic_pointer_cast<const milvus::expr::ExistsExpr>(
|
||||
expr)) {
|
||||
result = std::make_shared<PhyExistsFilterExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyExistsFilterExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else if (auto casted_expr = std::dynamic_pointer_cast<
|
||||
const milvus::expr::JsonContainsExpr>(expr)) {
|
||||
result = std::make_shared<PhyJsonContainsFilterExpr>(
|
||||
compiled_inputs,
|
||||
casted_expr,
|
||||
"PhyJsonContainsFilterExpr",
|
||||
context->get_segment(),
|
||||
context->get_query_timestamp(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,324 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "exec/expression/EvalCtx.h"
|
||||
#include "exec/expression/VectorFunction.h"
|
||||
#include "exec/expression/Utils.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "query/PlanProto.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class Expr {
|
||||
public:
|
||||
Expr(DataType type,
|
||||
const std::vector<std::shared_ptr<Expr>>&& inputs,
|
||||
const std::string& name)
|
||||
: type_(type),
|
||||
inputs_(std::move(inputs)),
|
||||
name_(name),
|
||||
vector_func_(nullptr) {
|
||||
}
|
||||
|
||||
Expr(DataType type,
|
||||
const std::vector<std::shared_ptr<Expr>>&& inputs,
|
||||
std::shared_ptr<VectorFunction> vec_func,
|
||||
const std::string& name)
|
||||
: type_(type),
|
||||
inputs_(std::move(inputs)),
|
||||
name_(name),
|
||||
vector_func_(vec_func) {
|
||||
}
|
||||
virtual ~Expr() = default;
|
||||
|
||||
const DataType&
|
||||
type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
std::string
|
||||
get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
virtual void
|
||||
Eval(EvalCtx& context, VectorPtr& result) {
|
||||
}
|
||||
|
||||
protected:
|
||||
DataType type_;
|
||||
const std::vector<std::shared_ptr<Expr>> inputs_;
|
||||
std::string name_;
|
||||
std::shared_ptr<VectorFunction> vector_func_;
|
||||
};
|
||||
|
||||
using ExprPtr = std::shared_ptr<milvus::exec::Expr>;
|
||||
|
||||
using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int);
|
||||
|
||||
class SegmentExpr : public Expr {
|
||||
public:
|
||||
SegmentExpr(const std::vector<ExprPtr>&& input,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const FieldId& field_id,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: Expr(DataType::BOOL, std::move(input), name),
|
||||
segment_(segment),
|
||||
field_id_(field_id),
|
||||
query_timestamp_(query_timestamp),
|
||||
batch_size_(batch_size) {
|
||||
num_rows_ = segment_->get_active_count(query_timestamp_);
|
||||
size_per_chunk_ = segment_->size_per_chunk();
|
||||
AssertInfo(
|
||||
batch_size_ > 0,
|
||||
fmt::format("expr batch size should greater than zero, but now: {}",
|
||||
batch_size_));
|
||||
InitSegmentExpr();
|
||||
}
|
||||
|
||||
void
|
||||
InitSegmentExpr() {
|
||||
auto& schema = segment_->get_schema();
|
||||
auto& field_meta = schema[field_id_];
|
||||
|
||||
if (schema.get_primary_field_id().has_value() &&
|
||||
schema.get_primary_field_id().value() == field_id_ &&
|
||||
IsPrimaryKeyDataType(field_meta.get_data_type())) {
|
||||
is_pk_field_ = true;
|
||||
pk_type_ = field_meta.get_data_type();
|
||||
}
|
||||
|
||||
is_index_mode_ = segment_->HasIndex(field_id_);
|
||||
if (is_index_mode_) {
|
||||
num_index_chunk_ = segment_->num_chunk_index(field_id_);
|
||||
} else {
|
||||
num_data_chunk_ = segment_->num_chunk_data(field_id_);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetNextBatchSize() {
|
||||
auto current_chunk =
|
||||
is_index_mode_ ? current_index_chunk_ : current_data_chunk_;
|
||||
auto current_chunk_pos =
|
||||
is_index_mode_ ? current_index_chunk_pos_ : current_data_chunk_pos_;
|
||||
auto current_rows = current_chunk * size_per_chunk_ + current_chunk_pos;
|
||||
return current_rows + batch_size_ >= num_rows_
|
||||
? num_rows_ - current_rows
|
||||
: batch_size_;
|
||||
}
|
||||
|
||||
template <typename T, typename FUNC, typename... ValTypes>
|
||||
int64_t
|
||||
ProcessDataChunks(
|
||||
FUNC func,
|
||||
std::function<bool(const milvus::SkipIndex&, FieldId, int)> skip_func,
|
||||
bool* res,
|
||||
ValTypes... values) {
|
||||
int64_t processed_size = 0;
|
||||
|
||||
for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) {
|
||||
auto data_pos =
|
||||
(i == current_data_chunk_) ? current_data_chunk_pos_ : 0;
|
||||
auto size = (i == (num_data_chunk_ - 1))
|
||||
? (segment_->type() == SegmentType::Growing
|
||||
? num_rows_ % size_per_chunk_ - data_pos
|
||||
: num_rows_ - data_pos)
|
||||
: size_per_chunk_ - data_pos;
|
||||
|
||||
size = std::min(size, batch_size_ - processed_size);
|
||||
|
||||
auto& skip_index = segment_->GetSkipIndex();
|
||||
if (!skip_func || !skip_func(skip_index, field_id_, i)) {
|
||||
auto chunk = segment_->chunk_data<T>(field_id_, i);
|
||||
const T* data = chunk.data() + data_pos;
|
||||
func(data, size, res + processed_size, values...);
|
||||
}
|
||||
|
||||
processed_size += size;
|
||||
if (processed_size >= batch_size_) {
|
||||
current_data_chunk_ = i;
|
||||
current_data_chunk_pos_ = data_pos + size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return processed_size;
|
||||
}
|
||||
|
||||
int
|
||||
ProcessIndexOneChunk(FixedVector<bool>& result,
|
||||
size_t chunk_id,
|
||||
const FixedVector<bool>& chunk_res,
|
||||
int processed_rows) {
|
||||
auto data_pos =
|
||||
chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0;
|
||||
auto size = std::min(
|
||||
std::min(size_per_chunk_ - data_pos, batch_size_ - processed_rows),
|
||||
int64_t(chunk_res.size()));
|
||||
|
||||
result.insert(result.end(),
|
||||
chunk_res.begin() + data_pos,
|
||||
chunk_res.begin() + data_pos + size);
|
||||
return size;
|
||||
}
|
||||
|
||||
template <typename T, typename FUNC, typename... ValTypes>
|
||||
FixedVector<bool>
|
||||
ProcessIndexChunks(FUNC func, ValTypes... values) {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
FixedVector<bool> result;
|
||||
int processed_rows = 0;
|
||||
|
||||
for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) {
|
||||
// This cache result help getting result for every batch loop.
|
||||
// It avoids indexing execute for evevy batch because indexing
|
||||
// executing costs quite much time.
|
||||
if (cached_index_chunk_id_ != i) {
|
||||
const Index& index =
|
||||
segment_->chunk_scalar_index<IndexInnerType>(field_id_, i);
|
||||
auto* index_ptr = const_cast<Index*>(&index);
|
||||
cached_index_chunk_res_ = std::move(func(index_ptr, values...));
|
||||
cached_index_chunk_id_ = i;
|
||||
}
|
||||
|
||||
auto size = ProcessIndexOneChunk(
|
||||
result, i, cached_index_chunk_res_, processed_rows);
|
||||
|
||||
if (processed_rows + size >= batch_size_) {
|
||||
current_index_chunk_ = i;
|
||||
current_index_chunk_pos_ = i == current_index_chunk_
|
||||
? current_index_chunk_pos_ + size
|
||||
: size;
|
||||
break;
|
||||
}
|
||||
processed_rows += size;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
protected:
|
||||
const segcore::SegmentInternalInterface* segment_;
|
||||
const FieldId field_id_;
|
||||
bool is_pk_field_{false};
|
||||
DataType pk_type_;
|
||||
Timestamp query_timestamp_;
|
||||
int64_t batch_size_;
|
||||
|
||||
// State indicate position that expr computing at
|
||||
// because expr maybe called for every batch.
|
||||
bool is_index_mode_{false};
|
||||
bool is_data_mode_{false};
|
||||
|
||||
int64_t num_rows_{0};
|
||||
int64_t num_data_chunk_{0};
|
||||
int64_t num_index_chunk_{0};
|
||||
int64_t current_data_chunk_{0};
|
||||
int64_t current_data_chunk_pos_{0};
|
||||
int64_t current_index_chunk_{0};
|
||||
int64_t current_index_chunk_pos_{0};
|
||||
int64_t size_per_chunk_{0};
|
||||
|
||||
// Cache for index scan to avoid search index every batch
|
||||
int64_t cached_index_chunk_id_{-1};
|
||||
FixedVector<bool> cached_index_chunk_res_{};
|
||||
};
|
||||
|
||||
std::vector<ExprPtr>
|
||||
CompileExpressions(const std::vector<expr::TypedExprPtr>& logical_exprs,
|
||||
ExecContext* context,
|
||||
const std::unordered_set<std::string>& flatten_cadidates =
|
||||
std::unordered_set<std::string>(),
|
||||
bool enable_constant_folding = false);
|
||||
|
||||
std::vector<ExprPtr>
|
||||
CompileInputs(const expr::TypedExprPtr& expr,
|
||||
QueryContext* config,
|
||||
const std::unordered_set<std::string>& flatten_cadidates);
|
||||
|
||||
ExprPtr
|
||||
CompileExpression(const expr::TypedExprPtr& expr,
|
||||
QueryContext* context,
|
||||
const std::unordered_set<std::string>& flatten_cadidates,
|
||||
bool enable_constant_folding);
|
||||
|
||||
class ExprSet {
|
||||
public:
|
||||
explicit ExprSet(const std::vector<expr::TypedExprPtr>& logical_exprs,
|
||||
ExecContext* exec_ctx) {
|
||||
exprs_ = CompileExpressions(logical_exprs, exec_ctx);
|
||||
}
|
||||
|
||||
virtual ~ExprSet() = default;
|
||||
|
||||
void
|
||||
Eval(EvalCtx& ctx, std::vector<VectorPtr>& results) {
|
||||
Eval(0, exprs_.size(), true, ctx, results);
|
||||
}
|
||||
|
||||
virtual void
|
||||
Eval(int32_t begin,
|
||||
int32_t end,
|
||||
bool initialize,
|
||||
EvalCtx& ctx,
|
||||
std::vector<VectorPtr>& result);
|
||||
|
||||
void
|
||||
Clear() {
|
||||
exprs_.clear();
|
||||
}
|
||||
|
||||
ExecContext*
|
||||
get_exec_context() const {
|
||||
return exec_ctx_;
|
||||
}
|
||||
|
||||
size_t
|
||||
size() const {
|
||||
return exprs_.size();
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Expr>>&
|
||||
exprs() const {
|
||||
return exprs_;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Expr>&
|
||||
expr(int32_t index) const {
|
||||
return exprs_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Expr>> exprs_;
|
||||
ExecContext* exec_ctx_;
|
||||
};
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,740 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "JsonContainsExpr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyJsonContainsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
switch (expr_->column_.data_type_) {
|
||||
case DataType::ARRAY:
|
||||
case DataType::JSON: {
|
||||
if (is_index_mode_) {
|
||||
PanicInfo(
|
||||
ExprInvalid,
|
||||
"exists expr for json or array index mode not supportted");
|
||||
}
|
||||
result = EvalJsonContainsForDataSegment();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}",
|
||||
expr_->column_.data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() {
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
switch (expr_->op_) {
|
||||
case proto::plan::JSONContainsExpr_JSONOp_Contains:
|
||||
case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: {
|
||||
if (datatype_is_array(data_type)) {
|
||||
auto val_type = expr_->vals_[0].val_case();
|
||||
switch (val_type) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
return ExecArrayContains<bool>();
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
return ExecArrayContains<int64_t>();
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
return ExecArrayContains<double>();
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
return ExecArrayContains<std::string>();
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported data type {}", val_type));
|
||||
}
|
||||
} else {
|
||||
if (expr_->same_type_) {
|
||||
auto val_type = expr_->vals_[0].val_case();
|
||||
switch (val_type) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
return ExecJsonContains<bool>();
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
return ExecJsonContains<int64_t>();
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
return ExecJsonContains<double>();
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
return ExecJsonContains<std::string>();
|
||||
}
|
||||
case proto::plan::GenericValue::kArrayVal: {
|
||||
return ExecJsonContainsArray();
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type:{}",
|
||||
val_type));
|
||||
}
|
||||
} else {
|
||||
return ExecJsonContainsWithDiffType();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: {
|
||||
if (datatype_is_array(data_type)) {
|
||||
auto val_type = expr_->vals_[0].val_case();
|
||||
switch (val_type) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
return ExecArrayContainsAll<bool>();
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
return ExecArrayContainsAll<int64_t>();
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
return ExecArrayContainsAll<double>();
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
return ExecArrayContainsAll<std::string>();
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported data type {}", val_type));
|
||||
}
|
||||
} else {
|
||||
if (expr_->same_type_) {
|
||||
auto val_type = expr_->vals_[0].val_case();
|
||||
switch (val_type) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
return ExecJsonContainsAll<bool>();
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
return ExecJsonContainsAll<int64_t>();
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
return ExecJsonContainsAll<double>();
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
return ExecJsonContainsAll<std::string>();
|
||||
}
|
||||
case proto::plan::GenericValue::kArrayVal: {
|
||||
return ExecJsonContainsAllArray();
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type:{}",
|
||||
val_type));
|
||||
}
|
||||
} else {
|
||||
return ExecJsonContainsAllWithDiffType();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(ExprInvalid,
|
||||
fmt::format("unsupported json contains type {}",
|
||||
proto::plan::JSONContainsExpr_JSONOp_Name(
|
||||
expr_->op_)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecArrayContains() {
|
||||
using GetType =
|
||||
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
|
||||
std::string_view,
|
||||
ExprValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
AssertInfo(expr_->column_.nested_path_.size() == 0,
|
||||
"[ExecArrayContains]nested path must be null");
|
||||
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
std::unordered_set<GetType> elements;
|
||||
for (auto const& element : expr_->vals_) {
|
||||
elements.insert(GetValueFromProto<GetType>(element));
|
||||
}
|
||||
auto execute_sub_batch = [](const milvus::ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::unordered_set<GetType>& elements) {
|
||||
auto executor = [&](size_t i) {
|
||||
const auto& array = data[i];
|
||||
for (int j = 0; j < array.length(); ++j) {
|
||||
if (elements.count(array.template get_data<GetType>(j)) > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (int i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecJsonContains() {
|
||||
using GetType =
|
||||
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
|
||||
std::string_view,
|
||||
ExprValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
std::unordered_set<GetType> elements;
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
for (auto const& element : expr_->vals_) {
|
||||
elements.insert(GetValueFromProto<GetType>(element));
|
||||
}
|
||||
auto execute_sub_batch = [](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer,
|
||||
const std::unordered_set<GetType>& elements) {
|
||||
auto executor = [&](size_t i) {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
return false;
|
||||
}
|
||||
for (auto&& it : array) {
|
||||
auto val = it.template get<GetType>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (elements.count(val.value()) > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecJsonContainsArray() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
std::vector<proto::plan::Array> elements;
|
||||
for (auto const& element : expr_->vals_) {
|
||||
elements.emplace_back(GetValueFromProto<proto::plan::Array>(element));
|
||||
}
|
||||
auto execute_sub_batch =
|
||||
[](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer,
|
||||
const std::vector<proto::plan::Array>& elements) {
|
||||
auto executor = [&](size_t i) -> bool {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
return false;
|
||||
}
|
||||
for (auto&& it : array) {
|
||||
auto val = it.get_array();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<
|
||||
simdjson::simdjson_result<simdjson::ondemand::value>>
|
||||
json_array;
|
||||
json_array.reserve(val.count_elements());
|
||||
for (auto&& e : val) {
|
||||
json_array.emplace_back(e);
|
||||
}
|
||||
for (auto const& element : elements) {
|
||||
if (CompareTwoJsonArray(json_array, element)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<milvus::Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecArrayContainsAll() {
|
||||
using GetType =
|
||||
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
|
||||
std::string_view,
|
||||
ExprValueType>;
|
||||
AssertInfo(expr_->column_.nested_path_.size() == 0,
|
||||
"[ExecArrayContainsAll]nested path must be null");
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
std::unordered_set<GetType> elements;
|
||||
for (auto const& element : expr_->vals_) {
|
||||
elements.insert(GetValueFromProto<GetType>(element));
|
||||
}
|
||||
|
||||
auto execute_sub_batch = [](const milvus::ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::unordered_set<GetType>& elements) {
|
||||
auto executor = [&](size_t i) {
|
||||
std::unordered_set<GetType> tmp_elements(elements);
|
||||
// Note: array can only be iterated once
|
||||
for (int j = 0; j < data[i].length(); ++j) {
|
||||
tmp_elements.erase(data[i].template get_data<GetType>(j));
|
||||
if (tmp_elements.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return tmp_elements.size() == 0;
|
||||
};
|
||||
for (int i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecJsonContainsAll() {
|
||||
using GetType =
|
||||
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
|
||||
std::string_view,
|
||||
ExprValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
std::unordered_set<GetType> elements;
|
||||
for (auto const& element : expr_->vals_) {
|
||||
elements.insert(GetValueFromProto<GetType>(element));
|
||||
}
|
||||
|
||||
auto execute_sub_batch = [](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer,
|
||||
const std::unordered_set<GetType>& elements) {
|
||||
auto executor = [&](const size_t i) -> bool {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
return false;
|
||||
}
|
||||
std::unordered_set<GetType> tmp_elements(elements);
|
||||
// Note: array can only be iterated once
|
||||
for (auto&& it : array) {
|
||||
auto val = it.template get<GetType>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
tmp_elements.erase(val.value());
|
||||
if (tmp_elements.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return tmp_elements.size() == 0;
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
|
||||
auto elements = expr_->vals_;
|
||||
std::unordered_set<int> elements_index;
|
||||
int i = 0;
|
||||
for (auto& element : elements) {
|
||||
elements_index.insert(i);
|
||||
i++;
|
||||
}
|
||||
|
||||
auto execute_sub_batch =
|
||||
[](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer,
|
||||
const std::vector<proto::plan::GenericValue>& elements,
|
||||
const std::unordered_set<int> elements_index) {
|
||||
auto executor = [&](size_t i) -> bool {
|
||||
const auto& json = data[i];
|
||||
auto doc = json.doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
return false;
|
||||
}
|
||||
std::unordered_set<int> tmp_elements_index(elements_index);
|
||||
for (auto&& it : array) {
|
||||
int i = -1;
|
||||
for (auto& element : elements) {
|
||||
i++;
|
||||
switch (element.val_case()) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
auto val = it.template get<bool>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.bool_val()) {
|
||||
tmp_elements_index.erase(i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
auto val = it.template get<int64_t>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.int64_val()) {
|
||||
tmp_elements_index.erase(i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
auto val = it.template get<double>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.float_val()) {
|
||||
tmp_elements_index.erase(i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
auto val = it.template get<std::string_view>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.string_val()) {
|
||||
tmp_elements_index.erase(i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kArrayVal: {
|
||||
auto val = it.get_array();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (CompareTwoJsonArray(val,
|
||||
element.array_val())) {
|
||||
tmp_elements_index.erase(i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported data type {}",
|
||||
element.val_case()));
|
||||
}
|
||||
if (tmp_elements_index.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (tmp_elements_index.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return tmp_elements_index.size() == 0;
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<Json>(execute_sub_batch,
|
||||
std::nullptr_t{},
|
||||
res,
|
||||
pointer,
|
||||
elements,
|
||||
elements_index);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
|
||||
std::vector<proto::plan::Array> elements;
|
||||
for (auto const& element : expr_->vals_) {
|
||||
elements.emplace_back(GetValueFromProto<proto::plan::Array>(element));
|
||||
}
|
||||
auto execute_sub_batch =
|
||||
[](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer,
|
||||
const std::vector<proto::plan::Array>& elements) {
|
||||
auto executor = [&](const size_t i) {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
return false;
|
||||
}
|
||||
std::unordered_set<int> exist_elements_index;
|
||||
for (auto&& it : array) {
|
||||
auto val = it.get_array();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<
|
||||
simdjson::simdjson_result<simdjson::ondemand::value>>
|
||||
json_array;
|
||||
json_array.reserve(val.count_elements());
|
||||
for (auto&& e : val) {
|
||||
json_array.emplace_back(e);
|
||||
}
|
||||
for (int index = 0; index < elements.size(); ++index) {
|
||||
if (CompareTwoJsonArray(json_array, elements[index])) {
|
||||
exist_elements_index.insert(index);
|
||||
}
|
||||
}
|
||||
if (exist_elements_index.size() == elements.size()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return exist_elements_index.size() == elements.size();
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
|
||||
auto elements = expr_->vals_;
|
||||
std::unordered_set<int> elements_index;
|
||||
int i = 0;
|
||||
for (auto& element : elements) {
|
||||
elements_index.insert(i);
|
||||
i++;
|
||||
}
|
||||
|
||||
auto execute_sub_batch =
|
||||
[](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string& pointer,
|
||||
const std::vector<proto::plan::GenericValue>& elements) {
|
||||
auto executor = [&](const size_t i) {
|
||||
auto& json = data[i];
|
||||
auto doc = json.doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
return false;
|
||||
}
|
||||
// Note: array can only be iterated once
|
||||
for (auto&& it : array) {
|
||||
for (auto const& element : elements) {
|
||||
switch (element.val_case()) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
auto val = it.template get<bool>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.bool_val()) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
auto val = it.template get<int64_t>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.int64_val()) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
auto val = it.template get<double>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.float_val()) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
auto val = it.template get<std::string_view>();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (val.value() == element.string_val()) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kArrayVal: {
|
||||
auto val = it.get_array();
|
||||
if (val.error()) {
|
||||
continue;
|
||||
}
|
||||
if (CompareTwoJsonArray(val,
|
||||
element.array_val())) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported data type {}",
|
||||
element.val_case()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, elements);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,87 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class PhyJsonContainsFilterExpr : public SegmentExpr {
|
||||
public:
|
||||
PhyJsonContainsFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::JsonContainsExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: SegmentExpr(std::move(input),
|
||||
name,
|
||||
segment,
|
||||
expr->column_.field_id_,
|
||||
query_timestamp,
|
||||
batch_size),
|
||||
expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
VectorPtr
|
||||
EvalJsonContainsForDataSegment();
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
ExecJsonContains();
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
ExecArrayContains();
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
ExecJsonContainsAll();
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
ExecArrayContainsAll();
|
||||
|
||||
VectorPtr
|
||||
ExecJsonContainsArray();
|
||||
|
||||
VectorPtr
|
||||
ExecJsonContainsAllArray();
|
||||
|
||||
VectorPtr
|
||||
ExecJsonContainsAllWithDiffType();
|
||||
|
||||
VectorPtr
|
||||
ExecJsonContainsWithDiffType();
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::JsonContainsExpr> expr_;
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,51 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "LogicalBinaryExpr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
AssertInfo(inputs_.size() == 2,
|
||||
fmt::format("logical binary expr must has two input, but now {}",
|
||||
inputs_.size()));
|
||||
VectorPtr left;
|
||||
inputs_[0]->Eval(context, left);
|
||||
VectorPtr right;
|
||||
inputs_[1]->Eval(context, right);
|
||||
auto lflat = GetColumnVector(left);
|
||||
auto rflat = GetColumnVector(right);
|
||||
auto size = left->size();
|
||||
bool* ldata = static_cast<bool*>(lflat->GetRawData());
|
||||
bool* rdata = static_cast<bool*>(rflat->GetRawData());
|
||||
if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::And) {
|
||||
LogicalElementFunc<LogicalOpType::And> func;
|
||||
func(ldata, rdata, size);
|
||||
} else if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::Or) {
|
||||
LogicalElementFunc<LogicalOpType::Or> func;
|
||||
func(ldata, rdata, size);
|
||||
} else {
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported logical operator: {}",
|
||||
expr_->GetOpTypeString()));
|
||||
}
|
||||
result = std::move(left);
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,78 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "simd/hook.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
enum class LogicalOpType { Invalid = 0, And = 1, Or = 2, Xor = 3, Minus = 4 };
|
||||
|
||||
template <LogicalOpType op>
|
||||
struct LogicalElementFunc {
|
||||
void
|
||||
operator()(bool* left, bool* right, int n) {
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
if constexpr (op == LogicalOpType::And) {
|
||||
milvus::simd::and_bool(left, right, n);
|
||||
} else if constexpr (op == LogicalOpType::Or) {
|
||||
milvus::simd::or_bool(left, right, n);
|
||||
} else {
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported logical operator: {}", op));
|
||||
}
|
||||
#else
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if constexpr (op == LogicalOpType::And) {
|
||||
left[i] &= right[i];
|
||||
} else if constexpr (op == LogicalOpType::Or) {
|
||||
left[i] |= right[i];
|
||||
} else {
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported logical operator: {}", op));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class PhyLogicalBinaryExpr : public Expr {
|
||||
public:
|
||||
PhyLogicalBinaryExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::LogicalBinaryExpr>& expr,
|
||||
const std::string& name)
|
||||
: Expr(DataType::BOOL, std::move(input), name), expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::LogicalBinaryExpr> expr_;
|
||||
};
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,44 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "LogicalUnaryExpr.h"
|
||||
#include "simd/hook.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyLogicalUnaryExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
AssertInfo(inputs_.size() == 1,
|
||||
fmt::format("logical unary expr must has one input, but now {}",
|
||||
inputs_.size()));
|
||||
|
||||
inputs_[0]->Eval(context, result);
|
||||
if (expr_->op_type_ == milvus::expr::LogicalUnaryExpr::OpType::LogicalNot) {
|
||||
auto flat_vec = GetColumnVector(result);
|
||||
bool* data = static_cast<bool*>(flat_vec->GetRawData());
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
milvus::simd::invert_bool(data, flat_vec->size());
|
||||
#else
|
||||
for (int i = 0; i < flat_vec->size(); ++i) {
|
||||
data[i] = !data[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,47 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class PhyLogicalUnaryExpr : public Expr {
|
||||
public:
|
||||
PhyLogicalUnaryExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::LogicalUnaryExpr>& expr,
|
||||
const std::string& name)
|
||||
: Expr(DataType::BOOL, std::move(input), name), expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::LogicalUnaryExpr> expr_;
|
||||
};
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,540 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "TermExpr.h"
|
||||
#include "query/Utils.h"
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyTermFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
if (is_pk_field_) {
|
||||
result = ExecPkTermImpl();
|
||||
return;
|
||||
}
|
||||
switch (expr_->column_.data_type_) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecVisitorImpl<bool>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT8: {
|
||||
result = ExecVisitorImpl<int8_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
result = ExecVisitorImpl<int16_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
result = ExecVisitorImpl<int32_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
result = ExecVisitorImpl<int64_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
result = ExecVisitorImpl<float>();
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
result = ExecVisitorImpl<double>();
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
if (segment_->type() == SegmentType::Growing) {
|
||||
result = ExecVisitorImpl<std::string>();
|
||||
} else {
|
||||
result = ExecVisitorImpl<std::string_view>();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::JSON: {
|
||||
if (expr_->vals_.size() == 0) {
|
||||
result = ExecVisitorImplTemplateJson<bool>();
|
||||
break;
|
||||
}
|
||||
auto type = expr_->vals_[0].val_case();
|
||||
switch (type) {
|
||||
case proto::plan::GenericValue::ValCase::kBoolVal:
|
||||
result = ExecVisitorImplTemplateJson<bool>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val:
|
||||
result = ExecVisitorImplTemplateJson<int64_t>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal:
|
||||
result = ExecVisitorImplTemplateJson<double>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kStringVal:
|
||||
result = ExecVisitorImplTemplateJson<std::string>();
|
||||
break;
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unknown data type: {}", type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::ARRAY: {
|
||||
if (expr_->vals_.size() == 0) {
|
||||
result = ExecVisitorImplTemplateArray<bool>();
|
||||
break;
|
||||
}
|
||||
auto type = expr_->vals_[0].val_case();
|
||||
switch (type) {
|
||||
case proto::plan::GenericValue::ValCase::kBoolVal:
|
||||
result = ExecVisitorImplTemplateArray<bool>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val:
|
||||
result = ExecVisitorImplTemplateArray<int64_t>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal:
|
||||
result = ExecVisitorImplTemplateArray<double>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kStringVal:
|
||||
result = ExecVisitorImplTemplateArray<std::string>();
|
||||
break;
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unknown data type: {}", type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}",
|
||||
expr_->column_.data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
PhyTermFilterExpr::InitPkCacheOffset() {
|
||||
auto id_array = std::make_unique<IdArray>();
|
||||
switch (pk_type_) {
|
||||
case DataType::INT64: {
|
||||
auto dst_ids = id_array->mutable_int_id();
|
||||
for (const auto& id : expr_->vals_) {
|
||||
dst_ids->add_data(GetValueFromProto<int64_t>(id));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto dst_ids = id_array->mutable_str_id();
|
||||
for (const auto& id : expr_->vals_) {
|
||||
dst_ids->add_data(GetValueFromProto<std::string>(id));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type {}", pk_type_));
|
||||
}
|
||||
}
|
||||
|
||||
auto [uids, seg_offsets] =
|
||||
segment_->search_ids(*id_array, query_timestamp_);
|
||||
cached_bits_.resize(num_rows_);
|
||||
cached_offsets_ =
|
||||
std::make_shared<ColumnVector>(DataType::INT64, seg_offsets.size());
|
||||
int64_t* cached_offsets_ptr = (int64_t*)cached_offsets_->GetRawData();
|
||||
int i = 0;
|
||||
for (const auto& offset : seg_offsets) {
|
||||
auto _offset = (int64_t)offset.get();
|
||||
cached_bits_[_offset] = true;
|
||||
cached_offsets_ptr[i++] = _offset;
|
||||
}
|
||||
cached_offsets_inited_ = true;
|
||||
}
|
||||
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecPkTermImpl() {
|
||||
if (!cached_offsets_inited_) {
|
||||
InitPkCacheOffset();
|
||||
}
|
||||
|
||||
auto real_batch_size = current_data_chunk_pos_ + batch_size_ >= num_rows_
|
||||
? num_rows_ - current_data_chunk_pos_
|
||||
: batch_size_;
|
||||
current_data_chunk_pos_ += real_batch_size;
|
||||
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
for (size_t i = 0; i < real_batch_size; ++i) {
|
||||
res[i] = cached_bits_[i];
|
||||
}
|
||||
|
||||
std::vector<VectorPtr> vecs{res_vec, cached_offsets_};
|
||||
return std::make_shared<RowVector>(vecs);
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecVisitorImplTemplateJson() {
|
||||
if (expr_->is_in_field_) {
|
||||
return ExecTermJsonVariableInField<ValueType>();
|
||||
} else {
|
||||
return ExecTermJsonFieldInVariable<ValueType>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecVisitorImplTemplateArray() {
|
||||
if (expr_->is_in_field_) {
|
||||
return ExecTermArrayVariableInField<ValueType>();
|
||||
} else {
|
||||
return ExecTermArrayFieldInVariable<ValueType>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecTermArrayVariableInField() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
AssertInfo(expr_->vals_.size() == 1,
|
||||
"element length in json array must be one");
|
||||
ValueType target_val = GetValueFromProto<ValueType>(expr_->vals_[0]);
|
||||
|
||||
auto execute_sub_batch = [](const ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const ValueType& target_val) {
|
||||
auto executor = [&](size_t i) {
|
||||
for (int i = 0; i < data[i].length(); i++) {
|
||||
auto val = data[i].template get_data<GetType>(i);
|
||||
if (val == target_val) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (int i = 0; i < size; ++i) {
|
||||
executor(i);
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, target_val);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecTermArrayFieldInVariable() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
int index = -1;
|
||||
if (expr_->column_.nested_path_.size() > 0) {
|
||||
index = std::stoi(expr_->column_.nested_path_[0]);
|
||||
}
|
||||
std::unordered_set<ValueType> term_set;
|
||||
for (const auto& element : expr_->vals_) {
|
||||
term_set.insert(GetValueFromProto<ValueType>(element));
|
||||
}
|
||||
|
||||
if (term_set.empty()) {
|
||||
for (size_t i = 0; i < real_batch_size; ++i) {
|
||||
res[i] = false;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
auto execute_sub_batch = [](const ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
int index,
|
||||
const std::unordered_set<ValueType>& term_set) {
|
||||
if (term_set.empty()) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
res[i] = false;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (index >= data[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto value = data[i].get_data<GetType>(index);
|
||||
res[i] = term_set.find(ValueType(value)) != term_set.end();
|
||||
}
|
||||
};
|
||||
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, index, term_set);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecTermJsonVariableInField() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
AssertInfo(expr_->vals_.size() == 1,
|
||||
"element length in json array must be one");
|
||||
ValueType val = GetValueFromProto<ValueType>(expr_->vals_[0]);
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
|
||||
auto execute_sub_batch = [](const Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string pointer,
|
||||
const ValueType& target_val) {
|
||||
auto executor = [&](size_t i) {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error())
|
||||
return false;
|
||||
for (auto it = array.begin(); it != array.end(); ++it) {
|
||||
auto val = (*it).template get<GetType>();
|
||||
if (val.error()) {
|
||||
return false;
|
||||
}
|
||||
if (val.value() == target_val) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, val);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecTermJsonFieldInVariable() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
std::unordered_set<ValueType> term_set;
|
||||
for (const auto& element : expr_->vals_) {
|
||||
term_set.insert(GetValueFromProto<ValueType>(element));
|
||||
}
|
||||
|
||||
if (term_set.empty()) {
|
||||
for (size_t i = 0; i < real_batch_size; ++i) {
|
||||
res[i] = false;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
auto execute_sub_batch = [](const Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::string pointer,
|
||||
const std::unordered_set<ValueType>& terms) {
|
||||
auto executor = [&](size_t i) {
|
||||
auto x = data[i].template at<GetType>(pointer);
|
||||
if (x.error()) {
|
||||
if constexpr (std::is_same_v<GetType, std::int64_t>) {
|
||||
auto x = data[i].template at<double>(pointer);
|
||||
if (x.error()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value = x.value();
|
||||
// if the term set is {1}, and the value is 1.1, we should not return true.
|
||||
return std::floor(value) == value &&
|
||||
terms.find(ValueType(value)) != terms.end();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return terms.find(ValueType(x.value())) != terms.end();
|
||||
};
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = executor(i);
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, pointer, term_set);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecVisitorImpl() {
|
||||
if (is_index_mode_) {
|
||||
return ExecVisitorImplForIndex<T>();
|
||||
} else {
|
||||
return ExecVisitorImplForData<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecVisitorImplForIndex() {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<IndexInnerType> vals;
|
||||
for (auto& val : expr_->vals_) {
|
||||
auto converted_val = GetValueFromProto<T>(val);
|
||||
// Integral overflow process
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
if (milvus::query::out_of_range<T>(converted_val)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
vals.emplace_back(converted_val);
|
||||
}
|
||||
auto execute_sub_batch = [](Index* index_ptr,
|
||||
const std::vector<IndexInnerType>& vals) {
|
||||
TermIndexFunc<T> func;
|
||||
return func(index_ptr, vals.size(), vals.data());
|
||||
};
|
||||
auto res = ProcessIndexChunks<T>(execute_sub_batch, vals);
|
||||
AssertInfo(res.size() == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
res.size(),
|
||||
real_batch_size));
|
||||
return std::make_shared<ColumnVector>(std::move(res));
|
||||
}
|
||||
|
||||
template <>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecVisitorImplForIndex<bool>() {
|
||||
using Index = index::ScalarIndex<bool>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> vals;
|
||||
for (auto& val : expr_->vals_) {
|
||||
vals.emplace_back(GetValueFromProto<bool>(val) ? 1 : 0);
|
||||
}
|
||||
auto execute_sub_batch = [](Index* index_ptr,
|
||||
const std::vector<uint8_t>& vals) {
|
||||
TermIndexFunc<bool> func;
|
||||
return std::move(func(index_ptr, vals.size(), (bool*)vals.data()));
|
||||
};
|
||||
auto res = ProcessIndexChunks<bool>(execute_sub_batch, vals);
|
||||
return std::make_shared<ColumnVector>(std::move(res));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyTermFilterExpr::ExecVisitorImplForData() {
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
std::vector<T> vals;
|
||||
for (auto& val : expr_->vals_) {
|
||||
// Integral overflow process
|
||||
bool overflowed = false;
|
||||
auto converted_val = GetValueFromProtoWithOverflow<T>(val, overflowed);
|
||||
if (!overflowed) {
|
||||
vals.emplace_back(converted_val);
|
||||
}
|
||||
}
|
||||
std::unordered_set<T> vals_set(vals.begin(), vals.end());
|
||||
auto execute_sub_batch = [](const T* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
const std::unordered_set<T>& vals) {
|
||||
TermElementFuncSet<T> func;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = func(vals, data[i]);
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<T>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, vals_set);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,134 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
template <typename T>
|
||||
struct TermElementFuncFlat {
|
||||
bool
|
||||
operator()(const T* src, size_t n, T val) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (src[i] == val) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TermElementFuncSet {
|
||||
bool
|
||||
operator()(const std::unordered_set<T>& srcs, T val) {
|
||||
return srcs.find(val) != srcs.end();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TermIndexFunc {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
FixedVector<bool>
|
||||
operator()(Index* index, size_t n, const IndexInnerType* val) {
|
||||
return index->In(n, val);
|
||||
}
|
||||
};
|
||||
|
||||
class PhyTermFilterExpr : public SegmentExpr {
|
||||
public:
|
||||
PhyTermFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::TermFilterExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: SegmentExpr(std::move(input),
|
||||
name,
|
||||
segment,
|
||||
expr->column_.field_id_,
|
||||
query_timestamp,
|
||||
batch_size),
|
||||
expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
void
|
||||
InitPkCacheOffset();
|
||||
|
||||
VectorPtr
|
||||
ExecPkTermImpl();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecVisitorImpl();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecVisitorImplForIndex();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecVisitorImplForData();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecVisitorImplTemplateJson();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecTermJsonVariableInField();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecTermJsonFieldInVariable();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecVisitorImplTemplateArray();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecTermArrayVariableInField();
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
ExecTermArrayFieldInVariable();
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::TermFilterExpr> expr_;
|
||||
// If expr is like "pk in (..)", can use pk index to optimize
|
||||
bool cached_offsets_inited_{false};
|
||||
ColumnVectorPtr cached_offsets_;
|
||||
FixedVector<bool> cached_bits_;
|
||||
};
|
||||
} //namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,593 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "UnaryExpr.h"
|
||||
#include "common/Json.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
void
|
||||
PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
switch (expr_->column_.data_type_) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecRangeVisitorImpl<bool>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT8: {
|
||||
result = ExecRangeVisitorImpl<int8_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
result = ExecRangeVisitorImpl<int16_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
result = ExecRangeVisitorImpl<int32_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
result = ExecRangeVisitorImpl<int64_t>();
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
result = ExecRangeVisitorImpl<float>();
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
result = ExecRangeVisitorImpl<double>();
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
if (segment_->type() == SegmentType::Growing) {
|
||||
result = ExecRangeVisitorImpl<std::string>();
|
||||
} else {
|
||||
result = ExecRangeVisitorImpl<std::string_view>();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::JSON: {
|
||||
auto val_type = expr_->val_.val_case();
|
||||
switch (val_type) {
|
||||
case proto::plan::GenericValue::ValCase::kBoolVal:
|
||||
result = ExecRangeVisitorImplJson<bool>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val:
|
||||
result = ExecRangeVisitorImplJson<int64_t>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal:
|
||||
result = ExecRangeVisitorImplJson<double>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kStringVal:
|
||||
result = ExecRangeVisitorImplJson<std::string>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kArrayVal:
|
||||
result = ExecRangeVisitorImplJson<proto::plan::Array>();
|
||||
break;
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unknown data type: {}", val_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::ARRAY: {
|
||||
auto val_type = expr_->val_.val_case();
|
||||
switch (val_type) {
|
||||
case proto::plan::GenericValue::ValCase::kBoolVal:
|
||||
result = ExecRangeVisitorImplArray<bool>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kInt64Val:
|
||||
result = ExecRangeVisitorImplArray<int64_t>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kFloatVal:
|
||||
result = ExecRangeVisitorImplArray<double>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kStringVal:
|
||||
result = ExecRangeVisitorImplArray<std::string>();
|
||||
break;
|
||||
case proto::plan::GenericValue::ValCase::kArrayVal:
|
||||
result = ExecRangeVisitorImplArray<proto::plan::Array>();
|
||||
break;
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unknown data type: {}", val_type));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type: {}",
|
||||
expr_->column_.data_type_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
VectorPtr
|
||||
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
|
||||
ValueType val = GetValueFromProto<ValueType>(expr_->val_);
|
||||
auto op_type = expr_->op_type_;
|
||||
int index = -1;
|
||||
if (expr_->column_.nested_path_.size() > 0) {
|
||||
index = std::stoi(expr_->column_.nested_path_[0]);
|
||||
}
|
||||
auto execute_sub_batch = [op_type](const milvus::ArrayView* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
ValueType val,
|
||||
int index) {
|
||||
switch (op_type) {
|
||||
case proto::plan::GreaterThan: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::GreaterThan>
|
||||
func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::GreaterEqual: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::GreaterEqual>
|
||||
func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessThan: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::LessThan> func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessEqual: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::LessEqual>
|
||||
func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::Equal: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::Equal> func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::NotEqual: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::NotEqual> func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::PrefixMatch: {
|
||||
UnaryElementFuncForArray<ValueType, proto::plan::PrefixMatch>
|
||||
func;
|
||||
func(data, size, val, index, res);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for unary expr: {}",
|
||||
op_type));
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::ArrayView>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, val, index);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() {
|
||||
using GetType =
|
||||
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
|
||||
std::string_view,
|
||||
ExprValueType>;
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ExprValueType val = GetValueFromProto<ExprValueType>(expr_->val_);
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto op_type = expr_->op_type_;
|
||||
auto pointer = milvus::Json::pointer(expr_->column_.nested_path_);
|
||||
|
||||
#define UnaryRangeJSONCompare(cmp) \
|
||||
do { \
|
||||
auto x = data[i].template at<GetType>(pointer); \
|
||||
if (x.error()) { \
|
||||
if constexpr (std::is_same_v<GetType, int64_t>) { \
|
||||
auto x = data[i].template at<double>(pointer); \
|
||||
res[i] = !x.error() && (cmp); \
|
||||
break; \
|
||||
} \
|
||||
res[i] = false; \
|
||||
break; \
|
||||
} \
|
||||
res[i] = (cmp); \
|
||||
} while (false)
|
||||
|
||||
#define UnaryRangeJSONCompareNotEqual(cmp) \
|
||||
do { \
|
||||
auto x = data[i].template at<GetType>(pointer); \
|
||||
if (x.error()) { \
|
||||
if constexpr (std::is_same_v<GetType, int64_t>) { \
|
||||
auto x = data[i].template at<double>(pointer); \
|
||||
res[i] = x.error() || (cmp); \
|
||||
break; \
|
||||
} \
|
||||
res[i] = true; \
|
||||
break; \
|
||||
} \
|
||||
res[i] = (cmp); \
|
||||
} while (false)
|
||||
|
||||
auto execute_sub_batch = [op_type, pointer](const milvus::Json* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
ExprValueType val) {
|
||||
switch (op_type) {
|
||||
case proto::plan::GreaterThan: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = false;
|
||||
} else {
|
||||
UnaryRangeJSONCompare(x.value() > val);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GreaterEqual: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = false;
|
||||
} else {
|
||||
UnaryRangeJSONCompare(x.value() >= val);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessThan: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = false;
|
||||
} else {
|
||||
UnaryRangeJSONCompare(x.value() < val);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessEqual: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = false;
|
||||
} else {
|
||||
UnaryRangeJSONCompare(x.value() <= val);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::Equal: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
res[i] = CompareTwoJsonArray(array, val);
|
||||
} else {
|
||||
UnaryRangeJSONCompare(x.value() == val);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::NotEqual: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
auto doc = data[i].doc();
|
||||
auto array = doc.at_pointer(pointer).get_array();
|
||||
if (array.error()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
res[i] = !CompareTwoJsonArray(array, val);
|
||||
} else {
|
||||
UnaryRangeJSONCompareNotEqual(x.value() != val);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::PrefixMatch: {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = false;
|
||||
} else {
|
||||
UnaryRangeJSONCompare(milvus::query::Match(
|
||||
ExprValueType(x.value()), val, op_type));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for unary expr: {}",
|
||||
op_type));
|
||||
}
|
||||
};
|
||||
int64_t processed_size = ProcessDataChunks<milvus::Json>(
|
||||
execute_sub_batch, std::nullptr_t{}, res, val);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyUnaryRangeFilterExpr::ExecRangeVisitorImpl() {
|
||||
if (is_index_mode_) {
|
||||
return ExecRangeVisitorImplForIndex<T>();
|
||||
} else {
|
||||
return ExecRangeVisitorImplForData<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForIndex() {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
if (auto res = PreCheckOverflow<T>()) {
|
||||
return res;
|
||||
}
|
||||
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto op_type = expr_->op_type_;
|
||||
auto execute_sub_batch = [op_type](Index* index_ptr, IndexInnerType val) {
|
||||
FixedVector<bool> res;
|
||||
switch (op_type) {
|
||||
case proto::plan::GreaterThan: {
|
||||
UnaryIndexFunc<T, proto::plan::GreaterThan> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
case proto::plan::GreaterEqual: {
|
||||
UnaryIndexFunc<T, proto::plan::GreaterEqual> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessThan: {
|
||||
UnaryIndexFunc<T, proto::plan::LessThan> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessEqual: {
|
||||
UnaryIndexFunc<T, proto::plan::LessEqual> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
case proto::plan::Equal: {
|
||||
UnaryIndexFunc<T, proto::plan::Equal> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
case proto::plan::NotEqual: {
|
||||
UnaryIndexFunc<T, proto::plan::NotEqual> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
case proto::plan::PrefixMatch: {
|
||||
UnaryIndexFunc<T, proto::plan::PrefixMatch> func;
|
||||
res = std::move(func(index_ptr, val));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for unary expr: {}",
|
||||
op_type));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
auto val = GetValueFromProto<IndexInnerType>(expr_->val_);
|
||||
auto res = ProcessIndexChunks<T>(execute_sub_batch, val);
|
||||
AssertInfo(res.size() == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
res.size(),
|
||||
real_batch_size));
|
||||
return std::make_shared<ColumnVector>(std::move(res));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ColumnVectorPtr
|
||||
PhyUnaryRangeFilterExpr::PreCheckOverflow() {
|
||||
if constexpr (std::is_integral_v<T> && !std::is_same_v<T, bool>) {
|
||||
int64_t val = GetValueFromProto<int64_t>(expr_->val_);
|
||||
|
||||
if (milvus::query::out_of_range<T>(val)) {
|
||||
int64_t batch_size = overflow_check_pos_ + batch_size_ >= num_rows_
|
||||
? num_rows_ - overflow_check_pos_
|
||||
: batch_size_;
|
||||
overflow_check_pos_ += batch_size;
|
||||
if (cached_overflow_res_ != nullptr &&
|
||||
cached_overflow_res_->size() == batch_size) {
|
||||
return cached_overflow_res_;
|
||||
}
|
||||
switch (expr_->op_type_) {
|
||||
case proto::plan::GreaterThan:
|
||||
case proto::plan::GreaterEqual: {
|
||||
auto res_vec = std::make_shared<ColumnVector>(
|
||||
DataType::BOOL, batch_size);
|
||||
cached_overflow_res_ = res_vec;
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
if (milvus::query::lt_lb<T>(val)) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res[i] = true;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
case proto::plan::LessThan:
|
||||
case proto::plan::LessEqual: {
|
||||
auto res_vec = std::make_shared<ColumnVector>(
|
||||
DataType::BOOL, batch_size);
|
||||
cached_overflow_res_ = res_vec;
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
if (milvus::query::gt_ub<T>(val)) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res[i] = true;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
case proto::plan::Equal: {
|
||||
auto res_vec = std::make_shared<ColumnVector>(
|
||||
DataType::BOOL, batch_size);
|
||||
cached_overflow_res_ = res_vec;
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res[i] = false;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
case proto::plan::NotEqual: {
|
||||
auto res_vec = std::make_shared<ColumnVector>(
|
||||
DataType::BOOL, batch_size);
|
||||
cached_overflow_res_ = res_vec;
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res[i] = true;
|
||||
}
|
||||
return res_vec;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported range node {}",
|
||||
expr_->op_type_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
if (auto res = PreCheckOverflow<T>()) {
|
||||
return res;
|
||||
}
|
||||
|
||||
auto real_batch_size = GetNextBatchSize();
|
||||
if (real_batch_size == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
IndexInnerType val = GetValueFromProto<IndexInnerType>(expr_->val_);
|
||||
auto res_vec =
|
||||
std::make_shared<ColumnVector>(DataType::BOOL, real_batch_size);
|
||||
bool* res = (bool*)res_vec->GetRawData();
|
||||
auto expr_type = expr_->op_type_;
|
||||
auto execute_sub_batch = [expr_type](const T* data,
|
||||
const int size,
|
||||
bool* res,
|
||||
IndexInnerType val) {
|
||||
switch (expr_type) {
|
||||
case proto::plan::GreaterThan: {
|
||||
UnaryElementFunc<T, proto::plan::GreaterThan> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::GreaterEqual: {
|
||||
UnaryElementFunc<T, proto::plan::GreaterEqual> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessThan: {
|
||||
UnaryElementFunc<T, proto::plan::LessThan> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::LessEqual: {
|
||||
UnaryElementFunc<T, proto::plan::LessEqual> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::Equal: {
|
||||
UnaryElementFunc<T, proto::plan::Equal> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::NotEqual: {
|
||||
UnaryElementFunc<T, proto::plan::NotEqual> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
case proto::plan::PrefixMatch: {
|
||||
UnaryElementFunc<T, proto::plan::PrefixMatch> func;
|
||||
func(data, size, val, res);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported operator type for unary expr: {}",
|
||||
expr_type));
|
||||
}
|
||||
};
|
||||
auto skip_index_func = [expr_type, val](const SkipIndex& skip_index,
|
||||
FieldId field_id,
|
||||
int64_t chunk_id) {
|
||||
return skip_index.CanSkipUnaryRange<T>(
|
||||
field_id, chunk_id, expr_type, val);
|
||||
};
|
||||
int64_t processed_size =
|
||||
ProcessDataChunks<T>(execute_sub_batch, skip_index_func, res, val);
|
||||
AssertInfo(processed_size == real_batch_size,
|
||||
fmt::format("internal error: expr processed rows {} not equal "
|
||||
"expect batch size {}",
|
||||
processed_size,
|
||||
real_batch_size));
|
||||
return res_vec;
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,220 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "index/Meta.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "query/Utils.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
template <typename T, proto::plan::OpType op>
|
||||
struct UnaryElementFunc {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
void
|
||||
operator()(const T* src, size_t size, IndexInnerType val, bool* res) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if constexpr (op == proto::plan::OpType::Equal) {
|
||||
res[i] = src[i] == val;
|
||||
} else if constexpr (op == proto::plan::OpType::NotEqual) {
|
||||
res[i] = src[i] != val;
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
|
||||
res[i] = src[i] > val;
|
||||
} else if constexpr (op == proto::plan::OpType::LessThan) {
|
||||
res[i] = src[i] < val;
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
|
||||
res[i] = src[i] >= val;
|
||||
} else if constexpr (op == proto::plan::OpType::LessEqual) {
|
||||
res[i] = src[i] <= val;
|
||||
} else if constexpr (op == proto::plan::OpType::PrefixMatch) {
|
||||
res[i] = milvus::query::Match(
|
||||
src[i], val, proto::plan::OpType::PrefixMatch);
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported op_type:{} for UnaryElementFunc",
|
||||
op));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define UnaryArrayCompare(cmp) \
|
||||
do { \
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) { \
|
||||
res[i] = false; \
|
||||
} else { \
|
||||
if (index >= src[i].length()) { \
|
||||
res[i] = false; \
|
||||
continue; \
|
||||
} \
|
||||
auto array_data = src[i].template get_data<GetType>(index); \
|
||||
res[i] = (cmp); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
template <typename ValueType, proto::plan::OpType op>
|
||||
struct UnaryElementFuncForArray {
|
||||
using GetType = std::conditional_t<std::is_same_v<ValueType, std::string>,
|
||||
std::string_view,
|
||||
ValueType>;
|
||||
void
|
||||
operator()(const ArrayView* src,
|
||||
size_t size,
|
||||
ValueType val,
|
||||
int index,
|
||||
bool* res) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if constexpr (op == proto::plan::OpType::Equal) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = src[i].is_same_array(val);
|
||||
} else {
|
||||
if (index >= src[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto array_data = src[i].template get_data<GetType>(index);
|
||||
res[i] = array_data == val;
|
||||
}
|
||||
} else if constexpr (op == proto::plan::OpType::NotEqual) {
|
||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||
res[i] = !src[i].is_same_array(val);
|
||||
} else {
|
||||
if (index >= src[i].length()) {
|
||||
res[i] = false;
|
||||
continue;
|
||||
}
|
||||
auto array_data = src[i].template get_data<GetType>(index);
|
||||
res[i] = array_data != val;
|
||||
}
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
|
||||
UnaryArrayCompare(array_data > val);
|
||||
} else if constexpr (op == proto::plan::OpType::LessThan) {
|
||||
UnaryArrayCompare(array_data < val);
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
|
||||
UnaryArrayCompare(array_data >= val);
|
||||
} else if constexpr (op == proto::plan::OpType::LessEqual) {
|
||||
UnaryArrayCompare(array_data <= val);
|
||||
} else if constexpr (op == proto::plan::OpType::PrefixMatch) {
|
||||
UnaryArrayCompare(milvus::query::Match(array_data, val, op));
|
||||
} else {
|
||||
PanicInfo(OpTypeInvalid,
|
||||
fmt::format("unsupported op_type:{} for "
|
||||
"UnaryElementFuncForArray",
|
||||
op));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, proto::plan::OpType op>
|
||||
struct UnaryIndexFunc {
|
||||
typedef std::
|
||||
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
|
||||
IndexInnerType;
|
||||
using Index = index::ScalarIndex<IndexInnerType>;
|
||||
FixedVector<bool>
|
||||
operator()(Index* index, IndexInnerType val) {
|
||||
if constexpr (op == proto::plan::OpType::Equal) {
|
||||
return index->In(1, &val);
|
||||
} else if constexpr (op == proto::plan::OpType::NotEqual) {
|
||||
return index->NotIn(1, &val);
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterThan) {
|
||||
return index->Range(val, OpType::GreaterThan);
|
||||
} else if constexpr (op == proto::plan::OpType::LessThan) {
|
||||
return index->Range(val, OpType::LessThan);
|
||||
} else if constexpr (op == proto::plan::OpType::GreaterEqual) {
|
||||
return index->Range(val, OpType::GreaterEqual);
|
||||
} else if constexpr (op == proto::plan::OpType::LessEqual) {
|
||||
return index->Range(val, OpType::LessEqual);
|
||||
} else if constexpr (op == proto::plan::OpType::PrefixMatch) {
|
||||
auto dataset = std::make_unique<Dataset>();
|
||||
dataset->Set(milvus::index::OPERATOR_TYPE,
|
||||
proto::plan::OpType::PrefixMatch);
|
||||
dataset->Set(milvus::index::PREFIX_VALUE, val);
|
||||
return index->Query(std::move(dataset));
|
||||
} else {
|
||||
PanicInfo(
|
||||
OpTypeInvalid,
|
||||
fmt::format("unsupported op_type:{} for UnaryIndexFunc", op));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class PhyUnaryRangeFilterExpr : public SegmentExpr {
|
||||
public:
|
||||
PhyUnaryRangeFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::UnaryRangeFilterExpr>& expr,
|
||||
const std::string& name,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
Timestamp query_timestamp,
|
||||
int64_t batch_size)
|
||||
: SegmentExpr(std::move(input),
|
||||
name,
|
||||
segment,
|
||||
expr->column_.field_id_,
|
||||
query_timestamp,
|
||||
batch_size),
|
||||
expr_(expr) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImpl();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForIndex();
|
||||
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplForData();
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplJson();
|
||||
|
||||
template <typename ExprValueType>
|
||||
VectorPtr
|
||||
ExecRangeVisitorImplArray();
|
||||
|
||||
// Check overflow and cache result for performace
|
||||
template <typename T>
|
||||
ColumnVectorPtr
|
||||
PreCheckOverflow();
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::UnaryRangeFilterExpr> expr_;
|
||||
ColumnVectorPtr cached_overflow_res_{nullptr};
|
||||
int64_t overflow_check_pos_{0};
|
||||
};
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,166 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "query/Utils.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
static ColumnVectorPtr
|
||||
GetColumnVector(const VectorPtr& result) {
|
||||
ColumnVectorPtr res;
|
||||
if (auto convert_vector = std::dynamic_pointer_cast<ColumnVector>(result)) {
|
||||
res = convert_vector;
|
||||
} else if (auto convert_vector =
|
||||
std::dynamic_pointer_cast<RowVector>(result)) {
|
||||
if (auto convert_flat_vector = std::dynamic_pointer_cast<ColumnVector>(
|
||||
convert_vector->child(0))) {
|
||||
res = convert_flat_vector;
|
||||
} else {
|
||||
PanicInfo(
|
||||
UnexpectedError,
|
||||
"RowVector result must have a first ColumnVector children");
|
||||
}
|
||||
} else {
|
||||
PanicInfo(UnexpectedError,
|
||||
"expr result must have a ColumnVector or RowVector result");
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
CompareTwoJsonArray(T arr1, const proto::plan::Array& arr2) {
|
||||
int json_array_length = 0;
|
||||
if constexpr (std::is_same_v<
|
||||
T,
|
||||
simdjson::simdjson_result<simdjson::ondemand::array>>) {
|
||||
json_array_length = arr1.count_elements();
|
||||
}
|
||||
if constexpr (std::is_same_v<T,
|
||||
std::vector<simdjson::simdjson_result<
|
||||
simdjson::ondemand::value>>>) {
|
||||
json_array_length = arr1.size();
|
||||
}
|
||||
if (arr2.array_size() != json_array_length) {
|
||||
return false;
|
||||
}
|
||||
int i = 0;
|
||||
for (auto&& it : arr1) {
|
||||
switch (arr2.array(i).val_case()) {
|
||||
case proto::plan::GenericValue::kBoolVal: {
|
||||
auto val = it.template get<bool>();
|
||||
if (val.error() || val.value() != arr2.array(i).bool_val()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kInt64Val: {
|
||||
auto val = it.template get<int64_t>();
|
||||
if (val.error() || val.value() != arr2.array(i).int64_val()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kFloatVal: {
|
||||
auto val = it.template get<double>();
|
||||
if (val.error() || val.value() != arr2.array(i).float_val()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case proto::plan::GenericValue::kStringVal: {
|
||||
auto val = it.template get<std::string_view>();
|
||||
if (val.error() || val.value() != arr2.array(i).string_val()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported data type {}",
|
||||
arr2.array(i).val_case()));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T
|
||||
GetValueFromProtoInternal(const milvus::proto::plan::GenericValue& value_proto,
|
||||
bool& overflowed) {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(value_proto.val_case() ==
|
||||
milvus::proto::plan::GenericValue::kBoolVal);
|
||||
return static_cast<T>(value_proto.bool_val());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(value_proto.val_case() ==
|
||||
milvus::proto::plan::GenericValue::kInt64Val);
|
||||
auto val = value_proto.int64_val();
|
||||
if (milvus::query::out_of_range<T>(val)) {
|
||||
overflowed = true;
|
||||
return T();
|
||||
} else {
|
||||
return static_cast<T>(val);
|
||||
}
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(value_proto.val_case() ==
|
||||
milvus::proto::plan::GenericValue::kFloatVal);
|
||||
return static_cast<T>(value_proto.float_val());
|
||||
} else if constexpr (std::is_same_v<T, std::string> ||
|
||||
std::is_same_v<T, std::string_view>) {
|
||||
Assert(value_proto.val_case() ==
|
||||
milvus::proto::plan::GenericValue::kStringVal);
|
||||
return static_cast<T>(value_proto.string_val());
|
||||
} else if constexpr (std::is_same_v<T, proto::plan::Array>) {
|
||||
Assert(value_proto.val_case() ==
|
||||
milvus::proto::plan::GenericValue::kArrayVal);
|
||||
return static_cast<T>(value_proto.array_val());
|
||||
} else if constexpr (std::is_same_v<T, milvus::proto::plan::GenericValue>) {
|
||||
return static_cast<T>(value_proto);
|
||||
} else {
|
||||
PanicInfo(Unsupported,
|
||||
fmt::format("unsupported generic value {}",
|
||||
value_proto.DebugString()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T
|
||||
GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) {
|
||||
bool dummy_overflowed = false;
|
||||
return GetValueFromProtoInternal<T>(value_proto, dummy_overflowed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T
|
||||
GetValueFromProtoWithOverflow(
|
||||
const milvus::proto::plan::GenericValue& value_proto, bool& overflowed) {
|
||||
return GetValueFromProtoInternal<T>(value_proto, overflowed);
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,47 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Vector.h"
|
||||
#include "exec/QueryContext.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class VectorFunction {
|
||||
public:
|
||||
virtual ~VectorFunction() = default;
|
||||
|
||||
virtual void
|
||||
Apply(std::vector<VectorPtr>& args,
|
||||
DataType output_type,
|
||||
EvalCtx& context,
|
||||
VectorPtr& result) const = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<VectorFunction>
|
||||
GetVectorFunction(const std::string& name,
|
||||
const std::vector<DataType>& input_types,
|
||||
const QueryConfig& config);
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,89 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "exec/operator/Operator.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
class CallbackSink : public Operator {
|
||||
public:
|
||||
CallbackSink(
|
||||
int32_t operator_id,
|
||||
DriverContext* ctx,
|
||||
std::function<BlockingReason(RowVectorPtr, ContinueFuture*)> callback)
|
||||
: Operator(ctx, DataType::NONE, operator_id, "N/A", "CallbackSink"),
|
||||
callback_(callback) {
|
||||
}
|
||||
|
||||
void
|
||||
AddInput(RowVectorPtr& input) override {
|
||||
blocking_reason_ = callback_(input, &future_);
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
GetOutput() override {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void
|
||||
NoMoreInput() override {
|
||||
Operator::NoMoreInput();
|
||||
Close();
|
||||
}
|
||||
|
||||
bool
|
||||
NeedInput() const override {
|
||||
return callback_ != nullptr;
|
||||
}
|
||||
|
||||
bool
|
||||
IsFilter() override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
IsFinished() override {
|
||||
return no_more_input_;
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
IsBlocked(ContinueFuture* future) override {
|
||||
if (blocking_reason_ != BlockingReason::kNotBlocked) {
|
||||
*future = std::move(future_);
|
||||
blocking_reason_ = BlockingReason::kNotBlocked;
|
||||
return BlockingReason::kWaitForConsumer;
|
||||
}
|
||||
return BlockingReason::kNotBlocked;
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
Close() override {
|
||||
if (callback_) {
|
||||
callback_(nullptr, nullptr);
|
||||
callback_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ContinueFuture future_;
|
||||
BlockingReason blocking_reason_{BlockingReason::kNotBlocked};
|
||||
std::function<BlockingReason(RowVectorPtr, ContinueFuture*)> callback_;
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,83 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "FilterBits.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
FilterBits::FilterBits(
|
||||
int32_t operator_id,
|
||||
DriverContext* driverctx,
|
||||
const std::shared_ptr<const plan::FilterBitsNode>& filter)
|
||||
: Operator(driverctx,
|
||||
filter->output_type(),
|
||||
operator_id,
|
||||
filter->id(),
|
||||
"FilterBits") {
|
||||
ExecContext* exec_context = operator_context_->get_exec_context();
|
||||
QueryContext* query_context = exec_context->get_query_context();
|
||||
std::vector<expr::TypedExprPtr> filters;
|
||||
filters.emplace_back(filter->filter());
|
||||
exprs_ = std::make_unique<ExprSet>(filters, exec_context);
|
||||
need_process_rows_ = query_context->get_segment()->get_active_count(
|
||||
query_context->get_query_timestamp());
|
||||
num_processed_rows_ = 0;
|
||||
}
|
||||
|
||||
void
|
||||
FilterBits::AddInput(RowVectorPtr& input) {
|
||||
input_ = std::move(input);
|
||||
}
|
||||
|
||||
bool
|
||||
FilterBits::AllInputProcessed() {
|
||||
if (num_processed_rows_ == need_process_rows_) {
|
||||
input_ = nullptr;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
FilterBits::IsFinished() {
|
||||
return AllInputProcessed();
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
FilterBits::GetOutput() {
|
||||
if (AllInputProcessed()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvalCtx eval_ctx(
|
||||
operator_context_->get_exec_context(), exprs_.get(), input_.get());
|
||||
|
||||
exprs_->Eval(0, 1, true, eval_ctx, results_);
|
||||
|
||||
AssertInfo(results_.size() == 1 && results_[0] != nullptr,
|
||||
"FilterBits result size should be one and not be nullptr");
|
||||
|
||||
if (results_[0]->type() == DataType::ROW) {
|
||||
auto row_vec = std::dynamic_pointer_cast<RowVector>(results_[0]);
|
||||
num_processed_rows_ += row_vec->child(0)->size();
|
||||
} else {
|
||||
num_processed_rows_ += results_[0]->size();
|
||||
}
|
||||
return std::make_shared<RowVector>(results_);
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,74 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "exec/Driver.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "exec/operator/Operator.h"
|
||||
#include "exec/QueryContext.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
class FilterBits : public Operator {
|
||||
public:
|
||||
FilterBits(int32_t operator_id,
|
||||
DriverContext* ctx,
|
||||
const std::shared_ptr<const plan::FilterBitsNode>& filter);
|
||||
|
||||
bool
|
||||
IsFilter() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NeedInput() const override {
|
||||
return !input_;
|
||||
}
|
||||
|
||||
void
|
||||
AddInput(RowVectorPtr& input) override;
|
||||
|
||||
RowVectorPtr
|
||||
GetOutput() override;
|
||||
|
||||
bool
|
||||
IsFinished() override;
|
||||
|
||||
void
|
||||
Close() override {
|
||||
Operator::Close();
|
||||
exprs_->Clear();
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
IsBlocked(ContinueFuture* /* unused */) override {
|
||||
return BlockingReason::kNotBlocked;
|
||||
}
|
||||
|
||||
bool
|
||||
AllInputProcessed();
|
||||
|
||||
private:
|
||||
std::unique_ptr<ExprSet> exprs_;
|
||||
int64_t num_processed_rows_;
|
||||
int64_t need_process_rows_;
|
||||
};
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,21 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Operator.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {}
|
||||
} // namespace milvus
|
|
@ -0,0 +1,197 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/Driver.h"
|
||||
#include "exec/Task.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class OperatorContext {
|
||||
public:
|
||||
OperatorContext(DriverContext* driverCtx,
|
||||
const plan::PlanNodeId& plannodeid,
|
||||
int32_t operator_id,
|
||||
const std::string& operator_type = "")
|
||||
: driver_context_(driverCtx),
|
||||
plannode_id_(plannodeid),
|
||||
operator_id_(operator_id),
|
||||
operator_type_(operator_type) {
|
||||
}
|
||||
|
||||
ExecContext*
|
||||
get_exec_context() const {
|
||||
if (!exec_context_) {
|
||||
exec_context_ = std::make_unique<ExecContext>(
|
||||
driver_context_->task_->query_context().get());
|
||||
}
|
||||
return exec_context_.get();
|
||||
}
|
||||
|
||||
const std::shared_ptr<Task>&
|
||||
get_task() const {
|
||||
return driver_context_->task_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
get_task_id() const {
|
||||
return driver_context_->task_->taskid();
|
||||
}
|
||||
|
||||
DriverContext*
|
||||
get_driver_context() const {
|
||||
return driver_context_;
|
||||
}
|
||||
|
||||
const plan::PlanNodeId&
|
||||
get_plannode_id() const {
|
||||
return plannode_id_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
get_operator_type() const {
|
||||
return operator_type_;
|
||||
}
|
||||
|
||||
const int32_t
|
||||
get_operator_id() const {
|
||||
return operator_id_;
|
||||
}
|
||||
|
||||
private:
|
||||
DriverContext* driver_context_;
|
||||
plan::PlanNodeId plannode_id_;
|
||||
int32_t operator_id_;
|
||||
std::string operator_type_;
|
||||
|
||||
mutable std::unique_ptr<ExecContext> exec_context_;
|
||||
};
|
||||
|
||||
class Operator {
|
||||
public:
|
||||
Operator(DriverContext* ctx,
|
||||
DataType output_type,
|
||||
int32_t operator_id,
|
||||
const std::string& plannode_id,
|
||||
const std::string& operator_type = "")
|
||||
: operator_context_(std::make_unique<OperatorContext>(
|
||||
ctx, plannode_id, operator_id, operator_type)) {
|
||||
}
|
||||
|
||||
virtual ~Operator() = default;
|
||||
|
||||
virtual bool
|
||||
NeedInput() const = 0;
|
||||
|
||||
virtual void
|
||||
AddInput(RowVectorPtr& input) = 0;
|
||||
|
||||
virtual void
|
||||
NoMoreInput() {
|
||||
no_more_input_ = true;
|
||||
}
|
||||
|
||||
virtual RowVectorPtr
|
||||
GetOutput() = 0;
|
||||
|
||||
virtual bool
|
||||
IsFinished() = 0;
|
||||
|
||||
virtual bool
|
||||
IsFilter() = 0;
|
||||
|
||||
virtual BlockingReason
|
||||
IsBlocked(ContinueFuture* future) = 0;
|
||||
|
||||
virtual void
|
||||
Close() {
|
||||
input_ = nullptr;
|
||||
results_.clear();
|
||||
}
|
||||
|
||||
virtual bool
|
||||
PreserveOrder() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
get_operator_type() const {
|
||||
return operator_context_->get_operator_type();
|
||||
}
|
||||
|
||||
const int32_t
|
||||
get_operator_id() const {
|
||||
return operator_context_->get_operator_id();
|
||||
}
|
||||
|
||||
const plan::PlanNodeId&
|
||||
get_plannode_id() const {
|
||||
return operator_context_->get_plannode_id();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<OperatorContext> operator_context_;
|
||||
|
||||
DataType output_type_;
|
||||
|
||||
RowVectorPtr input_;
|
||||
|
||||
bool no_more_input_{false};
|
||||
|
||||
std::vector<VectorPtr> results_;
|
||||
};
|
||||
|
||||
class SourceOperator : public Operator {
|
||||
public:
|
||||
SourceOperator(DriverContext* driver_ctx,
|
||||
DataType out_type,
|
||||
int32_t operator_id,
|
||||
const std::string& plannode_id,
|
||||
const std::string& operator_type)
|
||||
: Operator(
|
||||
driver_ctx, out_type, operator_id, plannode_id, operator_type) {
|
||||
}
|
||||
|
||||
bool
|
||||
NeedInput() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
AddInput(RowVectorPtr& /* unused */) override {
|
||||
throw NotImplementedException(
|
||||
"SourceOperator does not support addInput()");
|
||||
}
|
||||
|
||||
void
|
||||
NoMoreInput() override {
|
||||
throw NotImplementedException(
|
||||
"SourceOperator does not support noMoreInput()");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
|
@ -0,0 +1,557 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/Schema.h"
|
||||
#include "common/Types.h"
|
||||
#include "pb/plan.pb.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace expr {
|
||||
|
||||
struct ColumnInfo {
|
||||
FieldId field_id_;
|
||||
DataType data_type_;
|
||||
std::vector<std::string> nested_path_;
|
||||
|
||||
ColumnInfo(const proto::plan::ColumnInfo& column_info)
|
||||
: field_id_(column_info.field_id()),
|
||||
data_type_(static_cast<DataType>(column_info.data_type())),
|
||||
nested_path_(column_info.nested_path().begin(),
|
||||
column_info.nested_path().end()) {
|
||||
}
|
||||
|
||||
ColumnInfo(FieldId field_id,
|
||||
DataType data_type,
|
||||
std::vector<std::string> nested_path = {})
|
||||
: field_id_(field_id),
|
||||
data_type_(data_type),
|
||||
nested_path_(std::move(nested_path)) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const ColumnInfo& other) {
|
||||
if (field_id_ != other.field_id_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data_type_ != other.data_type_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nested_path_.size(); ++i) {
|
||||
if (nested_path_[i] != other.nested_path_[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const {
|
||||
return fmt::format("[FieldId:{}, data_type:{}, nested_path:{}]",
|
||||
std::to_string(field_id_.get()),
|
||||
data_type_,
|
||||
milvus::Join(nested_path_, ","));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Base class for all exprs
|
||||
* a strongly-typed expression, such as literal, function call, etc...
|
||||
*/
|
||||
class ITypeExpr {
|
||||
public:
|
||||
explicit ITypeExpr(DataType type) : type_(type), inputs_{} {
|
||||
}
|
||||
|
||||
ITypeExpr(DataType type,
|
||||
std::vector<std::shared_ptr<const ITypeExpr>> inputs)
|
||||
: type_(type), inputs_{std::move(inputs)} {
|
||||
}
|
||||
|
||||
virtual ~ITypeExpr() = default;
|
||||
|
||||
const std::vector<std::shared_ptr<const ITypeExpr>>&
|
||||
inputs() const {
|
||||
return inputs_;
|
||||
}
|
||||
|
||||
DataType
|
||||
type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
virtual std::string
|
||||
ToString() const = 0;
|
||||
|
||||
const std::vector<std::shared_ptr<const ITypeExpr>>&
|
||||
inputs() {
|
||||
return inputs_;
|
||||
}
|
||||
|
||||
protected:
|
||||
DataType type_;
|
||||
std::vector<std::shared_ptr<const ITypeExpr>> inputs_;
|
||||
};
|
||||
|
||||
using TypedExprPtr = std::shared_ptr<const ITypeExpr>;
|
||||
|
||||
class InputTypeExpr : public ITypeExpr {
|
||||
public:
|
||||
InputTypeExpr(DataType type) : ITypeExpr(type) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "ROW";
|
||||
}
|
||||
};
|
||||
|
||||
using InputTypeExprPtr = std::shared_ptr<const InputTypeExpr>;
|
||||
|
||||
class CallTypeExpr : public ITypeExpr {
|
||||
public:
|
||||
CallTypeExpr(DataType type,
|
||||
const std::vector<TypedExprPtr>& inputs,
|
||||
std::string fun_name)
|
||||
: ITypeExpr{type, std::move(inputs)} {
|
||||
}
|
||||
|
||||
virtual ~CallTypeExpr() = default;
|
||||
|
||||
virtual const std::string&
|
||||
name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::string str{};
|
||||
str += name();
|
||||
str += "(";
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
if (i != 0) {
|
||||
str += ",";
|
||||
}
|
||||
str += inputs_[i]->ToString();
|
||||
}
|
||||
str += ")";
|
||||
return str;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
using CallTypeExprPtr = std::shared_ptr<const CallTypeExpr>;
|
||||
|
||||
class FieldAccessTypeExpr : public ITypeExpr {
|
||||
public:
|
||||
FieldAccessTypeExpr(DataType type, const std::string& name)
|
||||
: ITypeExpr{type}, name_(name), is_input_column_(true) {
|
||||
}
|
||||
|
||||
FieldAccessTypeExpr(DataType type,
|
||||
const TypedExprPtr& input,
|
||||
const std::string& name)
|
||||
: ITypeExpr{type, {std::move(input)}}, name_(name) {
|
||||
is_input_column_ =
|
||||
dynamic_cast<const InputTypeExpr*>(inputs_[0].get()) != nullptr;
|
||||
}
|
||||
|
||||
bool
|
||||
is_input_column() const {
|
||||
return is_input_column_;
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
if (inputs_.empty()) {
|
||||
return fmt::format("{}", name_);
|
||||
}
|
||||
|
||||
return fmt::format("{}[{}]", inputs_[0]->ToString(), name_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
bool is_input_column_;
|
||||
};
|
||||
|
||||
using FieldAccessTypeExprPtr = std::shared_ptr<const FieldAccessTypeExpr>;
|
||||
|
||||
/**
|
||||
* @brief Base class for all milvus filter exprs, output type must be BOOL
|
||||
* a strongly-typed expression, such as literal, function call, etc...
|
||||
*/
|
||||
class ITypeFilterExpr : public ITypeExpr {
|
||||
public:
|
||||
ITypeFilterExpr() : ITypeExpr(DataType::BOOL) {
|
||||
}
|
||||
|
||||
ITypeFilterExpr(std::vector<std::shared_ptr<const ITypeExpr>> inputs)
|
||||
: ITypeExpr(DataType::BOOL, std::move(inputs)) {
|
||||
}
|
||||
|
||||
virtual ~ITypeFilterExpr() = default;
|
||||
};
|
||||
|
||||
class UnaryRangeFilterExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
explicit UnaryRangeFilterExpr(const ColumnInfo& column,
|
||||
proto::plan::OpType op_type,
|
||||
const proto::plan::GenericValue& val)
|
||||
: ITypeFilterExpr(), column_(column), op_type_(op_type), val_(val) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::stringstream ss;
|
||||
ss << "UnaryRangeFilterExpr: {columnInfo:" << column_.ToString()
|
||||
<< " op_type:" << milvus::proto::plan::OpType_Name(op_type_)
|
||||
<< " val:" << val_.DebugString() << "}";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
const ColumnInfo column_;
|
||||
const proto::plan::OpType op_type_;
|
||||
const proto::plan::GenericValue val_;
|
||||
};
|
||||
|
||||
class AlwaysTrueExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
explicit AlwaysTrueExpr() {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "AlwaysTrue expr";
|
||||
}
|
||||
};
|
||||
|
||||
class ExistsExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
explicit ExistsExpr(const ColumnInfo& column)
|
||||
: ITypeFilterExpr(), column_(column) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "{Exists Expression - Column: " + column_.ToString() + "}";
|
||||
}
|
||||
|
||||
const ColumnInfo column_;
|
||||
};
|
||||
|
||||
class LogicalUnaryExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
enum class OpType { Invalid = 0, LogicalNot = 1 };
|
||||
|
||||
explicit LogicalUnaryExpr(const OpType op_type, const TypedExprPtr& child)
|
||||
: op_type_(op_type) {
|
||||
inputs_.emplace_back(child);
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::string opTypeString;
|
||||
|
||||
switch (op_type_) {
|
||||
case OpType::LogicalNot:
|
||||
opTypeString = "Logical NOT";
|
||||
break;
|
||||
default:
|
||||
opTypeString = "Invalid Operator";
|
||||
break;
|
||||
}
|
||||
|
||||
return fmt::format("LogicalUnaryExpr:[{} - Child: {}]",
|
||||
opTypeString,
|
||||
inputs_[0]->ToString());
|
||||
}
|
||||
|
||||
const OpType op_type_;
|
||||
};
|
||||
|
||||
class TermFilterExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
explicit TermFilterExpr(const ColumnInfo& column,
|
||||
const std::vector<proto::plan::GenericValue>& vals,
|
||||
bool is_in_field = false)
|
||||
: ITypeFilterExpr(),
|
||||
column_(column),
|
||||
vals_(vals),
|
||||
is_in_field_(is_in_field) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::string values;
|
||||
|
||||
for (const auto& val : vals_) {
|
||||
values += val.DebugString() + ", ";
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "TermFilterExpr:[Column: " << column_.ToString() << ", Values: ["
|
||||
<< values << "]"
|
||||
<< ", Is In Field: " << (is_in_field_ ? "true" : "false") << "]";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
const ColumnInfo column_;
|
||||
const std::vector<proto::plan::GenericValue> vals_;
|
||||
const bool is_in_field_;
|
||||
};
|
||||
|
||||
class LogicalBinaryExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
enum class OpType { Invalid = 0, And = 1, Or = 2 };
|
||||
|
||||
explicit LogicalBinaryExpr(OpType op_type,
|
||||
const TypedExprPtr& left,
|
||||
const TypedExprPtr& right)
|
||||
: ITypeFilterExpr(), op_type_(op_type) {
|
||||
inputs_.emplace_back(left);
|
||||
inputs_.emplace_back(right);
|
||||
}
|
||||
|
||||
std::string
|
||||
GetOpTypeString() const {
|
||||
switch (op_type_) {
|
||||
case OpType::Invalid:
|
||||
return "Invalid";
|
||||
case OpType::And:
|
||||
return "And";
|
||||
case OpType::Or:
|
||||
return "Or";
|
||||
default:
|
||||
return "Unknown"; // Handle the default case if necessary
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format("LogicalBinaryExpr:[{} - Left: {}, Right: {}]",
|
||||
GetOpTypeString(),
|
||||
inputs_[0]->ToString(),
|
||||
inputs_[1]->ToString());
|
||||
}
|
||||
|
||||
std::string
|
||||
name() const {
|
||||
return GetOpTypeString();
|
||||
}
|
||||
|
||||
public:
|
||||
const OpType op_type_;
|
||||
};
|
||||
|
||||
class BinaryRangeFilterExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
BinaryRangeFilterExpr(const ColumnInfo& column,
|
||||
const proto::plan::GenericValue& lower_value,
|
||||
const proto::plan::GenericValue& upper_value,
|
||||
bool lower_inclusive,
|
||||
bool upper_inclusive)
|
||||
: ITypeFilterExpr(),
|
||||
column_(column),
|
||||
lower_val_(lower_value),
|
||||
upper_val_(upper_value),
|
||||
lower_inclusive_(lower_inclusive),
|
||||
upper_inclusive_(upper_inclusive) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::stringstream ss;
|
||||
ss << "BinaryRangeFilterExpr:[Column: " << column_.ToString()
|
||||
<< ", Lower Value: " << lower_val_.DebugString()
|
||||
<< ", Upper Value: " << upper_val_.DebugString()
|
||||
<< ", Lower Inclusive: " << (lower_inclusive_ ? "true" : "false")
|
||||
<< ", Upper Inclusive: " << (upper_inclusive_ ? "true" : "false")
|
||||
<< "]";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const ColumnInfo column_;
|
||||
const proto::plan::GenericValue lower_val_;
|
||||
const proto::plan::GenericValue upper_val_;
|
||||
const bool lower_inclusive_;
|
||||
const bool upper_inclusive_;
|
||||
};
|
||||
|
||||
class BinaryArithOpEvalRangeExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
BinaryArithOpEvalRangeExpr(const ColumnInfo& column,
|
||||
const proto::plan::OpType op_type,
|
||||
const proto::plan::ArithOpType arith_op_type,
|
||||
const proto::plan::GenericValue value,
|
||||
const proto::plan::GenericValue right_operand)
|
||||
: column_(column),
|
||||
op_type_(op_type),
|
||||
arith_op_type_(arith_op_type),
|
||||
right_operand_(right_operand),
|
||||
value_(value) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::stringstream ss;
|
||||
ss << "BinaryArithOpEvalRangeExpr:[Column: " << column_.ToString()
|
||||
<< ", Operator Type: " << milvus::proto::plan::OpType_Name(op_type_)
|
||||
<< ", Arith Operator Type: "
|
||||
<< milvus::proto::plan::ArithOpType_Name(arith_op_type_)
|
||||
<< ", Value: " << value_.DebugString()
|
||||
<< ", Right Operand: " << right_operand_.DebugString() << "]";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
const ColumnInfo column_;
|
||||
const proto::plan::OpType op_type_;
|
||||
const proto::plan::ArithOpType arith_op_type_;
|
||||
const proto::plan::GenericValue right_operand_;
|
||||
const proto::plan::GenericValue value_;
|
||||
};
|
||||
|
||||
class CompareExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
CompareExpr(const FieldId& left_field,
|
||||
const FieldId& right_field,
|
||||
DataType left_data_type,
|
||||
DataType right_data_type,
|
||||
proto::plan::OpType op_type)
|
||||
: left_field_id_(left_field),
|
||||
right_field_id_(right_field),
|
||||
left_data_type_(left_data_type),
|
||||
right_data_type_(right_data_type),
|
||||
op_type_(op_type) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::string opTypeString;
|
||||
|
||||
return fmt::format(
|
||||
"CompareExpr:[Left Field ID: {}, Right Field ID: {}, Left Data "
|
||||
"Type: {}, "
|
||||
"Operator: {}, Right "
|
||||
"Data Type: {}]",
|
||||
left_field_id_.get(),
|
||||
right_field_id_.get(),
|
||||
milvus::proto::plan::OpType_Name(op_type_),
|
||||
left_data_type_,
|
||||
right_data_type_);
|
||||
}
|
||||
|
||||
public:
|
||||
const FieldId left_field_id_;
|
||||
const FieldId right_field_id_;
|
||||
const DataType left_data_type_;
|
||||
const DataType right_data_type_;
|
||||
const proto::plan::OpType op_type_;
|
||||
};
|
||||
|
||||
class JsonContainsExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
JsonContainsExpr(ColumnInfo column,
|
||||
ContainsType op,
|
||||
const bool same_type,
|
||||
const std::vector<proto::plan::GenericValue>& vals)
|
||||
: column_(column),
|
||||
op_(op),
|
||||
same_type_(same_type),
|
||||
vals_(std::move(vals)) {
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
std::string values;
|
||||
for (const auto& val : vals_) {
|
||||
values += val.DebugString() + ", ";
|
||||
}
|
||||
return fmt::format(
|
||||
"JsonContainsExpr:[Column: {}, Operator: {}, Same Type: {}, "
|
||||
"Values: [{}]]",
|
||||
column_.ToString(),
|
||||
JSONContainsExpr_JSONOp_Name(op_),
|
||||
(same_type_ ? "true" : "false"),
|
||||
values);
|
||||
}
|
||||
|
||||
public:
|
||||
const ColumnInfo column_;
|
||||
ContainsType op_;
|
||||
bool same_type_;
|
||||
const std::vector<proto::plan::GenericValue> vals_;
|
||||
};
|
||||
} // namespace expr
|
||||
} // namespace milvus
|
||||
|
||||
template <>
|
||||
struct fmt::formatter<milvus::proto::plan::ArithOpType>
|
||||
: formatter<string_view> {
|
||||
auto
|
||||
format(milvus::proto::plan::ArithOpType c, format_context& ctx) const {
|
||||
using namespace milvus::proto::plan;
|
||||
string_view name = "unknown";
|
||||
switch (c) {
|
||||
case ArithOpType::Unknown:
|
||||
name = "Unknown";
|
||||
break;
|
||||
case ArithOpType::Add:
|
||||
name = "Add";
|
||||
break;
|
||||
case ArithOpType::Sub:
|
||||
name = "Sub";
|
||||
break;
|
||||
case ArithOpType::Mul:
|
||||
name = "Mul";
|
||||
break;
|
||||
case ArithOpType::Div:
|
||||
name = "Div";
|
||||
break;
|
||||
case ArithOpType::Mod:
|
||||
name = "Mod";
|
||||
break;
|
||||
case ArithOpType::ArrayLength:
|
||||
name = "ArrayLength";
|
||||
break;
|
||||
case ArithOpType::ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_:
|
||||
name = "ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_";
|
||||
break;
|
||||
case ArithOpType::ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_:
|
||||
name = "ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_";
|
||||
break;
|
||||
}
|
||||
return formatter<string_view>::format(name, ctx);
|
||||
}
|
||||
};
|
|
@ -13,7 +13,6 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/Types.h"
|
||||
|
|
|
@ -68,7 +68,7 @@ ScalarIndexSort<T>::BuildV2(const Config& config) {
|
|||
PanicInfo(S3Error, "failed to create scan iterator");
|
||||
}
|
||||
auto reader = res.value();
|
||||
std::vector<storage::FieldDataPtr> field_datas;
|
||||
std::vector<FieldDataPtr> field_datas;
|
||||
for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) {
|
||||
if (!rec.ok()) {
|
||||
PanicInfo(DataFormatBroken, "failed to read data");
|
||||
|
@ -280,7 +280,7 @@ ScalarIndexSort<T>::LoadV2(const Config& config) {
|
|||
index_files.push_back(b.name);
|
||||
}
|
||||
}
|
||||
std::map<std::string, storage::FieldDataPtr> index_datas{};
|
||||
std::map<std::string, FieldDataPtr> index_datas{};
|
||||
for (auto& file_name : index_files) {
|
||||
auto res = space_->GetBlobByteSize(file_name);
|
||||
if (!res.ok()) {
|
||||
|
|
|
@ -24,11 +24,12 @@
|
|||
|
||||
#include "common/Types.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Exception.h"
|
||||
#include "common/Utils.h"
|
||||
#include "common/Slice.h"
|
||||
#include "index/StringIndexMarisa.h"
|
||||
#include "index/Utils.h"
|
||||
#include "index/Index.h"
|
||||
#include "common/Utils.h"
|
||||
#include "common/Slice.h"
|
||||
#include "storage/Util.h"
|
||||
#include "storage/space.h"
|
||||
|
||||
|
@ -73,7 +74,7 @@ StringIndexMarisa::BuildV2(const Config& config) {
|
|||
PanicInfo(S3Error, "failed to create scan iterator");
|
||||
}
|
||||
auto reader = res.value();
|
||||
std::vector<storage::FieldDataPtr> field_datas;
|
||||
std::vector<FieldDataPtr> field_datas;
|
||||
for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) {
|
||||
if (!rec.ok()) {
|
||||
PanicInfo(DataFormatBroken, "failed to read data");
|
||||
|
@ -315,7 +316,7 @@ StringIndexMarisa::LoadV2(const Config& config) {
|
|||
index_files.push_back(b.name);
|
||||
}
|
||||
}
|
||||
std::map<std::string, storage::FieldDataPtr> index_datas{};
|
||||
std::map<std::string, FieldDataPtr> index_datas{};
|
||||
for (auto& file_name : index_files) {
|
||||
auto res = space_->GetBlobByteSize(file_name);
|
||||
if (!res.ok()) {
|
||||
|
|
|
@ -24,17 +24,18 @@
|
|||
#include <vector>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Exception.h"
|
||||
#include "common/File.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/Slice.h"
|
||||
#include "index/Utils.h"
|
||||
#include "index/Meta.h"
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <unistd.h>
|
||||
#include "common/EasyAssert.h"
|
||||
#include "knowhere/comp/index_param.h"
|
||||
#include "common/Slice.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/Util.h"
|
||||
#include "common/File.h"
|
||||
#include "knowhere/comp/index_param.h"
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
|
@ -205,7 +206,7 @@ ParseConfigFromIndexParams(
|
|||
}
|
||||
|
||||
void
|
||||
AssembleIndexDatas(std::map<std::string, storage::FieldDataPtr>& index_datas) {
|
||||
AssembleIndexDatas(std::map<std::string, FieldDataPtr>& index_datas) {
|
||||
if (index_datas.find(INDEX_FILE_SLICE_META) != index_datas.end()) {
|
||||
auto slice_meta = index_datas.at(INDEX_FILE_SLICE_META);
|
||||
Config meta_data = Config::parse(std::string(
|
||||
|
@ -237,9 +238,8 @@ AssembleIndexDatas(std::map<std::string, storage::FieldDataPtr>& index_datas) {
|
|||
}
|
||||
|
||||
void
|
||||
AssembleIndexDatas(
|
||||
std::map<std::string, storage::FieldDataChannelPtr>& index_datas,
|
||||
std::unordered_map<std::string, storage::FieldDataPtr>& result) {
|
||||
AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
|
||||
std::unordered_map<std::string, FieldDataPtr>& result) {
|
||||
if (auto meta_iter = index_datas.find(INDEX_FILE_SLICE_META);
|
||||
meta_iter != index_datas.end()) {
|
||||
auto raw_metadata_array =
|
||||
|
|
|
@ -28,9 +28,9 @@
|
|||
#include <string>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "index/IndexInfo.h"
|
||||
#include "storage/Types.h"
|
||||
#include "storage/FieldData.h"
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
|
@ -114,12 +114,11 @@ ParseConfigFromIndexParams(
|
|||
const std::map<std::string, std::string>& index_params);
|
||||
|
||||
void
|
||||
AssembleIndexDatas(std::map<std::string, storage::FieldDataPtr>& index_datas);
|
||||
AssembleIndexDatas(std::map<std::string, FieldDataPtr>& index_datas);
|
||||
|
||||
void
|
||||
AssembleIndexDatas(
|
||||
std::map<std::string, storage::FieldDataChannelPtr>& index_datas,
|
||||
std::unordered_map<std::string, storage::FieldDataPtr>& result);
|
||||
AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
|
||||
std::unordered_map<std::string, FieldDataPtr>& result);
|
||||
|
||||
// On Linux, read() (and similar system calls) will transfer at most 0x7ffff000 (2,147,479,552) bytes once
|
||||
void
|
||||
|
|
|
@ -38,20 +38,20 @@
|
|||
#include "knowhere/factory.h"
|
||||
#include "knowhere/comp/time_recorder.h"
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/Slice.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/File.h"
|
||||
#include "common/Slice.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "common/RangeSearchHelper.h"
|
||||
#include "common/Utils.h"
|
||||
#include "log/Log.h"
|
||||
#include "mmap/Types.h"
|
||||
#include "storage/DataCodec.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/MemFileManagerImpl.h"
|
||||
#include "storage/ThreadPools.h"
|
||||
#include "storage/Util.h"
|
||||
#include "common/File.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "storage/space.h"
|
||||
#include "storage/Util.h"
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
|
@ -189,7 +189,7 @@ VectorMemIndex<T>::LoadV2(const Config& config) {
|
|||
|
||||
auto slice_meta_file = index_prefix + "/" + INDEX_FILE_SLICE_META;
|
||||
auto res = space_->GetBlobByteSize(std::string(slice_meta_file));
|
||||
std::map<std::string, storage::FieldDataPtr> index_datas{};
|
||||
std::map<std::string, FieldDataPtr> index_datas{};
|
||||
|
||||
if (!res.ok() && !res.status().IsFileNotFound()) {
|
||||
PanicInfo(DataFormatBroken, "failed to read blob");
|
||||
|
@ -289,7 +289,7 @@ VectorMemIndex<T>::Load(const Config& config) {
|
|||
|
||||
auto parallel_degree =
|
||||
static_cast<uint64_t>(DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE);
|
||||
std::map<std::string, storage::FieldDataPtr> index_datas{};
|
||||
std::map<std::string, FieldDataPtr> index_datas{};
|
||||
|
||||
// try to read slice meta first
|
||||
std::string slice_meta_filepath;
|
||||
|
@ -414,7 +414,7 @@ VectorMemIndex<T>::BuildV2(const Config& config) {
|
|||
}
|
||||
|
||||
auto reader = res.value();
|
||||
std::vector<storage::FieldDataPtr> field_datas;
|
||||
std::vector<FieldDataPtr> field_datas;
|
||||
for (auto rec : *reader) {
|
||||
if (!rec.ok()) {
|
||||
PanicInfo(IndexBuildError,
|
||||
|
|
|
@ -53,8 +53,13 @@
|
|||
__FUNCTION__, \
|
||||
GetThreadName().c_str())
|
||||
|
||||
#define LOG_SEGCORE_TRACE_ DLOG(INFO) << SEGCORE_MODULE_FUNCTION
|
||||
#define LOG_SEGCORE_DEBUG_ DLOG(INFO) << SEGCORE_MODULE_FUNCTION
|
||||
// GLOG has no debug and trace level,
|
||||
// Using VLOG to implement it.
|
||||
#define GLOG_DEBUG 5
|
||||
#define GLOG_TRACE 6
|
||||
|
||||
#define LOG_SEGCORE_TRACE_ VLOG(GLOG_TRACE) << SEGCORE_MODULE_FUNCTION
|
||||
#define LOG_SEGCORE_DEBUG_ VLOG(GLOG_DEBUG) << SEGCORE_MODULE_FUNCTION
|
||||
#define LOG_SEGCORE_INFO_ LOG(INFO) << SEGCORE_MODULE_FUNCTION
|
||||
#define LOG_SEGCORE_WARNING_ LOG(WARNING) << SEGCORE_MODULE_FUNCTION
|
||||
#define LOG_SEGCORE_ERROR_ LOG(ERROR) << SEGCORE_MODULE_FUNCTION
|
||||
|
|
|
@ -21,15 +21,15 @@
|
|||
#include <cstring>
|
||||
#include <filesystem>
|
||||
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/Span.h"
|
||||
#include "common/Array.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/File.h"
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/Span.h"
|
||||
#include "fmt/format.h"
|
||||
#include "log/Log.h"
|
||||
#include "mmap/Utils.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "common/Array.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
|
@ -156,7 +156,7 @@ class ColumnBase {
|
|||
Span() const = 0;
|
||||
|
||||
void
|
||||
AppendBatch(const storage::FieldDataPtr& data) {
|
||||
AppendBatch(const FieldDataPtr& data) {
|
||||
size_t required_size = size_ + data->Size();
|
||||
if (required_size > cap_size_) {
|
||||
Expand(required_size * 2 + padding_);
|
||||
|
|
|
@ -19,13 +19,13 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "storage/FieldData.h"
|
||||
#include "common/FieldData.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
struct FieldDataInfo {
|
||||
FieldDataInfo() {
|
||||
channel = std::make_shared<storage::FieldDataChannel>();
|
||||
channel = std::make_shared<FieldDataChannel>();
|
||||
}
|
||||
|
||||
FieldDataInfo(int64_t field_id,
|
||||
|
@ -34,12 +34,12 @@ struct FieldDataInfo {
|
|||
: field_id(field_id),
|
||||
row_count(row_count),
|
||||
mmap_dir_path(std::move(mmap_dir_path)) {
|
||||
channel = std::make_shared<storage::FieldDataChannel>();
|
||||
channel = std::make_shared<FieldDataChannel>();
|
||||
}
|
||||
|
||||
FieldDataInfo(int64_t field_id,
|
||||
size_t row_count,
|
||||
storage::FieldDataChannelPtr channel)
|
||||
FieldDataChannelPtr channel)
|
||||
: field_id(field_id),
|
||||
row_count(row_count),
|
||||
channel(std::move(channel)) {
|
||||
|
@ -48,7 +48,7 @@ struct FieldDataInfo {
|
|||
FieldDataInfo(int64_t field_id,
|
||||
size_t row_count,
|
||||
std::string mmap_dir_path,
|
||||
storage::FieldDataChannelPtr channel)
|
||||
FieldDataChannelPtr channel)
|
||||
: field_id(field_id),
|
||||
row_count(row_count),
|
||||
mmap_dir_path(std::move(mmap_dir_path)),
|
||||
|
@ -57,9 +57,9 @@ struct FieldDataInfo {
|
|||
|
||||
FieldDataInfo(int64_t field_id,
|
||||
size_t row_count,
|
||||
const std::vector<storage::FieldDataPtr>& batch)
|
||||
const std::vector<FieldDataPtr>& batch)
|
||||
: field_id(field_id), row_count(row_count) {
|
||||
channel = std::make_shared<storage::FieldDataChannel>();
|
||||
channel = std::make_shared<FieldDataChannel>();
|
||||
for (auto& data : batch) {
|
||||
channel->push(data);
|
||||
}
|
||||
|
@ -69,11 +69,11 @@ struct FieldDataInfo {
|
|||
FieldDataInfo(int64_t field_id,
|
||||
size_t row_count,
|
||||
std::string mmap_dir_path,
|
||||
const std::vector<storage::FieldDataPtr>& batch)
|
||||
const std::vector<FieldDataPtr>& batch)
|
||||
: field_id(field_id),
|
||||
row_count(row_count),
|
||||
mmap_dir_path(std::move(mmap_dir_path)) {
|
||||
channel = std::make_shared<storage::FieldDataChannel>();
|
||||
channel = std::make_shared<FieldDataChannel>();
|
||||
for (auto& data : batch) {
|
||||
channel->push(data);
|
||||
}
|
||||
|
@ -83,6 +83,6 @@ struct FieldDataInfo {
|
|||
int64_t field_id;
|
||||
size_t row_count;
|
||||
std::string mmap_dir_path;
|
||||
storage::FieldDataChannelPtr channel;
|
||||
FieldDataChannelPtr channel;
|
||||
};
|
||||
} // namespace milvus
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
namespace milvus {
|
||||
|
||||
inline size_t
|
||||
GetDataSize(const std::vector<storage::FieldDataPtr>& datas) {
|
||||
GetDataSize(const std::vector<FieldDataPtr>& datas) {
|
||||
size_t total_size{0};
|
||||
for (auto data : datas) {
|
||||
total_size += data->Size();
|
||||
|
@ -42,7 +42,7 @@ GetDataSize(const std::vector<storage::FieldDataPtr>& datas) {
|
|||
}
|
||||
|
||||
inline void*
|
||||
FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) {
|
||||
FillField(DataType data_type, const FieldDataPtr data, void* dst) {
|
||||
char* dest = reinterpret_cast<char*>(dst);
|
||||
if (datatype_is_variable(data_type)) {
|
||||
switch (data_type) {
|
||||
|
@ -80,7 +80,7 @@ FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) {
|
|||
inline size_t
|
||||
WriteFieldData(File& file,
|
||||
DataType data_type,
|
||||
const storage::FieldDataPtr& data,
|
||||
const FieldDataPtr& data,
|
||||
std::vector<std::vector<uint64_t>>& element_indices) {
|
||||
size_t total_written{0};
|
||||
if (datatype_is_variable(data_type)) {
|
||||
|
|
|
@ -0,0 +1,287 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace plan {
|
||||
|
||||
typedef std::string PlanNodeId;
|
||||
/**
|
||||
* @brief Base class for all logic plan node
|
||||
*
|
||||
*/
|
||||
class PlanNode {
|
||||
public:
|
||||
explicit PlanNode(const PlanNodeId& id) : id_(id) {
|
||||
}
|
||||
|
||||
virtual ~PlanNode() = default;
|
||||
|
||||
const PlanNodeId&
|
||||
id() const {
|
||||
return id_;
|
||||
}
|
||||
|
||||
virtual DataType
|
||||
output_type() const = 0;
|
||||
|
||||
virtual std::vector<std::shared_ptr<PlanNode>>
|
||||
sources() const = 0;
|
||||
|
||||
virtual bool
|
||||
RequireSplits() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual std::string
|
||||
ToString() const = 0;
|
||||
|
||||
virtual std::string_view
|
||||
name() const = 0;
|
||||
|
||||
private:
|
||||
PlanNodeId id_;
|
||||
};
|
||||
|
||||
using PlanNodePtr = std::shared_ptr<PlanNode>;
|
||||
|
||||
class SegmentNode : public PlanNode {
|
||||
public:
|
||||
SegmentNode(
|
||||
const PlanNodeId& id,
|
||||
const std::shared_ptr<milvus::segcore::SegmentInternalInterface>&
|
||||
segment)
|
||||
: PlanNode(id), segment_(segment) {
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return DataType::ROW;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<PlanNode>>
|
||||
sources() const override {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "SegmentNode";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "SegmentNode";
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<milvus::segcore::SegmentInternalInterface> segment_;
|
||||
};
|
||||
|
||||
class ValuesNode : public PlanNode {
|
||||
public:
|
||||
ValuesNode(const PlanNodeId& id,
|
||||
const std::vector<RowVectorPtr>& values,
|
||||
bool parallelizeable = false)
|
||||
: PlanNode(id),
|
||||
values_{std::move(values)},
|
||||
output_type_(values[0]->type()) {
|
||||
AssertInfo(!values.empty(), "ValueNode must has value");
|
||||
}
|
||||
|
||||
ValuesNode(const PlanNodeId& id,
|
||||
std::vector<RowVectorPtr>&& values,
|
||||
bool parallelizeable = false)
|
||||
: PlanNode(id),
|
||||
values_{std::move(values)},
|
||||
output_type_(values[0]->type()) {
|
||||
AssertInfo(!values.empty(), "ValueNode must has value");
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return output_type_;
|
||||
}
|
||||
|
||||
const std::vector<RowVectorPtr>&
|
||||
values() const {
|
||||
return values_;
|
||||
}
|
||||
|
||||
std::vector<PlanNodePtr>
|
||||
sources() const override {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool
|
||||
parallelizable() {
|
||||
return parallelizable_;
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "Values";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "Values";
|
||||
}
|
||||
|
||||
private:
|
||||
DataType output_type_;
|
||||
const std::vector<RowVectorPtr> values_;
|
||||
bool parallelizable_;
|
||||
};
|
||||
|
||||
class FilterNode : public PlanNode {
|
||||
public:
|
||||
FilterNode(const PlanNodeId& id,
|
||||
expr::TypedExprPtr filter,
|
||||
std::vector<PlanNodePtr> sources)
|
||||
: PlanNode(id),
|
||||
sources_{std::move(sources)},
|
||||
filter_(std::move(filter)) {
|
||||
AssertInfo(
|
||||
filter_->type() == DataType::BOOL,
|
||||
fmt::format("Filter expression must be of type BOOLEAN, Got {}",
|
||||
filter_->type()));
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return sources_[0]->output_type();
|
||||
}
|
||||
|
||||
std::vector<PlanNodePtr>
|
||||
sources() const override {
|
||||
return sources_;
|
||||
}
|
||||
|
||||
const expr::TypedExprPtr&
|
||||
filter() const {
|
||||
return filter_;
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "Filter";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "";
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
const expr::TypedExprPtr filter_;
|
||||
};
|
||||
|
||||
class FilterBitsNode : public PlanNode {
|
||||
public:
|
||||
FilterBitsNode(
|
||||
const PlanNodeId& id,
|
||||
expr::TypedExprPtr filter,
|
||||
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
|
||||
: PlanNode(id),
|
||||
sources_{std::move(sources)},
|
||||
filter_(std::move(filter)) {
|
||||
AssertInfo(
|
||||
filter_->type() == DataType::BOOL,
|
||||
fmt::format("Filter expression must be of type BOOLEAN, Got {}",
|
||||
filter_->type()));
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return DataType::BOOL;
|
||||
}
|
||||
|
||||
std::vector<PlanNodePtr>
|
||||
sources() const override {
|
||||
return sources_;
|
||||
}
|
||||
|
||||
const expr::TypedExprPtr&
|
||||
filter() const {
|
||||
return filter_;
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "FilterBits";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format("FilterBitsNode:[filter_expr:{}]",
|
||||
filter_->ToString());
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
const expr::TypedExprPtr filter_;
|
||||
};
|
||||
|
||||
enum class ExecutionStrategy {
|
||||
// Process splits as they come in any available driver.
|
||||
kUngrouped,
|
||||
// Process splits from each split group only in one driver.
|
||||
// It is used when split groups represent separate partitions of the data on
|
||||
// the grouping keys or join keys. In that case it is sufficient to keep only
|
||||
// the keys from a single split group in a hash table used by group-by or
|
||||
// join.
|
||||
kGrouped,
|
||||
};
|
||||
struct PlanFragment {
|
||||
std::shared_ptr<const PlanNode> plan_node_;
|
||||
ExecutionStrategy execution_strategy_{ExecutionStrategy::kUngrouped};
|
||||
int32_t num_splitgroups_{0};
|
||||
|
||||
PlanFragment() = default;
|
||||
|
||||
inline bool
|
||||
IsGroupedExecution() const {
|
||||
return execution_strategy_ == ExecutionStrategy::kGrouped;
|
||||
}
|
||||
|
||||
explicit PlanFragment(std::shared_ptr<const PlanNode> top_node,
|
||||
ExecutionStrategy strategy,
|
||||
int32_t num_splitgroups)
|
||||
: plan_node_(std::move(top_node)),
|
||||
execution_strategy_(strategy),
|
||||
num_splitgroups_(num_splitgroups) {
|
||||
}
|
||||
|
||||
explicit PlanFragment(std::shared_ptr<const PlanNode> top_node)
|
||||
: plan_node_(std::move(top_node)) {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace plan
|
||||
} // namespace milvus
|
|
@ -20,10 +20,12 @@
|
|||
#include "common/QueryInfo.h"
|
||||
#include "query/Expr.h"
|
||||
|
||||
namespace milvus::plan {
|
||||
class PlanNode;
|
||||
};
|
||||
namespace milvus::query {
|
||||
|
||||
class PlanNodeVisitor;
|
||||
|
||||
// Base of all Nodes
|
||||
struct PlanNode {
|
||||
public:
|
||||
|
@ -36,6 +38,7 @@ using PlanNodePtr = std::unique_ptr<PlanNode>;
|
|||
|
||||
struct VectorPlanNode : PlanNode {
|
||||
std::optional<ExprPtr> predicate_;
|
||||
std::optional<std::shared_ptr<milvus::plan::PlanNode>> filter_plannode_;
|
||||
SearchInfo search_info_;
|
||||
std::string placeholder_tag_;
|
||||
};
|
||||
|
@ -64,6 +67,7 @@ struct RetrievePlanNode : PlanNode {
|
|||
accept(PlanNodeVisitor&) override;
|
||||
|
||||
std::optional<ExprPtr> predicate_;
|
||||
std::optional<std::shared_ptr<milvus::plan::PlanNode>> filter_plannode_;
|
||||
bool is_count_;
|
||||
int64_t limit_;
|
||||
};
|
||||
|
|
|
@ -185,6 +185,12 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
}
|
||||
}();
|
||||
|
||||
auto expr_parser = [&]() -> plan::PlanNodePtr {
|
||||
auto expr = ParseExprs(anns_proto.predicates());
|
||||
return std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
|
||||
expr);
|
||||
};
|
||||
|
||||
auto& query_info_proto = anns_proto.query_info();
|
||||
|
||||
SearchInfo search_info;
|
||||
|
@ -210,6 +216,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
}();
|
||||
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
|
||||
plan_node->predicate_ = std::move(expr_opt);
|
||||
if (anns_proto.has_predicates()) {
|
||||
plan_node->filter_plannode_ = std::move(expr_parser());
|
||||
}
|
||||
plan_node->search_info_ = std::move(search_info);
|
||||
return plan_node;
|
||||
}
|
||||
|
@ -227,7 +236,13 @@ ProtoParser::RetrievePlanNodeFromProto(
|
|||
auto expr_opt = [&]() -> ExprPtr {
|
||||
return ParseExpr(predicate_proto);
|
||||
}();
|
||||
auto expr_parser = [&]() -> plan::PlanNodePtr {
|
||||
auto expr = ParseExprs(predicate_proto);
|
||||
return std::make_shared<plan::FilterBitsNode>(
|
||||
DEFAULT_PLANNODE_ID, expr);
|
||||
}();
|
||||
node->predicate_ = std::move(expr_opt);
|
||||
node->filter_plannode_ = std::move(expr_parser);
|
||||
} else {
|
||||
auto& query = plan_node_proto.query();
|
||||
if (query.has_predicates()) {
|
||||
|
@ -235,7 +250,13 @@ ProtoParser::RetrievePlanNodeFromProto(
|
|||
auto expr_opt = [&]() -> ExprPtr {
|
||||
return ParseExpr(predicate_proto);
|
||||
}();
|
||||
auto expr_parser = [&]() -> plan::PlanNodePtr {
|
||||
auto expr = ParseExprs(predicate_proto);
|
||||
return std::make_shared<plan::FilterBitsNode>(
|
||||
DEFAULT_PLANNODE_ID, expr);
|
||||
}();
|
||||
node->predicate_ = std::move(expr_opt);
|
||||
node->filter_plannode_ = std::move(expr_parser);
|
||||
}
|
||||
node->is_count_ = query.is_count();
|
||||
node->limit_ = query.limit();
|
||||
|
@ -284,6 +305,16 @@ ProtoParser::CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto) {
|
|||
return retrieve_plan;
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
return std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value());
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
|
@ -352,6 +383,21 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) {
|
|||
return result;
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseBinaryRangeExprs(
|
||||
const proto::plan::BinaryRangeExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
return std::make_shared<expr::BinaryRangeFilterExpr>(
|
||||
columnInfo,
|
||||
expr_pb.lower_value(),
|
||||
expr_pb.upper_value(),
|
||||
expr_pb.lower_inclusive(),
|
||||
expr_pb.upper_inclusive());
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
|
@ -436,6 +482,27 @@ ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) {
|
|||
return result;
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
|
||||
auto& left_column_info = expr_pb.left_column_info();
|
||||
auto left_field_id = FieldId(left_column_info.field_id());
|
||||
auto left_data_type = schema[left_field_id].get_data_type();
|
||||
Assert(left_data_type ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
|
||||
auto& right_column_info = expr_pb.right_column_info();
|
||||
auto right_field_id = FieldId(right_column_info.field_id());
|
||||
auto right_data_type = schema[right_field_id].get_data_type();
|
||||
Assert(right_data_type ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
|
||||
return std::make_shared<expr::CompareExpr>(left_field_id,
|
||||
right_field_id,
|
||||
left_data_type,
|
||||
right_data_type,
|
||||
expr_pb.op());
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) {
|
||||
auto& left_column_info = expr_pb.left_column_info();
|
||||
|
@ -461,6 +528,20 @@ ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) {
|
|||
}();
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
std::vector<::milvus::proto::plan::GenericValue> values;
|
||||
for (size_t i = 0; i < expr_pb.values_size(); i++) {
|
||||
values.emplace_back(expr_pb.values(i));
|
||||
}
|
||||
return std::make_shared<expr::TermFilterExpr>(
|
||||
columnInfo, values, expr_pb.is_in_field());
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
|
@ -568,6 +649,14 @@ ProtoParser::ParseUnaryExpr(const proto::plan::UnaryExpr& expr_pb) {
|
|||
return std::make_unique<LogicalUnaryExpr>(op, expr);
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) {
|
||||
auto op = static_cast<expr::LogicalUnaryExpr::OpType>(expr_pb.op());
|
||||
Assert(op == expr::LogicalUnaryExpr::OpType::LogicalNot);
|
||||
auto child_expr = this->ParseExprs(expr_pb.child());
|
||||
return std::make_shared<expr::LogicalUnaryExpr>(op, child_expr);
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
|
||||
auto op = static_cast<LogicalBinaryExpr::OpType>(expr_pb.op());
|
||||
|
@ -576,6 +665,14 @@ ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
|
|||
return std::make_unique<LogicalBinaryExpr>(op, left_expr, right_expr);
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) {
|
||||
auto op = static_cast<expr::LogicalBinaryExpr::OpType>(expr_pb.op());
|
||||
auto left_expr = this->ParseExprs(expr_pb.left());
|
||||
auto right_expr = this->ParseExprs(expr_pb.right());
|
||||
return std::make_shared<expr::LogicalBinaryExpr>(op, left_expr, right_expr);
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseBinaryArithOpEvalRangeExpr(
|
||||
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
|
||||
|
@ -642,11 +739,35 @@ ProtoParser::ParseBinaryArithOpEvalRangeExpr(
|
|||
return result;
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseBinaryArithOpEvalRangeExprs(
|
||||
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
return std::make_shared<expr::BinaryArithOpEvalRangeExpr>(
|
||||
column_info,
|
||||
expr_pb.op(),
|
||||
expr_pb.arith_op(),
|
||||
expr_pb.value(),
|
||||
expr_pb.right_operand());
|
||||
}
|
||||
|
||||
std::unique_ptr<ExistsExprImpl>
|
||||
ExtractExistsExprImpl(const proto::plan::ExistsExpr& expr_proto) {
|
||||
return std::make_unique<ExistsExprImpl>(expr_proto.info());
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
return std::make_shared<expr::ExistsExpr>(column_info);
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseExistExpr(const proto::plan::ExistsExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.info();
|
||||
|
@ -718,6 +839,24 @@ ExtractJsonContainsExprImpl(const proto::plan::JSONContainsExpr& expr_proto) {
|
|||
val_case);
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseJsonContainsExprs(
|
||||
const proto::plan::JSONContainsExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema[field_id].get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
std::vector<::milvus::proto::plan::GenericValue> values;
|
||||
for (size_t i = 0; i < expr_pb.elements_size(); i++) {
|
||||
values.emplace_back(expr_pb.elements(i));
|
||||
}
|
||||
return std::make_shared<expr::JsonContainsExpr>(
|
||||
columnInfo,
|
||||
expr_pb.op(),
|
||||
expr_pb.elements_same_type(),
|
||||
std::move(values));
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseJsonContainsExpr(
|
||||
const proto::plan::JSONContainsExpr& expr_pb) {
|
||||
|
@ -755,6 +894,55 @@ ProtoParser::ParseJsonContainsExpr(
|
|||
return result;
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::CreateAlwaysTrueExprs() {
|
||||
return std::make_shared<expr::AlwaysTrueExpr>();
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) {
|
||||
using ppe = proto::plan::Expr;
|
||||
switch (expr_pb.expr_case()) {
|
||||
case ppe::kUnaryRangeExpr: {
|
||||
return ParseUnaryRangeExprs(expr_pb.unary_range_expr());
|
||||
}
|
||||
case ppe::kBinaryExpr: {
|
||||
return ParseBinaryExprs(expr_pb.binary_expr());
|
||||
}
|
||||
case ppe::kUnaryExpr: {
|
||||
return ParseUnaryExprs(expr_pb.unary_expr());
|
||||
}
|
||||
case ppe::kTermExpr: {
|
||||
return ParseTermExprs(expr_pb.term_expr());
|
||||
}
|
||||
case ppe::kBinaryRangeExpr: {
|
||||
return ParseBinaryRangeExprs(expr_pb.binary_range_expr());
|
||||
}
|
||||
case ppe::kCompareExpr: {
|
||||
return ParseCompareExprs(expr_pb.compare_expr());
|
||||
}
|
||||
case ppe::kBinaryArithOpEvalRangeExpr: {
|
||||
return ParseBinaryArithOpEvalRangeExprs(
|
||||
expr_pb.binary_arith_op_eval_range_expr());
|
||||
}
|
||||
case ppe::kExistsExpr: {
|
||||
return ParseExistExprs(expr_pb.exists_expr());
|
||||
}
|
||||
case ppe::kAlwaysTrueExpr: {
|
||||
return CreateAlwaysTrueExprs();
|
||||
}
|
||||
case ppe::kJsonContainsExpr: {
|
||||
return ParseJsonContainsExprs(expr_pb.json_contains_expr());
|
||||
}
|
||||
default: {
|
||||
std::string s;
|
||||
google::protobuf::TextFormat::PrintToString(expr_pb, &s);
|
||||
PanicInfo(ExprInvalid,
|
||||
std::string("unsupported expr proto node: ") + s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) {
|
||||
using ppe = proto::plan::Expr;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "PlanNode.h"
|
||||
#include "common/Schema.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
|
@ -72,6 +73,40 @@ class ProtoParser {
|
|||
std::unique_ptr<RetrievePlan>
|
||||
CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseExprs(const proto::plan::Expr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseBinaryArithOpEvalRangeExprs(
|
||||
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseCompareExprs(const proto::plan::CompareExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseTermExprs(const proto::plan::TermExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseExistExprs(const proto::plan::ExistsExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
CreateAlwaysTrueExprs();
|
||||
|
||||
private:
|
||||
const Schema& schema;
|
||||
};
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "query/Expr.h"
|
||||
#include "common/Utils.h"
|
||||
#include "simd/hook.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
|
@ -70,4 +71,61 @@ inline bool
|
|||
out_of_range(int64_t t) {
|
||||
return gt_ub<T>(t) || lt_lb<T>(t);
|
||||
}
|
||||
|
||||
inline void
|
||||
AppendOneChunk(BitsetType& result, const bool* chunk_ptr, size_t chunk_len) {
|
||||
// Append a value once instead of BITSET_BLOCK_BIT_SIZE times.
|
||||
auto AppendBlock = [&result](const bool* ptr, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
auto val = milvus::simd::get_bitset_block(ptr);
|
||||
#else
|
||||
BitsetBlockType val = 0;
|
||||
// This can use CPU SIMD optimzation
|
||||
uint8_t vals[BITSET_BLOCK_SIZE] = {0};
|
||||
for (size_t j = 0; j < 8; ++j) {
|
||||
for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) {
|
||||
vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j;
|
||||
}
|
||||
}
|
||||
for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) {
|
||||
val |= BitsetBlockType(vals[j]) << (8 * j);
|
||||
}
|
||||
#endif
|
||||
result.append(val);
|
||||
ptr += BITSET_BLOCK_SIZE * 8;
|
||||
}
|
||||
};
|
||||
// Append bit for these bits that can not be union as a block
|
||||
// Usually n less than BITSET_BLOCK_BIT_SIZE.
|
||||
auto AppendBit = [&result](const bool* ptr, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
bool bit = *ptr++;
|
||||
result.push_back(bit);
|
||||
}
|
||||
};
|
||||
|
||||
size_t res_len = result.size();
|
||||
|
||||
int n_prefix =
|
||||
res_len % BITSET_BLOCK_BIT_SIZE == 0
|
||||
? 0
|
||||
: std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE,
|
||||
chunk_len);
|
||||
|
||||
AppendBit(chunk_ptr, n_prefix);
|
||||
|
||||
if (n_prefix == chunk_len)
|
||||
return;
|
||||
|
||||
size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE;
|
||||
size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE;
|
||||
|
||||
AppendBlock(chunk_ptr + n_prefix, n_block);
|
||||
|
||||
AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "PlanNodeVisitor.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
class ExecPlanNodeVisitor : public PlanNodeVisitor {
|
||||
public:
|
||||
void
|
||||
|
@ -96,6 +97,24 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
|
|||
return expr_use_pk_index_;
|
||||
}
|
||||
|
||||
void
|
||||
ExecuteExprNodeInternal(
|
||||
const std::shared_ptr<milvus::plan::PlanNode>& plannode,
|
||||
const milvus::segcore::SegmentInternalInterface* segment,
|
||||
BitsetType& result,
|
||||
bool& cache_offset_getted,
|
||||
std::vector<int64_t>& cache_offset);
|
||||
|
||||
void
|
||||
ExecuteExprNode(const std::shared_ptr<milvus::plan::PlanNode>& plannode,
|
||||
const milvus::segcore::SegmentInternalInterface* segment,
|
||||
BitsetType& result) {
|
||||
bool get_cache_offset;
|
||||
std::vector<int64_t> cache_offsets;
|
||||
ExecuteExprNodeInternal(
|
||||
plannode, segment, result, get_cache_offset, cache_offsets);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename VectorType>
|
||||
void
|
||||
|
|
|
@ -16,9 +16,12 @@
|
|||
#include "query/PlanImpl.h"
|
||||
#include "query/SubSearchResult.h"
|
||||
#include "query/generated/ExecExprVisitor.h"
|
||||
#include "query/Utils.h"
|
||||
#include "segcore/SegmentGrowing.h"
|
||||
#include "common/Json.h"
|
||||
#include "log/Log.h"
|
||||
#include "plan/PlanNode.h"
|
||||
#include "exec/Task.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
|
@ -73,6 +76,63 @@ empty_search_result(int64_t num_queries, SearchInfo& search_info) {
|
|||
return final_result;
|
||||
}
|
||||
|
||||
void
|
||||
ExecPlanNodeVisitor::ExecuteExprNodeInternal(
|
||||
const std::shared_ptr<milvus::plan::PlanNode>& plannode,
|
||||
const milvus::segcore::SegmentInternalInterface* segment,
|
||||
BitsetType& bitset_holder,
|
||||
bool& cache_offset_getted,
|
||||
std::vector<int64_t>& cache_offset) {
|
||||
bitset_holder.clear();
|
||||
LOG_SEGCORE_INFO_ << "plannode:" << plannode->ToString();
|
||||
auto plan = plan::PlanFragment(plannode);
|
||||
// TODO: get query id from proxy
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment, timestamp_);
|
||||
|
||||
auto task =
|
||||
milvus::exec::Task::Create(DEFAULT_TASK_ID, plan, 0, query_context);
|
||||
for (;;) {
|
||||
auto result = task->Next();
|
||||
if (!result) {
|
||||
break;
|
||||
}
|
||||
auto childrens = result->childrens();
|
||||
AssertInfo(childrens.size() == 1,
|
||||
"expr result vector's children size not equal one");
|
||||
LOG_SEGCORE_DEBUG_ << "output result length:" << childrens[0]->size()
|
||||
<< std::endl;
|
||||
if (auto vec = std::dynamic_pointer_cast<ColumnVector>(childrens[0])) {
|
||||
AppendOneChunk(bitset_holder,
|
||||
static_cast<bool*>(vec->GetRawData()),
|
||||
vec->size());
|
||||
} else if (auto row =
|
||||
std::dynamic_pointer_cast<RowVector>(childrens[0])) {
|
||||
auto bit_vec =
|
||||
std::dynamic_pointer_cast<ColumnVector>(row->child(0));
|
||||
AppendOneChunk(bitset_holder,
|
||||
static_cast<bool*>(bit_vec->GetRawData()),
|
||||
bit_vec->size());
|
||||
if (!cache_offset_getted) {
|
||||
// offset cache only get once because not support iterator batch
|
||||
auto cache_offset_vec =
|
||||
std::dynamic_pointer_cast<ColumnVector>(row->child(1));
|
||||
auto cache_offset_vec_ptr =
|
||||
(int64_t*)(cache_offset_vec->GetRawData());
|
||||
for (size_t i = 0; i < cache_offset_vec->size(); ++i) {
|
||||
cache_offset.push_back(cache_offset_vec_ptr[i]);
|
||||
}
|
||||
cache_offset_getted = true;
|
||||
}
|
||||
} else {
|
||||
PanicInfo(UnexpectedError, "expr return type not matched");
|
||||
}
|
||||
}
|
||||
// std::string s;
|
||||
// boost::to_string(*bitset_holder, s);
|
||||
// std::cout << bitset_holder->size() << " . " << s << std::endl;
|
||||
}
|
||||
|
||||
template <typename VectorType>
|
||||
void
|
||||
ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
||||
|
@ -98,10 +158,10 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
|||
}
|
||||
|
||||
std::unique_ptr<BitsetType> bitset_holder;
|
||||
if (node.predicate_.has_value()) {
|
||||
bitset_holder = std::make_unique<BitsetType>(
|
||||
ExecExprVisitor(*segment, this, active_count, timestamp_)
|
||||
.call_child(*node.predicate_.value()));
|
||||
if (node.filter_plannode_.has_value()) {
|
||||
BitsetType expr_res;
|
||||
ExecuteExprNode(node.filter_plannode_.value(), segment, expr_res);
|
||||
bitset_holder = std::make_unique<BitsetType>(expr_res);
|
||||
bitset_holder->flip();
|
||||
} else {
|
||||
bitset_holder = std::make_unique<BitsetType>(active_count, false);
|
||||
|
@ -165,10 +225,16 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
|||
bitset_holder.resize(active_count);
|
||||
}
|
||||
|
||||
if (node.predicate_.has_value() && node.predicate_.value() != nullptr) {
|
||||
bitset_holder =
|
||||
ExecExprVisitor(*segment, this, active_count, timestamp_)
|
||||
.call_child(*(node.predicate_.value()));
|
||||
// This flag used to indicate whether to get offset from expr module that
|
||||
// speeds up mvcc filter in the next interface: "timestamp_filter"
|
||||
bool get_cache_offset = false;
|
||||
std::vector<int64_t> cache_offsets;
|
||||
if (node.filter_plannode_.has_value()) {
|
||||
ExecuteExprNodeInternal(node.filter_plannode_.value(),
|
||||
segment,
|
||||
bitset_holder,
|
||||
get_cache_offset,
|
||||
cache_offsets);
|
||||
bitset_holder.flip();
|
||||
}
|
||||
|
||||
|
@ -189,9 +255,8 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
|||
}
|
||||
|
||||
bool false_filtered_out = false;
|
||||
if (GetExprUsePkIndex() && IsTermExpr(node.predicate_.value().get())) {
|
||||
segment->timestamp_filter(
|
||||
bitset_holder, expr_cached_pk_id_offsets_, timestamp_);
|
||||
if (get_cache_offset) {
|
||||
segment->timestamp_filter(bitset_holder, cache_offsets, timestamp_);
|
||||
} else {
|
||||
bitset_holder.flip();
|
||||
false_filtered_out = true;
|
||||
|
|
|
@ -42,6 +42,6 @@ set(SEGCORE_FILES
|
|||
SkipIndex.cpp)
|
||||
add_library(milvus_segcore SHARED ${SEGCORE_FILES})
|
||||
|
||||
target_link_libraries(milvus_segcore milvus_query ${OpenMP_CXX_FLAGS} milvus-storage)
|
||||
target_link_libraries(milvus_segcore milvus_query milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage)
|
||||
|
||||
install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}")
|
||||
|
|
|
@ -25,13 +25,13 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/Json.h"
|
||||
#include "common/Span.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "storage/FieldData.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -103,7 +103,7 @@ class VectorBase {
|
|||
|
||||
virtual void
|
||||
set_data_raw(ssize_t element_offset,
|
||||
const std::vector<storage::FieldDataPtr>& data) = 0;
|
||||
const std::vector<FieldDataPtr>& data) = 0;
|
||||
|
||||
void
|
||||
set_data_raw(ssize_t element_offset,
|
||||
|
@ -112,7 +112,7 @@ class VectorBase {
|
|||
const FieldMeta& field_meta);
|
||||
|
||||
virtual void
|
||||
fill_chunk_data(const std::vector<storage::FieldDataPtr>& data) = 0;
|
||||
fill_chunk_data(const std::vector<FieldDataPtr>& data) = 0;
|
||||
|
||||
virtual SpanBase
|
||||
get_span_base(int64_t chunk_id) const = 0;
|
||||
|
@ -197,7 +197,7 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
}
|
||||
|
||||
void
|
||||
fill_chunk_data(const std::vector<storage::FieldDataPtr>& datas)
|
||||
fill_chunk_data(const std::vector<FieldDataPtr>& datas)
|
||||
override { // used only for sealed segment
|
||||
AssertInfo(chunks_.size() == 0, "no empty concurrent vector");
|
||||
|
||||
|
@ -217,7 +217,7 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
|
||||
void
|
||||
set_data_raw(ssize_t element_offset,
|
||||
const std::vector<storage::FieldDataPtr>& datas) override {
|
||||
const std::vector<FieldDataPtr>& datas) override {
|
||||
for (auto& field_data : datas) {
|
||||
auto num_rows = field_data->get_num_rows();
|
||||
set_data_raw(element_offset, field_data->Data(), num_rows);
|
||||
|
|
|
@ -306,7 +306,7 @@ class IndexingRecord {
|
|||
AppendingIndex(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
FieldId fieldId,
|
||||
const storage::FieldDataPtr data,
|
||||
const FieldDataPtr data,
|
||||
const InsertRecord<is_sealed>& record) {
|
||||
if (is_in(fieldId)) {
|
||||
auto& indexing = field_indexings_.at(fieldId);
|
||||
|
|
|
@ -424,7 +424,7 @@ struct InsertRecord {
|
|||
}
|
||||
|
||||
void
|
||||
insert_pks(const std::vector<storage::FieldDataPtr>& field_datas) {
|
||||
insert_pks(const std::vector<FieldDataPtr>& field_datas) {
|
||||
std::lock_guard lck(shared_mutex_);
|
||||
int64_t offset = 0;
|
||||
for (auto& data : field_datas) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "common/Consts.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/Types.h"
|
||||
#include "fmt/format.h"
|
||||
#include "log/Log.h"
|
||||
|
@ -29,7 +30,6 @@
|
|||
#include "query/SearchOnSealed.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
#include "segcore/Utils.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/RemoteChunkManagerSingleton.h"
|
||||
#include "storage/Util.h"
|
||||
#include "storage/ThreadPools.h"
|
||||
|
@ -58,8 +58,12 @@ SegmentGrowingImpl::mask_with_delete(BitsetType& bitset,
|
|||
return;
|
||||
}
|
||||
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
|
||||
AssertInfo(delete_bitset.size() == bitset.size(),
|
||||
"Deleted bitmap size not equal to filtered bitmap size");
|
||||
AssertInfo(
|
||||
delete_bitset.size() == bitset.size(),
|
||||
fmt::format(
|
||||
"Deleted bitmap size:{} not equal to filtered bitmap size:{}",
|
||||
delete_bitset.size(),
|
||||
bitset.size()));
|
||||
bitset |= delete_bitset;
|
||||
}
|
||||
|
||||
|
@ -177,12 +181,12 @@ SegmentGrowingImpl::LoadFieldData(const LoadFieldDataInfo& infos) {
|
|||
for (auto& [id, info] : infos.field_infos) {
|
||||
auto field_id = FieldId(id);
|
||||
auto insert_files = info.insert_files;
|
||||
auto channel = std::make_shared<storage::FieldDataChannel>();
|
||||
auto channel = std::make_shared<FieldDataChannel>();
|
||||
auto& pool =
|
||||
ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::MIDDLE);
|
||||
auto load_future =
|
||||
pool.Submit(LoadFieldDatasFromRemote, insert_files, channel);
|
||||
auto field_data = CollectFieldDataChannel(channel);
|
||||
auto field_data = storage::CollectFieldDataChannel(channel);
|
||||
if (field_id == TimestampFieldID) {
|
||||
// step 2: sort timestamp
|
||||
// query node already guarantees that the timestamp is ordered, avoid field data copy in c++
|
||||
|
@ -263,7 +267,8 @@ SegmentGrowingImpl::LoadFieldDataV2(const LoadFieldDataInfo& infos) {
|
|||
std::shared_ptr<milvus_storage::Space> space = std::move(res.value());
|
||||
auto load_future = pool.Submit(
|
||||
LoadFieldDatasFromRemote2, space, schema_, field_data_info);
|
||||
auto field_data = CollectFieldDataChannel(field_data_info.channel);
|
||||
auto field_data =
|
||||
milvus::storage::CollectFieldDataChannel(field_data_info.channel);
|
||||
if (field_id == TimestampFieldID) {
|
||||
// step 2: sort timestamp
|
||||
// query node already guarantees that the timestamp is ordered, avoid field data copy in c++
|
||||
|
|
|
@ -235,7 +235,13 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
|
||||
bool
|
||||
HasIndex(FieldId field_id) const override {
|
||||
return true;
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
if (datatype_is_vector(field_meta.get_data_type()) &&
|
||||
indexing_record_.SyncDataWithIndex(field_id)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
|
|
|
@ -47,6 +47,7 @@ class SegmentSealed : public SegmentInternalInterface {
|
|||
}
|
||||
};
|
||||
|
||||
using SegmentSealedPtr = std::unique_ptr<SegmentSealed>;
|
||||
using SegmentSealedSPtr = std::shared_ptr<SegmentSealed>;
|
||||
using SegmentSealedUPtr = std::unique_ptr<SegmentSealed>;
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "mmap/Column.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/Types.h"
|
||||
#include "log/Log.h"
|
||||
#include "pb/schema.pb.h"
|
||||
|
@ -40,7 +41,6 @@
|
|||
#include "query/ScalarIndex.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnSealed.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/Util.h"
|
||||
#include "storage/ThreadPools.h"
|
||||
#include "storage/ChunkCacheSingleton.h"
|
||||
|
@ -279,7 +279,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
|
|||
if (system_field_type == SystemFieldType::Timestamp) {
|
||||
std::vector<Timestamp> timestamps(num_rows);
|
||||
int64_t offset = 0;
|
||||
auto field_data = CollectFieldDataChannel(data.channel);
|
||||
auto field_data = storage::CollectFieldDataChannel(data.channel);
|
||||
for (auto& data : field_data) {
|
||||
int64_t row_count = data->get_num_rows();
|
||||
std::copy_n(static_cast<const Timestamp*>(data->Data()),
|
||||
|
@ -307,7 +307,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
|
|||
AssertInfo(system_field_type == SystemFieldType::RowId,
|
||||
"System field type of id column is not RowId");
|
||||
|
||||
auto field_data = CollectFieldDataChannel(data.channel);
|
||||
auto field_data = storage::CollectFieldDataChannel(data.channel);
|
||||
|
||||
// write data under lock
|
||||
std::unique_lock lck(mutex_);
|
||||
|
@ -335,7 +335,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
|
|||
auto var_column =
|
||||
std::make_shared<VariableColumn<std::string>>(
|
||||
num_rows, field_meta);
|
||||
storage::FieldDataPtr field_data;
|
||||
FieldDataPtr field_data;
|
||||
while (data.channel->pop(field_data)) {
|
||||
for (auto i = 0; i < field_data->get_num_rows(); i++) {
|
||||
auto str = static_cast<const std::string*>(
|
||||
|
@ -354,7 +354,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
|
|||
auto var_column =
|
||||
std::make_shared<VariableColumn<milvus::Json>>(
|
||||
num_rows, field_meta);
|
||||
storage::FieldDataPtr field_data;
|
||||
FieldDataPtr field_data;
|
||||
while (data.channel->pop(field_data)) {
|
||||
for (auto i = 0; i < field_data->get_num_rows(); i++) {
|
||||
auto padded_string =
|
||||
|
@ -374,7 +374,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
|
|||
case milvus::DataType::ARRAY: {
|
||||
auto var_column =
|
||||
std::make_shared<ArrayColumn>(num_rows, field_meta);
|
||||
storage::FieldDataPtr field_data;
|
||||
FieldDataPtr field_data;
|
||||
while (data.channel->pop(field_data)) {
|
||||
for (auto i = 0; i < field_data->get_num_rows(); i++) {
|
||||
auto rawValue = field_data->RawValue(i);
|
||||
|
@ -398,7 +398,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
|
|||
field_id, num_rows, field_data_size);
|
||||
} else {
|
||||
column = std::make_shared<Column>(num_rows, field_meta);
|
||||
storage::FieldDataPtr field_data;
|
||||
FieldDataPtr field_data;
|
||||
while (data.channel->pop(field_data)) {
|
||||
column->AppendBatch(field_data);
|
||||
}
|
||||
|
@ -469,7 +469,7 @@ SegmentSealedImpl::MapFieldData(const FieldId field_id, FieldDataInfo& data) {
|
|||
auto data_size = 0;
|
||||
std::vector<uint64_t> indices{};
|
||||
std::vector<std::vector<uint64_t>> element_indices{};
|
||||
storage::FieldDataPtr field_data;
|
||||
FieldDataPtr field_data;
|
||||
while (data.channel->pop(field_data)) {
|
||||
data_size += field_data->Size();
|
||||
auto written =
|
||||
|
@ -669,8 +669,12 @@ SegmentSealedImpl::mask_with_delete(BitsetType& bitset,
|
|||
return;
|
||||
}
|
||||
auto& delete_bitset = *bitmap_holder->bitmap_ptr;
|
||||
AssertInfo(delete_bitset.size() == bitset.size(),
|
||||
"Deleted bitmap size not equal to filtered bitmap size");
|
||||
AssertInfo(
|
||||
delete_bitset.size() == bitset.size(),
|
||||
fmt::format(
|
||||
"Deleted bitmap size:{} not equal to filtered bitmap size:{}",
|
||||
delete_bitset.size(),
|
||||
bitset.size()));
|
||||
bitset |= delete_bitset;
|
||||
}
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
vec_binlog_config_;
|
||||
};
|
||||
|
||||
inline SegmentSealedPtr
|
||||
inline SegmentSealedUPtr
|
||||
CreateSealedSegment(
|
||||
SchemaPtr schema,
|
||||
IndexMetaPtr index_meta = nullptr,
|
||||
|
|
|
@ -14,14 +14,14 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common/Common.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "index/ScalarIndex.h"
|
||||
#include "log/Log.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/RemoteChunkManagerSingleton.h"
|
||||
#include "common/Common.h"
|
||||
#include "storage/ThreadPool.h"
|
||||
#include "storage/Util.h"
|
||||
#include "mmap/Utils.h"
|
||||
#include "storage/ThreadPool.h"
|
||||
#include "storage/RemoteChunkManagerSingleton.h"
|
||||
#include "storage/Util.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -50,7 +50,7 @@ ParsePksFromFieldData(std::vector<PkType>& pks, const DataArray& data) {
|
|||
void
|
||||
ParsePksFromFieldData(DataType data_type,
|
||||
std::vector<PkType>& pks,
|
||||
const std::vector<storage::FieldDataPtr>& datas) {
|
||||
const std::vector<FieldDataPtr>& datas) {
|
||||
int64_t offset = 0;
|
||||
|
||||
for (auto& field_data : datas) {
|
||||
|
@ -737,7 +737,7 @@ LoadFieldDatasFromRemote2(std::shared_ptr<milvus_storage::Space> space,
|
|||
// segcore use default remote chunk manager to load data from minio/s3
|
||||
void
|
||||
LoadFieldDatasFromRemote(std::vector<std::string>& remote_files,
|
||||
storage::FieldDataChannelPtr channel) {
|
||||
FieldDataChannelPtr channel) {
|
||||
try {
|
||||
auto parallel_degree = static_cast<uint64_t>(
|
||||
DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE);
|
||||
|
|
|
@ -20,13 +20,13 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common/FieldData.h"
|
||||
#include "common/QueryResult.h"
|
||||
// #include "common/Schema.h"
|
||||
#include "common/Types.h"
|
||||
#include "index/Index.h"
|
||||
#include "segcore/DeletedRecord.h"
|
||||
#include "segcore/InsertRecord.h"
|
||||
#include "index/Index.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/space.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
@ -37,7 +37,7 @@ ParsePksFromFieldData(std::vector<PkType>& pks, const DataArray& data);
|
|||
void
|
||||
ParsePksFromFieldData(DataType data_type,
|
||||
std::vector<PkType>& pks,
|
||||
const std::vector<storage::FieldDataPtr>& datas);
|
||||
const std::vector<FieldDataPtr>& datas);
|
||||
|
||||
void
|
||||
ParsePksFromIDs(std::vector<PkType>& pks,
|
||||
|
@ -159,7 +159,7 @@ ReverseDataFromIndex(const index::IndexBase* index,
|
|||
|
||||
void
|
||||
LoadFieldDatasFromRemote(std::vector<std::string>& remote_files,
|
||||
storage::FieldDataChannelPtr channel);
|
||||
FieldDataChannelPtr channel);
|
||||
|
||||
void
|
||||
LoadFieldDatasFromRemote2(std::shared_ptr<milvus_storage::Space> space,
|
||||
|
|
|
@ -10,21 +10,22 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "segcore/segment_c.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "common/FieldData.h"
|
||||
#include "common/LoadInfo.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "common/type_c.h"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "log/Log.h"
|
||||
#include "mmap/Types.h"
|
||||
#include "segcore/Collection.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
#include "segcore/SegmentSealedImpl.h"
|
||||
#include "segcore/Utils.h"
|
||||
#include "storage/FieldData.h"
|
||||
#include "storage/Util.h"
|
||||
#include "mmap/Types.h"
|
||||
#include "storage/space.h"
|
||||
|
||||
////////////////////////////// common interfaces //////////////////////////////
|
||||
|
@ -292,8 +293,8 @@ LoadFieldRawData(CSegmentInterface c_segment,
|
|||
}
|
||||
auto field_data = milvus::storage::CreateFieldData(data_type, dim);
|
||||
field_data->FillFieldData(data, row_count);
|
||||
milvus::storage::FieldDataChannelPtr channel =
|
||||
std::make_shared<milvus::storage::FieldDataChannel>();
|
||||
milvus::FieldDataChannelPtr channel =
|
||||
std::make_shared<milvus::FieldDataChannel>();
|
||||
channel->push(field_data);
|
||||
channel->close();
|
||||
auto field_data_info = milvus::FieldDataInfo(
|
||||
|
|
|
@ -28,6 +28,10 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
|
|||
set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw")
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*")
|
||||
# TODO: add arm cpu simd
|
||||
message ("simd using arm mode")
|
||||
list(APPEND MILVUS_SIMD_SRCS
|
||||
neon.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(milvus_simd ${MILVUS_SIMD_SRCS})
|
||||
|
|
|
@ -39,7 +39,7 @@ GetBitsetBlockAVX2(const bool* src) {
|
|||
BitsetBlockType res[4];
|
||||
_mm256_storeu_si256((__m256i*)res, tmpvec);
|
||||
return res[0];
|
||||
// __m128i tmpvec = _mm_loadu_si64(tmp);
|
||||
// __m256i tmpvec = _mm_loadu_si64(tmp);
|
||||
// BitsetBlockType res;
|
||||
// _mm_storeu_si64(&res, tmpvec);
|
||||
// return res;
|
||||
|
@ -231,6 +231,80 @@ FindTermAVX2(const double* src, size_t vec_size, double val) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
AllFalseAVX2(const bool* src, int64_t size) {
|
||||
int num_chunk = size / 32;
|
||||
__m256i highbit = _mm256_set1_epi8(0x7F);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m256i data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i highbits = _mm256_add_epi8(data, highbit);
|
||||
if (_mm256_movemask_epi8(highbits) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
if (src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AllTrueAVX2(const bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
__m256i highbit = _mm256_set1_epi8(0x7F);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m256i data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i highbits = _mm256_add_epi8(data, highbit);
|
||||
if (_mm256_movemask_epi8(highbits) != 0xFFFF) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
if (!src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
AndBoolAVX2(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 32;
|
||||
for (size_t i = 0; i < num_chunk * 32; i += 32) {
|
||||
__m256i l_reg =
|
||||
_mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i));
|
||||
__m256i r_reg =
|
||||
_mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i));
|
||||
__m256i res = _mm256_and_si256(l_reg, r_reg);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res);
|
||||
}
|
||||
for (size_t i = num_chunk * 32; i < size; ++i) {
|
||||
left[i] &= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
OrBoolAVX2(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 32;
|
||||
for (size_t i = 0; i < num_chunk * 32; i += 32) {
|
||||
__m256i l_reg =
|
||||
_mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i));
|
||||
__m256i r_reg =
|
||||
_mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i));
|
||||
__m256i res = _mm256_or_si256(l_reg, r_reg);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res);
|
||||
}
|
||||
for (size_t i = num_chunk * 32; i < size; ++i) {
|
||||
left[i] |= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
|
|
|
@ -58,5 +58,17 @@ template <>
|
|||
bool
|
||||
FindTermAVX2(const double* src, size_t vec_size, double val);
|
||||
|
||||
bool
|
||||
AllFalseAVX2(const bool* src, int64_t size);
|
||||
|
||||
bool
|
||||
AllTrueAVX2(const bool* src, int64_t size);
|
||||
|
||||
void
|
||||
AndBoolAVX2(bool* left, bool* right, int64_t size);
|
||||
|
||||
void
|
||||
OrBoolAVX2(bool* left, bool* right, int64_t size);
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
|
|
@ -183,6 +183,39 @@ FindTermAVX512(const double* src, size_t vec_size, double val) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
AndBoolAVX512(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 64;
|
||||
for (size_t i = 0; i < num_chunk * 64; i += 64) {
|
||||
__m512i l_reg =
|
||||
_mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i));
|
||||
__m512i r_reg =
|
||||
_mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i));
|
||||
__m512i res = _mm512_and_si512(l_reg, r_reg);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res);
|
||||
}
|
||||
for (size_t i = num_chunk * 64; i < size; ++i) {
|
||||
left[i] &= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
OrBoolAVX512(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 64;
|
||||
for (size_t i = 0; i < num_chunk * 64; i += 64) {
|
||||
__m512i l_reg =
|
||||
_mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i));
|
||||
__m512i r_reg =
|
||||
_mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i));
|
||||
__m512i res = _mm512_or_si512(l_reg, r_reg);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res);
|
||||
}
|
||||
for (size_t i = num_chunk * 64; i < size; ++i) {
|
||||
left[i] |= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
#endif
|
||||
|
|
|
@ -55,5 +55,11 @@ template <>
|
|||
bool
|
||||
FindTermAVX512(const double* src, size_t vec_size, double val);
|
||||
|
||||
void
|
||||
AndBoolAVX512(bool* left, bool* right, int64_t size);
|
||||
|
||||
void
|
||||
OrBoolAVX512(bool* left, bool* right, int64_t size);
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include "sse2.h"
|
||||
#include "sse4.h"
|
||||
#include "instruction_set.h"
|
||||
#elif defined(__ARM_NEON)
|
||||
#include "neon.h"
|
||||
#endif
|
||||
|
||||
namespace milvus {
|
||||
|
@ -44,6 +46,12 @@ bool use_find_term_avx512;
|
|||
#endif
|
||||
|
||||
decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef;
|
||||
decltype(all_false) all_false = AllFalseRef;
|
||||
decltype(all_true) all_true = AllTrueRef;
|
||||
decltype(invert_bool) invert_bool = InvertBoolRef;
|
||||
decltype(and_bool) and_bool = AndBoolRef;
|
||||
decltype(or_bool) or_bool = OrBoolRef;
|
||||
|
||||
FindTermPtr<bool> find_term_bool = FindTermRef<bool>;
|
||||
FindTermPtr<int8_t> find_term_int8 = FindTermRef<int8_t>;
|
||||
FindTermPtr<int16_t> find_term_int16 = FindTermRef<int16_t>;
|
||||
|
@ -161,9 +169,82 @@ find_term_hook() {
|
|||
LOG_SEGCORE_INFO_ << "find term hook simd type: " << simd_type;
|
||||
}
|
||||
|
||||
void
|
||||
all_boolean_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_sse2 && cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
all_false = AllFalseSSE2;
|
||||
all_true = AllTrueSSE2;
|
||||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
simd_type = "NEON";
|
||||
all_false = AllFalseNEON;
|
||||
all_true = AllTrueNEON;
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_SEGCORE_INFO_ << "AllFalse/AllTrue hook simd type: " << simd_type;
|
||||
}
|
||||
|
||||
void
|
||||
invert_boolean_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_sse2 && cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
invert_bool = InvertBoolSSE2;
|
||||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
simd_type = "NEON";
|
||||
invert_bool = InvertBoolNEON;
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_SEGCORE_INFO_ << "InvertBoolean hook simd type: " << simd_type;
|
||||
}
|
||||
|
||||
void
|
||||
logical_boolean_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_avx512 && cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
and_bool = AndBoolAVX512;
|
||||
or_bool = OrBoolAVX512;
|
||||
} else if (use_avx2 && cpu_support_avx2()) {
|
||||
simd_type = "AVX2";
|
||||
and_bool = AndBoolAVX2;
|
||||
or_bool = OrBoolAVX2;
|
||||
} else if (use_sse2 && cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
and_bool = AndBoolSSE2;
|
||||
or_bool = OrBoolSSE2;
|
||||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
simd_type = "NEON";
|
||||
and_bool = AndBoolNEON;
|
||||
or_bool = OrBoolNEON;
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_SEGCORE_INFO_ << "InvertBoolean hook simd type: " << simd_type;
|
||||
}
|
||||
void
|
||||
boolean_hook() {
|
||||
all_boolean_hook();
|
||||
invert_boolean_hook();
|
||||
logical_boolean_hook();
|
||||
}
|
||||
|
||||
static int init_hook_ = []() {
|
||||
bitset_hook();
|
||||
find_term_hook();
|
||||
boolean_hook();
|
||||
return 0;
|
||||
}();
|
||||
|
||||
|
|
|
@ -19,6 +19,11 @@ namespace milvus {
|
|||
namespace simd {
|
||||
|
||||
extern BitsetBlockType (*get_bitset_block)(const bool* src);
|
||||
extern bool (*all_false)(const bool* src, int64_t size);
|
||||
extern bool (*all_true)(const bool* src, int64_t size);
|
||||
extern void (*invert_bool)(bool* src, int64_t size);
|
||||
extern void (*and_bool)(bool* left, bool* right, int64_t size);
|
||||
extern void (*or_bool)(bool* left, bool* right, int64_t size);
|
||||
|
||||
template <typename T>
|
||||
using FindTermPtr = bool (*)(const T* src, size_t size, T val);
|
||||
|
@ -63,6 +68,18 @@ bitset_hook();
|
|||
void
|
||||
find_term_hook();
|
||||
|
||||
void
|
||||
boolean_hook();
|
||||
|
||||
void
|
||||
all_boolean_hook();
|
||||
|
||||
void
|
||||
invert_boolean_hook();
|
||||
|
||||
void
|
||||
logical_boolean_hook();
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
find_term_func(const T* data, size_t size, T val) {
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
|
||||
#include "neon.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <arm_neon.h>
|
||||
|
||||
namespace milvus {
|
||||
namespace simd {
|
||||
|
||||
bool
|
||||
AllFalseNEON(const bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(src);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
uint8x16_t data = vld1q_u8(ptr + i);
|
||||
if (vmaxvq_u8(data) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
if (src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AllTrueNEON(const bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(src);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
uint8x16_t data = vld1q_u8(ptr + i);
|
||||
if (vminvq_u8(data) == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
if (!src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
InvertBoolNEON(bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
uint8x16_t mask = vdupq_n_u8(0x01);
|
||||
uint8_t* ptr = reinterpret_cast<uint8_t*>(src);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
uint8x16_t data = vld1q_u8(ptr + i);
|
||||
|
||||
uint8x16_t flipped = veorq_u8(data, mask);
|
||||
|
||||
vst1q_u8(ptr + i, flipped);
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
src[i] = !src[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
AndBoolNEON(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
uint8_t* lptr = reinterpret_cast<uint8_t*>(left);
|
||||
uint8_t* rptr = reinterpret_cast<uint8_t*>(right);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
uint8x16_t l_reg = vld1q_u8(lptr + i);
|
||||
uint8x16_t r_reg = vld1q_u8(rptr + i);
|
||||
|
||||
uint8x16_t res = vandq_u8(l_reg, r_reg);
|
||||
|
||||
vst1q_u8(lptr + i, res);
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
left[i] &= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
OrBoolNEON(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
uint8_t* lptr = reinterpret_cast<uint8_t*>(left);
|
||||
uint8_t* rptr = reinterpret_cast<uint8_t*>(right);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
uint8x16_t l_reg = vld1q_u8(lptr + i);
|
||||
uint8x16_t r_reg = vld1q_u8(rptr + i);
|
||||
|
||||
uint8x16_t res = vorrq_u8(l_reg, r_reg);
|
||||
|
||||
vst1q_u8(lptr + i, res);
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
left[i] |= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
#endif
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include "common.h"
|
||||
namespace milvus {
|
||||
namespace simd {
|
||||
|
||||
BitsetBlockType
|
||||
GetBitsetBlockSSE2(const bool* src);
|
||||
|
||||
bool
|
||||
AllFalseNEON(const bool* src, int64_t size);
|
||||
|
||||
bool
|
||||
AllTrueNEON(const bool* src, int64_t size);
|
||||
|
||||
void
|
||||
InvertBoolNEON(bool* src, int64_t size);
|
||||
|
||||
void
|
||||
AndBoolNEON(bool* left, bool* right, int64_t size);
|
||||
|
||||
void
|
||||
OrBoolNEON(bool* left, bool* right, int64_t size);
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
|
@ -29,5 +29,46 @@ GetBitsetBlockRef(const bool* src) {
|
|||
return val;
|
||||
}
|
||||
|
||||
bool
|
||||
AllTrueRef(const bool* src, int64_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (!src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AllFalseRef(const bool* src, int64_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
InvertBoolRef(bool* src, int64_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
src[i] = !src[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
AndBoolRef(bool* left, bool* right, int64_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
left[i] &= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
OrBoolRef(bool* left, bool* right, int64_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
left[i] |= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
|
|
@ -19,6 +19,21 @@ namespace simd {
|
|||
BitsetBlockType
|
||||
GetBitsetBlockRef(const bool* src);
|
||||
|
||||
bool
|
||||
AllTrueRef(const bool* src, int64_t size);
|
||||
|
||||
bool
|
||||
AllFalseRef(const bool* src, int64_t size);
|
||||
|
||||
void
|
||||
InvertBoolRef(bool* src, int64_t size);
|
||||
|
||||
void
|
||||
AndBoolRef(bool* left, bool* right, int64_t size);
|
||||
|
||||
void
|
||||
OrBoolRef(bool* left, bool* right, int64_t size);
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
FindTermRef(const T* src, size_t size, T val) {
|
||||
|
|
|
@ -256,6 +256,102 @@ FindTermSSE2(const double* src, size_t vec_size, double val) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
print_m128i(__m128i v) {
|
||||
alignas(16) int result[4];
|
||||
_mm_store_si128(reinterpret_cast<__m128i*>(result), v);
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
std::cout << std::hex << result[i] << " ";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
bool
|
||||
AllFalseSSE2(const bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
__m128i highbit = _mm_set1_epi8(0x7F);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m128i data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i highbits = _mm_add_epi8(data, highbit);
|
||||
if (_mm_movemask_epi8(highbits) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
if (src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AllTrueSSE2(const bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
__m128i highbit = _mm_set1_epi8(0x7F);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m128i data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i highbits = _mm_add_epi8(data, highbit);
|
||||
if (_mm_movemask_epi8(highbits) != 0xFFFF) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
if (!src[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
InvertBoolSSE2(bool* src, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
__m128i mask = _mm_set1_epi8(0x01);
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m128i data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src + i));
|
||||
__m128i flipped = _mm_xor_si128(data, mask);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(src + i), flipped);
|
||||
}
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
src[i] = !src[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
AndBoolSSE2(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i));
|
||||
__m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i));
|
||||
__m128i res = _mm_and_si128(l_reg, r_reg);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res);
|
||||
}
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
left[i] &= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
OrBoolSSE2(bool* left, bool* right, int64_t size) {
|
||||
int num_chunk = size / 16;
|
||||
for (size_t i = 0; i < num_chunk * 16; i += 16) {
|
||||
__m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i));
|
||||
__m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i));
|
||||
__m128i res = _mm_or_si128(l_reg, r_reg);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res);
|
||||
}
|
||||
for (size_t i = num_chunk * 16; i < size; ++i) {
|
||||
left[i] |= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
|
|
|
@ -24,6 +24,21 @@ namespace simd {
|
|||
BitsetBlockType
|
||||
GetBitsetBlockSSE2(const bool* src);
|
||||
|
||||
bool
|
||||
AllFalseSSE2(const bool* src, int64_t size);
|
||||
|
||||
bool
|
||||
AllTrueSSE2(const bool* src, int64_t size);
|
||||
|
||||
void
|
||||
InvertBoolSSE2(bool* src, int64_t size);
|
||||
|
||||
void
|
||||
AndBoolSSE2(bool* left, bool* right, int64_t size);
|
||||
|
||||
void
|
||||
OrBoolSSE2(bool* left, bool* right, int64_t size);
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
FindTermSSE2(const T* src, size_t vec_size, T va) {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue