mirror of https://github.com/milvus-io/milvus.git
Remove outdated searchplan (#25282)
Signed-off-by: Enwei Jiao <enwei.jiao@zilliz.com>pull/25337/head
parent
4a87b9f60a
commit
816158e4af
|
@ -21,7 +21,6 @@ set(MILVUS_QUERY_SRCS
|
|||
visitors/VerifyExprVisitor.cpp
|
||||
visitors/ExtractInfoPlanNodeVisitor.cpp
|
||||
visitors/ExtractInfoExprVisitor.cpp
|
||||
Parser.cpp
|
||||
Plan.cpp
|
||||
SearchOnGrowing.cpp
|
||||
SearchOnSealed.cpp
|
||||
|
|
|
@ -1,514 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
#include "ExprImpl.h"
|
||||
#include "Parser.h"
|
||||
#include "Plan.h"
|
||||
#include "generated/ExtractInfoPlanNodeVisitor.h"
|
||||
#include "generated/VerifyPlanNodeVisitor.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "query/Expr.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
template <typename Merger>
|
||||
static ExprPtr
|
||||
ConstructTree(Merger merger, std::vector<ExprPtr> item_list) {
|
||||
if (item_list.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (item_list.size() == 1) {
|
||||
return std::move(item_list[0]);
|
||||
}
|
||||
|
||||
// Note: use deque to construct a binary tree
|
||||
// Op
|
||||
// / \
|
||||
// Op Op
|
||||
// | \ | \
|
||||
// A B C D
|
||||
std::deque<ExprPtr> binary_queue;
|
||||
for (auto& item : item_list) {
|
||||
Assert(item != nullptr);
|
||||
binary_queue.push_back(std::move(item));
|
||||
}
|
||||
while (binary_queue.size() > 1) {
|
||||
auto left = std::move(binary_queue.front());
|
||||
binary_queue.pop_front();
|
||||
auto right = std::move(binary_queue.front());
|
||||
binary_queue.pop_front();
|
||||
binary_queue.push_back(merger(std::move(left), std::move(right)));
|
||||
}
|
||||
Assert(binary_queue.size() == 1);
|
||||
return std::move(binary_queue.front());
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseCompareNode(const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto op_name = boost::algorithm::to_lower_copy(std::string(out_iter.key()));
|
||||
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
|
||||
auto body = out_iter.value();
|
||||
Assert(body.is_array());
|
||||
Assert(body.size() == 2);
|
||||
auto expr = std::make_unique<CompareExpr>();
|
||||
expr->op_type_ = mapping_.at(op_name);
|
||||
|
||||
auto& item0 = body[0];
|
||||
Assert(item0.is_string());
|
||||
auto left_field_name = FieldName(item0.get<std::string>());
|
||||
expr->left_data_type_ = schema[left_field_name].get_data_type();
|
||||
expr->left_field_id_ = schema.get_field_id(left_field_name);
|
||||
|
||||
auto& item1 = body[1];
|
||||
Assert(item1.is_string());
|
||||
auto right_field_name = FieldName(item1.get<std::string>());
|
||||
expr->right_data_type_ = schema[right_field_name].get_data_type();
|
||||
expr->right_field_id_ = schema.get_field_id(right_field_name);
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseRangeNode(const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto field_name = FieldName(out_iter.key());
|
||||
auto& body = out_iter.value();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!datatype_is_vector(data_type));
|
||||
|
||||
switch (data_type) {
|
||||
case DataType::BOOL:
|
||||
return ParseRangeNodeImpl<bool>(field_name, body);
|
||||
|
||||
case DataType::INT8:
|
||||
return ParseRangeNodeImpl<int8_t>(field_name, body);
|
||||
case DataType::INT16:
|
||||
return ParseRangeNodeImpl<int16_t>(field_name, body);
|
||||
case DataType::INT32:
|
||||
return ParseRangeNodeImpl<int32_t>(field_name, body);
|
||||
case DataType::INT64:
|
||||
return ParseRangeNodeImpl<int64_t>(field_name, body);
|
||||
|
||||
case DataType::FLOAT:
|
||||
return ParseRangeNodeImpl<float>(field_name, body);
|
||||
case DataType::DOUBLE:
|
||||
return ParseRangeNodeImpl<double>(field_name, body);
|
||||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Plan>
|
||||
Parser::CreatePlanImpl(const Json& dsl) {
|
||||
auto bool_dsl = dsl.at("bool");
|
||||
auto predicate = ParseAnyNode(bool_dsl);
|
||||
Assert(vector_node_opt_.has_value());
|
||||
auto vec_node = std::move(vector_node_opt_).value();
|
||||
if (predicate != nullptr) {
|
||||
vec_node->predicate_ = std::move(predicate);
|
||||
}
|
||||
VerifyPlanNodeVisitor verifier;
|
||||
vec_node->accept(verifier);
|
||||
|
||||
ExtractedPlanInfo plan_info(schema.size());
|
||||
ExtractInfoPlanNodeVisitor extractor(plan_info);
|
||||
vec_node->accept(extractor);
|
||||
|
||||
auto plan = std::make_unique<Plan>(schema);
|
||||
plan->tag2field_ = std::move(tag2field_);
|
||||
plan->plan_node_ = std::move(vec_node);
|
||||
plan->extra_info_opt_ = std::move(plan_info);
|
||||
return plan;
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseTermNode(const Json& out_body) {
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto field_name = FieldName(out_iter.key());
|
||||
auto body = out_iter.value();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!datatype_is_vector(data_type));
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
return ParseTermNodeImpl<bool>(field_name, body);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
return ParseTermNodeImpl<int8_t>(field_name, body);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ParseTermNodeImpl<int16_t>(field_name, body);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ParseTermNodeImpl<int32_t>(field_name, body);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ParseTermNodeImpl<int64_t>(field_name, body);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ParseTermNodeImpl<float>(field_name, body);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ParseTermNodeImpl<double>(field_name, body);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data_type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<VectorPlanNode>
|
||||
Parser::ParseVecNode(const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
Assert(out_body.size() == 1);
|
||||
auto iter = out_body.begin();
|
||||
auto field_name = FieldName(iter.key());
|
||||
|
||||
auto& vec_info = iter.value();
|
||||
Assert(vec_info.is_object());
|
||||
auto topk = vec_info["topk"];
|
||||
AssertInfo(topk > 0, "topk must greater than 0");
|
||||
AssertInfo(topk < 16384, "topk is too large");
|
||||
|
||||
auto field_id = schema.get_field_id(field_name);
|
||||
|
||||
auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
||||
auto& field_meta = schema.operator[](field_name);
|
||||
auto data_type = field_meta.get_data_type();
|
||||
if (data_type == DataType::VECTOR_FLOAT) {
|
||||
return std::make_unique<FloatVectorANNS>();
|
||||
} else {
|
||||
return std::make_unique<BinaryVectorANNS>();
|
||||
}
|
||||
}();
|
||||
vec_node->search_info_.topk_ = topk;
|
||||
vec_node->search_info_.metric_type_ = vec_info.at("metric_type");
|
||||
vec_node->search_info_.search_params_ = vec_info.at("params");
|
||||
vec_node->search_info_.field_id_ = field_id;
|
||||
vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");
|
||||
vec_node->placeholder_tag_ = vec_info.at("query");
|
||||
auto tag = vec_node->placeholder_tag_;
|
||||
AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");
|
||||
tag2field_.emplace(tag, field_id);
|
||||
return vec_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ExprPtr
|
||||
Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) {
|
||||
Assert(body.is_object());
|
||||
auto values = body["values"];
|
||||
auto is_in_field = body["is_in_field"];
|
||||
|
||||
std::vector<T> terms(values.size());
|
||||
auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
auto value = values[i];
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(value.is_boolean());
|
||||
val_case = proto::plan::GenericValue::ValCase::kBoolVal;
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(value.is_number_integer());
|
||||
val_case = proto::plan::GenericValue::ValCase::kInt64Val;
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(value.is_number());
|
||||
val_case = proto::plan::GenericValue::ValCase::kFloatVal;
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
}
|
||||
terms[i] = value;
|
||||
}
|
||||
std::sort(terms.begin(), terms.end());
|
||||
return std::make_unique<TermExprImpl<T>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
terms,
|
||||
val_case,
|
||||
is_in_field);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ExprPtr
|
||||
Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
|
||||
Assert(body.is_object());
|
||||
if (body.size() == 1) {
|
||||
auto item = body.begin();
|
||||
auto op_name = boost::algorithm::to_lower_copy(std::string(item.key()));
|
||||
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
|
||||
|
||||
// This is an expression with an arithmetic operation
|
||||
if (item.value().is_object()) {
|
||||
/* // This is the expected DSL expression
|
||||
{
|
||||
range: {
|
||||
field_name: {
|
||||
op: {
|
||||
arith_op: {
|
||||
right_operand: operand,
|
||||
value: value
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EXAMPLE:
|
||||
{
|
||||
range: {
|
||||
field_name: {
|
||||
"EQ": {
|
||||
"ADD": {
|
||||
right_operand: 10,
|
||||
value: 25
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
auto arith = item.value();
|
||||
auto arith_body = arith.begin();
|
||||
|
||||
auto arith_op_name =
|
||||
boost::algorithm::to_lower_copy(std::string(arith_body.key()));
|
||||
AssertInfo(arith_op_mapping_.count(arith_op_name),
|
||||
"arith op(" + arith_op_name + ") not found");
|
||||
|
||||
auto& arith_op_body = arith_body.value();
|
||||
Assert(arith_op_body.is_object());
|
||||
|
||||
auto right_operand = arith_op_body["right_operand"];
|
||||
auto value = arith_op_body["value"];
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
throw std::runtime_error("bool type is not supported");
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(right_operand.is_number_integer());
|
||||
Assert(value.is_number_integer());
|
||||
// see also: https://github.com/milvus-io/milvus/issues/23646.
|
||||
return std::make_unique<
|
||||
BinaryArithOpEvalRangeExprImpl<int64_t>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
proto::plan::GenericValue::ValCase::kInt64Val,
|
||||
arith_op_mapping_.at(arith_op_name),
|
||||
right_operand,
|
||||
mapping_.at(op_name),
|
||||
value);
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(right_operand.is_number());
|
||||
Assert(value.is_number());
|
||||
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
proto::plan::GenericValue::ValCase::kFloatVal,
|
||||
arith_op_mapping_.at(arith_op_name),
|
||||
right_operand,
|
||||
mapping_.at(op_name),
|
||||
value);
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(item.value().is_boolean());
|
||||
return std::make_unique<UnaryRangeExprImpl<T>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
mapping_.at(op_name),
|
||||
item.value(),
|
||||
proto::plan::GenericValue::ValCase::kBoolVal);
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(item.value().is_number_integer());
|
||||
// see also: https://github.com/milvus-io/milvus/issues/23646.
|
||||
return std::make_unique<UnaryRangeExprImpl<int64_t>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
mapping_.at(op_name),
|
||||
item.value(),
|
||||
proto::plan::GenericValue::ValCase::kInt64Val);
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(item.value().is_number());
|
||||
return std::make_unique<UnaryRangeExprImpl<T>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
mapping_.at(op_name),
|
||||
item.value(),
|
||||
proto::plan::GenericValue::ValCase::kFloatVal);
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
}
|
||||
} else if (body.size() == 2) {
|
||||
bool has_lower_value = false;
|
||||
bool has_upper_value = false;
|
||||
bool lower_inclusive = false;
|
||||
bool upper_inclusive = false;
|
||||
T lower_value;
|
||||
T upper_value;
|
||||
for (auto& item : body.items()) {
|
||||
auto op_name =
|
||||
boost::algorithm::to_lower_copy(std::string(item.key()));
|
||||
AssertInfo(mapping_.count(op_name),
|
||||
"op(" + op_name + ") not found");
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(item.value().is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(item.value().is_number_integer());
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(item.value().is_number());
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
}
|
||||
auto op = mapping_.at(op_name);
|
||||
switch (op) {
|
||||
case OpType::GreaterEqual:
|
||||
lower_inclusive = true;
|
||||
case OpType::GreaterThan:
|
||||
lower_value = item.value();
|
||||
has_lower_value = true;
|
||||
break;
|
||||
case OpType::LessEqual:
|
||||
upper_inclusive = true;
|
||||
case OpType::LessThan:
|
||||
upper_value = item.value();
|
||||
has_upper_value = true;
|
||||
break;
|
||||
default:
|
||||
PanicInfo("unsupported operator in binary-range node");
|
||||
}
|
||||
}
|
||||
AssertInfo(has_lower_value && has_upper_value,
|
||||
"illegal binary-range node");
|
||||
return std::make_unique<BinaryRangeExprImpl<T>>(
|
||||
ColumnInfo(schema.get_field_id(field_name),
|
||||
schema[field_name].get_data_type()),
|
||||
proto::plan::GenericValue::ValCase::VAL_NOT_SET,
|
||||
lower_inclusive,
|
||||
upper_inclusive,
|
||||
lower_value,
|
||||
upper_value);
|
||||
} else {
|
||||
PanicInfo("illegal range node, too more or too few ops");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ExprPtr>
|
||||
Parser::ParseItemList(const Json& body) {
|
||||
std::vector<ExprPtr> results;
|
||||
if (body.is_object()) {
|
||||
// only one item;
|
||||
auto new_expr = ParseAnyNode(body);
|
||||
results.emplace_back(std::move(new_expr));
|
||||
} else {
|
||||
// item array
|
||||
Assert(body.is_array());
|
||||
for (auto& item : body) {
|
||||
auto new_expr = ParseAnyNode(item);
|
||||
results.emplace_back(std::move(new_expr));
|
||||
}
|
||||
}
|
||||
auto old_size = results.size();
|
||||
|
||||
auto new_end =
|
||||
std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) {
|
||||
return x == nullptr;
|
||||
});
|
||||
|
||||
results.resize(new_end - results.begin());
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseAnyNode(const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
Assert(out_body.size() == 1);
|
||||
|
||||
auto out_iter = out_body.begin();
|
||||
|
||||
auto key = out_iter.key();
|
||||
auto body = out_iter.value();
|
||||
|
||||
if (key == "must") {
|
||||
return ParseMustNode(body);
|
||||
} else if (key == "should") {
|
||||
return ParseShouldNode(body);
|
||||
} else if (key == "must_not") {
|
||||
return ParseMustNotNode(body);
|
||||
} else if (key == "range") {
|
||||
return ParseRangeNode(body);
|
||||
} else if (key == "term") {
|
||||
return ParseTermNode(body);
|
||||
} else if (key == "compare") {
|
||||
return ParseCompareNode(body);
|
||||
} else if (key == "vector") {
|
||||
auto vec_node = ParseVecNode(body);
|
||||
Assert(!vector_node_opt_.has_value());
|
||||
vector_node_opt_ = std::move(vec_node);
|
||||
return nullptr;
|
||||
} else {
|
||||
PanicInfo("unsupported key: " + key);
|
||||
}
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseMustNode(const Json& body) {
|
||||
auto item_list = ParseItemList(body);
|
||||
auto merger = [](ExprPtr left, ExprPtr right) {
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
return std::make_unique<LogicalBinaryExpr>(
|
||||
OpType::LogicalAnd, left, right);
|
||||
};
|
||||
return ConstructTree(merger, std::move(item_list));
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseShouldNode(const Json& body) {
|
||||
auto item_list = ParseItemList(body);
|
||||
Assert(item_list.size() >= 1);
|
||||
auto merger = [](ExprPtr left, ExprPtr right) {
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
return std::make_unique<LogicalBinaryExpr>(
|
||||
OpType::LogicalOr, left, right);
|
||||
};
|
||||
return ConstructTree(merger, std::move(item_list));
|
||||
}
|
||||
|
||||
ExprPtr
|
||||
Parser::ParseMustNotNode(const Json& body) {
|
||||
auto item_list = ParseItemList(body);
|
||||
Assert(item_list.size() >= 1);
|
||||
auto merger = [](ExprPtr left, ExprPtr right) {
|
||||
using OpType = LogicalBinaryExpr::OpType;
|
||||
return std::make_unique<LogicalBinaryExpr>(
|
||||
OpType::LogicalAnd, left, right);
|
||||
};
|
||||
auto subtree = ConstructTree(merger, std::move(item_list));
|
||||
|
||||
using OpType = LogicalUnaryExpr::OpType;
|
||||
return std::make_unique<LogicalUnaryExpr>(OpType::LogicalNot, subtree);
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
|
@ -1,92 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Plan.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
class Parser {
|
||||
public:
|
||||
friend std::unique_ptr<Plan>
|
||||
CreatePlan(const Schema& schema, const std::string_view dsl_str);
|
||||
|
||||
private:
|
||||
std::unique_ptr<Plan>
|
||||
CreatePlanImpl(const Json& dsl);
|
||||
|
||||
explicit Parser(const Schema& schema) : schema(schema) {
|
||||
}
|
||||
|
||||
// vector node parser, should be called exactly once per pass.
|
||||
std::unique_ptr<VectorPlanNode>
|
||||
ParseVecNode(const Json& out_body);
|
||||
|
||||
// Dispatcher of all parse function
|
||||
// NOTE: when nullptr, it is a pure vector node
|
||||
ExprPtr
|
||||
ParseAnyNode(const Json& body);
|
||||
|
||||
ExprPtr
|
||||
ParseMustNode(const Json& body);
|
||||
|
||||
ExprPtr
|
||||
ParseShouldNode(const Json& body);
|
||||
|
||||
ExprPtr
|
||||
ParseMustNotNode(const Json& body);
|
||||
|
||||
// parse the value of "should"/"must"/"must_not" entry
|
||||
std::vector<ExprPtr>
|
||||
ParseItemList(const Json& body);
|
||||
|
||||
// parse the value of "range" entry
|
||||
ExprPtr
|
||||
ParseRangeNode(const Json& out_body);
|
||||
|
||||
// parse the value of "term" entry
|
||||
ExprPtr
|
||||
ParseTermNode(const Json& out_body);
|
||||
|
||||
// parse the value of "term" entry
|
||||
ExprPtr
|
||||
ParseCompareNode(const Json& out_body);
|
||||
|
||||
private:
|
||||
// template implementation of leaf parser
|
||||
// used by corresponding parser
|
||||
|
||||
template <typename T>
|
||||
ExprPtr
|
||||
ParseRangeNodeImpl(const FieldName& field_name, const Json& body);
|
||||
|
||||
template <typename T>
|
||||
ExprPtr
|
||||
ParseTermNodeImpl(const FieldName& field_name, const Json& body);
|
||||
|
||||
private:
|
||||
const Schema& schema;
|
||||
std::map<std::string, FieldId> tag2field_; // PlaceholderName -> field id
|
||||
std::optional<std::unique_ptr<VectorPlanNode>> vector_node_opt_;
|
||||
};
|
||||
|
||||
} // namespace milvus::query
|
|
@ -14,7 +14,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Parser.h"
|
||||
#include "Plan.h"
|
||||
#include "PlanProto.h"
|
||||
#include "generated/ShowPlanNodeVisitor.h"
|
||||
|
@ -63,14 +62,6 @@ ParsePlaceholderGroup(const Plan* plan,
|
|||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<Plan>
|
||||
CreatePlan(const Schema& schema, const std::string_view dsl_str) {
|
||||
Json dsl;
|
||||
dsl = json::parse(dsl_str);
|
||||
auto plan = Parser(schema).CreatePlanImpl(dsl);
|
||||
return plan;
|
||||
}
|
||||
|
||||
std::unique_ptr<Plan>
|
||||
CreateSearchPlanByExpr(const Schema& schema,
|
||||
const void* serialized_expr_plan,
|
||||
|
|
|
@ -26,9 +26,6 @@ struct Plan;
|
|||
struct PlaceholderGroup;
|
||||
struct RetrievePlan;
|
||||
|
||||
std::unique_ptr<Plan>
|
||||
CreatePlan(const Schema& schema, const std::string_view dsl);
|
||||
|
||||
// Note: serialized_expr_plan is of binary format
|
||||
std::unique_ptr<Plan>
|
||||
CreateSearchPlanByExpr(const Schema& schema,
|
||||
|
|
|
@ -15,34 +15,6 @@
|
|||
#include "segcore/Collection.h"
|
||||
#include "segcore/plan_c.h"
|
||||
|
||||
CStatus
|
||||
CreateSearchPlan(CCollection c_col, const char* dsl, CSearchPlan* res_plan) {
|
||||
auto col = (milvus::segcore::Collection*)c_col;
|
||||
|
||||
try {
|
||||
auto res = milvus::query::CreatePlan(*col->get_schema(), dsl);
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
auto plan = (CSearchPlan)res.release();
|
||||
*res_plan = plan;
|
||||
return status;
|
||||
} catch (milvus::SegcoreError& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = e.get_error_code();
|
||||
status.error_msg = strdup(e.what());
|
||||
*res_plan = nullptr;
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
*res_plan = nullptr;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: serialized_expr_plan is of binary format
|
||||
CStatus
|
||||
CreateSearchPlanByExpr(CCollection c_col,
|
||||
|
|
|
@ -23,9 +23,6 @@ typedef void* CSearchPlan;
|
|||
typedef void* CPlaceholderGroup;
|
||||
typedef void* CRetrievePlan;
|
||||
|
||||
CStatus
|
||||
CreateSearchPlan(CCollection col, const char* dsl, CSearchPlan* res_plan);
|
||||
|
||||
// Note: serialized_expr_plan is of binary format
|
||||
CStatus
|
||||
CreateSearchPlanByExpr(CCollection col,
|
||||
|
|
|
@ -32,7 +32,6 @@ set(MILVUS_TEST_FILES
|
|||
test_index_wrapper.cpp
|
||||
test_init.cpp
|
||||
test_parquet_c.cpp
|
||||
test_plan_proto.cpp
|
||||
test_query.cpp
|
||||
test_reduce.cpp
|
||||
test_reduce_c.cpp
|
||||
|
|
|
@ -32,26 +32,19 @@ const auto schema = []() {
|
|||
}();
|
||||
|
||||
const auto plan = [] {
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": -1
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: -1
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
return plan;
|
||||
}();
|
||||
auto ph_group = [] {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "pb/plan.pb.h"
|
||||
#include "segcore/SegmentGrowing.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
#include "pb/schema.pb.h"
|
||||
|
@ -42,25 +43,36 @@ TEST(GrowingIndex, Correctness) {
|
|||
std::make_shared<CollectionIndexMeta>(226985, std::move(filedMap));
|
||||
auto segment = CreateGrowingSegment(schema, metaPtr);
|
||||
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"embeddings": {
|
||||
"metric_type": "l2",
|
||||
"params": {
|
||||
"nprobe": 16
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal":3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "vector": {
|
||||
// "embeddings": {
|
||||
// "metric_type": "l2",
|
||||
// "params": {
|
||||
// "nprobe": 16
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal":3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
milvus::proto::plan::PlanNode plan_node;
|
||||
auto vector_anns = plan_node.mutable_vector_anns();
|
||||
vector_anns->set_is_binary(false);
|
||||
vector_anns->set_placeholder_tag("$0");
|
||||
vector_anns->set_field_id(102);
|
||||
auto query_info = vector_anns->mutable_query_info();
|
||||
query_info->set_topk(5);
|
||||
query_info->set_round_decimal(3);
|
||||
query_info->set_metric_type("l2");
|
||||
query_info->set_search_params(R"({"nprobe": 16})");
|
||||
auto plan_str = plan_node.SerializeAsString();
|
||||
|
||||
int64_t per_batch = 10000;
|
||||
int64_t n_batch = 20;
|
||||
|
@ -75,7 +87,8 @@ TEST(GrowingIndex, Correctness) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = milvus::query::CreatePlan(*schema, dsl);
|
||||
auto plan = milvus::query::CreateSearchPlanByExpr(
|
||||
*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 128, 1024);
|
||||
auto ph_group =
|
||||
|
|
|
@ -1,692 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "pb/plan.pb.h"
|
||||
#include "query/PlanProto.h"
|
||||
#include "query/generated/ShowPlanNodeVisitor.h"
|
||||
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
namespace planpb = proto::plan;
|
||||
using std::string;
|
||||
|
||||
namespace spb = proto::schema;
|
||||
static SchemaPtr
|
||||
getStandardSchema() {
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddDebugField(
|
||||
"FloatVectorField", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
|
||||
schema->AddDebugField("BinaryVectorField",
|
||||
DataType::VECTOR_BINARY,
|
||||
16,
|
||||
knowhere::metric::JACCARD);
|
||||
schema->AddDebugField("Int64Field", DataType::INT64);
|
||||
schema->AddDebugField("Int32Field", DataType::INT32);
|
||||
schema->AddDebugField("Int16Field", DataType::INT16);
|
||||
schema->AddDebugField("Int8Field", DataType::INT8);
|
||||
schema->AddDebugField("DoubleField", DataType::DOUBLE);
|
||||
schema->AddDebugField("FloatField", DataType::FLOAT);
|
||||
return schema;
|
||||
}
|
||||
|
||||
class PlanProtoTest : public ::testing::TestWithParam<std::tuple<std::string>> {
|
||||
public:
|
||||
PlanProtoTest() {
|
||||
schema = getStandardSchema();
|
||||
}
|
||||
|
||||
protected:
|
||||
SchemaPtr schema;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(InstName,
|
||||
PlanProtoTest,
|
||||
::testing::Values( //
|
||||
std::make_tuple("DoubleField"), //
|
||||
std::make_tuple("FloatField"), //
|
||||
std::make_tuple("Int64Field"), //
|
||||
std::make_tuple("Int32Field"), //
|
||||
std::make_tuple("Int16Field"), //
|
||||
std::make_tuple("Int8Field") //
|
||||
));
|
||||
|
||||
TEST_P(PlanProtoTest, Range) {
|
||||
// xxx.query(predicates = "int64field > 3", topk = 10, ...)
|
||||
FieldName vec_field_name = FieldName("FloatVectorField");
|
||||
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
|
||||
|
||||
auto field_name = std::get<0>(GetParam());
|
||||
auto field_id = schema->get_field_id(FieldName(field_name));
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
auto data_type_str = spb::DataType_Name(int(data_type));
|
||||
|
||||
string value_tag = "bool_val";
|
||||
if (datatype_is_floating(data_type)) {
|
||||
value_tag = "float_val";
|
||||
} else if (datatype_is_integer(data_type)) {
|
||||
value_tag = "int64_val";
|
||||
}
|
||||
|
||||
auto fmt1 = boost::format(R"(
|
||||
vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
op: GreaterThan
|
||||
value: <
|
||||
%4%: 3
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)") % vec_float_field_id.get() %
|
||||
field_id.get() % data_type_str % value_tag;
|
||||
|
||||
auto proto_text = fmt1.str();
|
||||
planpb::PlanNode node_proto;
|
||||
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
|
||||
// std::cout << node_proto.DebugString();
|
||||
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
|
||||
|
||||
ShowPlanNodeVisitor visitor;
|
||||
auto json = visitor.call_child(*plan->plan_node_);
|
||||
// std::cout << json.dump(2);
|
||||
auto extra_info = plan->extra_info_opt_.value();
|
||||
|
||||
std::string dsl_text = boost::str(boost::format(R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"%1%": {
|
||||
"GT": 3
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"FloatVectorField": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)") % field_name);
|
||||
|
||||
auto ref_plan = CreatePlan(*schema, dsl_text);
|
||||
plan->check_identical(*ref_plan);
|
||||
}
|
||||
|
||||
TEST_P(PlanProtoTest, TermExpr) {
|
||||
// xxx.query(predicates = "int64field in [1, 2, 3]", topk = 10, ...)
|
||||
FieldName vec_field_name = FieldName("FloatVectorField");
|
||||
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
|
||||
|
||||
auto field_name = std::get<0>(GetParam());
|
||||
auto field_id = schema->get_field_id(FieldName(field_name));
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
auto data_type_str = spb::DataType_Name(int(data_type));
|
||||
|
||||
string value_tag = "bool_val";
|
||||
if (datatype_is_floating(data_type)) {
|
||||
value_tag = "float_val";
|
||||
} else if (datatype_is_integer(data_type)) {
|
||||
value_tag = "int64_val";
|
||||
}
|
||||
|
||||
auto fmt1 = boost::format(R"(
|
||||
vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
term_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
values: <
|
||||
%4%: 1
|
||||
>
|
||||
values: <
|
||||
%4%: 2
|
||||
>
|
||||
values: <
|
||||
%4%: 3
|
||||
>
|
||||
is_in_field : false
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)") % vec_float_field_id.get() %
|
||||
field_id.get() % data_type_str % value_tag;
|
||||
|
||||
auto proto_text = fmt1.str();
|
||||
planpb::PlanNode node_proto;
|
||||
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
|
||||
// std::cout << node_proto.DebugString();
|
||||
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
|
||||
|
||||
ShowPlanNodeVisitor visitor;
|
||||
auto json = visitor.call_child(*plan->plan_node_);
|
||||
// std::cout << json.dump(2);
|
||||
auto extra_info = plan->extra_info_opt_.value();
|
||||
|
||||
std::string dsl_text = boost::str(boost::format(R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"%1%": {
|
||||
"values": [1,2,3],
|
||||
"is_in_field" : false
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"FloatVectorField": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)") % field_name);
|
||||
|
||||
auto ref_plan = CreatePlan(*schema, dsl_text);
|
||||
plan->check_identical(*ref_plan);
|
||||
}
|
||||
|
||||
TEST(PlanProtoTest, NotExpr) {
|
||||
auto schema = getStandardSchema();
|
||||
// xxx.query(predicates = "not (int64field > 3)", topk = 10, ...)
|
||||
FieldName vec_field_name = FieldName("FloatVectorField");
|
||||
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
|
||||
|
||||
FieldName int64_field_name = FieldName("Int64Field");
|
||||
FieldId int64_field_id = schema->get_field_id(int64_field_name);
|
||||
string value_tag = "int64_val";
|
||||
|
||||
auto data_type = spb::DataType::Int64;
|
||||
auto data_type_str = spb::DataType_Name(int(data_type));
|
||||
|
||||
auto fmt1 = boost::format(R"(
|
||||
vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
unary_expr: <
|
||||
op: Not
|
||||
child: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
op: GreaterThan
|
||||
value: <
|
||||
%4%: 3
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)") % vec_float_field_id.get() %
|
||||
int64_field_id.get() % data_type_str % value_tag;
|
||||
|
||||
auto proto_text = fmt1.str();
|
||||
planpb::PlanNode node_proto;
|
||||
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
|
||||
// std::cout << node_proto.DebugString();
|
||||
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
|
||||
|
||||
ShowPlanNodeVisitor visitor;
|
||||
auto json = visitor.call_child(*plan->plan_node_);
|
||||
// std::cout << json.dump(2);
|
||||
auto extra_info = plan->extra_info_opt_.value();
|
||||
|
||||
std::string dsl_text = boost::str(boost::format(R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"must_not": [{
|
||||
"range": {
|
||||
"%1%": {
|
||||
"GT": 3
|
||||
}
|
||||
}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"FloatVectorField": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)") % int64_field_name.get());
|
||||
|
||||
auto ref_plan = CreatePlan(*schema, dsl_text);
|
||||
auto ref_json = ShowPlanNodeVisitor().call_child(*ref_plan->plan_node_);
|
||||
EXPECT_EQ(json.dump(2), ref_json.dump(2));
|
||||
plan->check_identical(*ref_plan);
|
||||
}
|
||||
|
||||
TEST(PlanProtoTest, AndOrExpr) {
|
||||
auto schema = getStandardSchema();
|
||||
// xxx.query(predicates = "(int64field < 3) && (int64field > 2 || int64field == 1)", topk = 10, ...)
|
||||
FieldName vec_field_name = FieldName("FloatVectorField");
|
||||
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
|
||||
|
||||
FieldName int64_field_name = FieldName("Int64Field");
|
||||
FieldId int64_field_id = schema->get_field_id(int64_field_name);
|
||||
string value_tag = "int64_val";
|
||||
|
||||
auto data_type = spb::DataType::Int64;
|
||||
auto data_type_str = spb::DataType_Name(int(data_type));
|
||||
|
||||
auto fmt1 = boost::format(R"(
|
||||
vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
binary_expr: <
|
||||
op: LogicalAnd
|
||||
left: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
op: LessThan
|
||||
value: <
|
||||
%4%: 3
|
||||
>
|
||||
>
|
||||
>
|
||||
right: <
|
||||
binary_expr: <
|
||||
op: LogicalOr
|
||||
left: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
op: GreaterThan
|
||||
value: <
|
||||
%4%: 2
|
||||
>
|
||||
>
|
||||
>
|
||||
right: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
op: Equal
|
||||
value: <
|
||||
%4%: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)") % vec_float_field_id.get() %
|
||||
int64_field_id.get() % data_type_str % value_tag;
|
||||
|
||||
auto proto_text = fmt1.str();
|
||||
planpb::PlanNode node_proto;
|
||||
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
|
||||
// std::cout << node_proto.DebugString();
|
||||
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
|
||||
|
||||
ShowPlanNodeVisitor visitor;
|
||||
auto json = visitor.call_child(*plan->plan_node_);
|
||||
// std::cout << json.dump(2);
|
||||
auto extra_info = plan->extra_info_opt_.value();
|
||||
|
||||
std::string dsl_text = boost::str(boost::format(R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"%1%": {
|
||||
"LT": 3
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"should": [
|
||||
{
|
||||
"range": {
|
||||
"%1%": {
|
||||
"GT": 2
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"range": {
|
||||
"%1%": {
|
||||
"EQ": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"FloatVectorField": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)") % int64_field_name.get());
|
||||
|
||||
auto ref_plan = CreatePlan(*schema, dsl_text);
|
||||
auto ref_json = ShowPlanNodeVisitor().call_child(*ref_plan->plan_node_);
|
||||
EXPECT_EQ(json.dump(2), ref_json.dump(2));
|
||||
plan->check_identical(*ref_plan);
|
||||
}
|
||||
|
||||
TEST_P(PlanProtoTest, CompareExpr) {
|
||||
auto schema = getStandardSchema();
|
||||
auto age_fid = schema->AddDebugField("age1", DataType::INT64);
|
||||
// xxx.query(predicates = "int64field < int64field", topk = 10, ...)
|
||||
|
||||
FieldName vec_field_name = FieldName("FloatVectorField");
|
||||
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
|
||||
|
||||
auto field_name = std::get<0>(GetParam());
|
||||
auto field_id = schema->get_field_id(FieldName(field_name));
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
auto data_type_str = spb::DataType_Name(int(data_type));
|
||||
|
||||
auto fmt1 = boost::format(R"(
|
||||
vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
compare_expr: <
|
||||
left_column_info: <
|
||||
field_id: %2%
|
||||
data_type: Int64
|
||||
>
|
||||
right_column_info: <
|
||||
field_id: %3%
|
||||
data_type: %4%
|
||||
>
|
||||
op: LessThan
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)") % vec_float_field_id.get() %
|
||||
age_fid.get() % field_id.get() % data_type_str;
|
||||
|
||||
auto proto_text = fmt1.str();
|
||||
planpb::PlanNode node_proto;
|
||||
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
|
||||
// std::cout << node_proto.DebugString();
|
||||
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
|
||||
|
||||
ShowPlanNodeVisitor visitor;
|
||||
auto json = visitor.call_child(*plan->plan_node_);
|
||||
// std::cout << json.dump(2);
|
||||
auto extra_info = plan->extra_info_opt_.value();
|
||||
|
||||
std::string dsl_text = boost::str(boost::format(R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"compare": {
|
||||
"LT": [
|
||||
"age1",
|
||||
"%1%"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"FloatVectorField": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)") % field_name);
|
||||
|
||||
auto ref_plan = CreatePlan(*schema, dsl_text);
|
||||
plan->check_identical(*ref_plan);
|
||||
}
|
||||
|
||||
TEST_P(PlanProtoTest, BinaryArithOpEvalRange) {
|
||||
// xxx.query(predicates = "int64field > 3", topk = 10, ...)
|
||||
// auto data_type = std::get<0>(GetParam());
|
||||
// auto data_type_str = spb::DataType_Name(data_type);
|
||||
// auto field_id = 100 + (int)data_type;
|
||||
// auto field_name = data_type_str + "Field";
|
||||
// string value_tag = "bool_val";
|
||||
// if (datatype_is_floating((DataType)data_type)) {
|
||||
// value_tag = "float_val";
|
||||
// } else if (datatype_is_integer((DataType)data_type)) {
|
||||
// value_tag = "int64_val";
|
||||
// }
|
||||
|
||||
FieldName vec_field_name = FieldName("FloatVectorField");
|
||||
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
|
||||
|
||||
auto field_name = std::get<0>(GetParam());
|
||||
auto field_id = schema->get_field_id(FieldName(field_name));
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
auto data_type_str = spb::DataType_Name(int(data_type));
|
||||
|
||||
string value_tag = "bool_val";
|
||||
if (datatype_is_floating(data_type)) {
|
||||
value_tag = "float_val";
|
||||
} else if (datatype_is_integer(data_type)) {
|
||||
value_tag = "int64_val";
|
||||
}
|
||||
|
||||
auto fmt1 = boost::format(R"(
|
||||
vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
binary_arith_op_eval_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: %3%
|
||||
>
|
||||
arith_op: Add
|
||||
right_operand: <
|
||||
%4%: 1029
|
||||
>
|
||||
op: Equal
|
||||
value: <
|
||||
%4%: 2016
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)") % vec_float_field_id.get() %
|
||||
field_id.get() % data_type_str % value_tag;
|
||||
|
||||
auto proto_text = fmt1.str();
|
||||
planpb::PlanNode node_proto;
|
||||
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
|
||||
// std::cout << node_proto.DebugString();
|
||||
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
|
||||
|
||||
ShowPlanNodeVisitor visitor;
|
||||
auto json = visitor.call_child(*plan->plan_node_);
|
||||
// std::cout << json.dump(2);
|
||||
auto extra_info = plan->extra_info_opt_.value();
|
||||
|
||||
std::string dsl_text = boost::str(boost::format(R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"%1%": {
|
||||
"EQ": {
|
||||
"ADD": {
|
||||
"right_operand": 1029,
|
||||
"value": 2016
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"FloatVectorField": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)") % field_name);
|
||||
|
||||
auto ref_plan = CreatePlan(*schema, dsl_text);
|
||||
plan->check_identical(*ref_plan);
|
||||
}
|
||||
|
||||
TEST(PlanProtoTest, Predicates) {
|
||||
auto schema = getStandardSchema();
|
||||
auto age_fid = schema->AddDebugField("age1", DataType::INT64);
|
||||
|
||||
planpb::PlanNode plan_node_proto;
|
||||
auto expr =
|
||||
plan_node_proto.mutable_predicates()->mutable_unary_range_expr();
|
||||
expr->set_op(planpb::Equal);
|
||||
auto column_info = expr->mutable_column_info();
|
||||
column_info->set_data_type(proto::schema::DataType::Int64);
|
||||
column_info->set_field_id(age_fid.get());
|
||||
auto value = expr->mutable_value();
|
||||
value->set_int64_val(1000);
|
||||
|
||||
std::string binary;
|
||||
plan_node_proto.SerializeToString(&binary);
|
||||
|
||||
auto plan =
|
||||
CreateRetrievePlanByExpr(*schema, binary.c_str(), binary.size());
|
||||
ASSERT_TRUE(plan->plan_node_->predicate_.has_value());
|
||||
ASSERT_FALSE(plan->plan_node_->is_count);
|
||||
}
|
|
@ -53,84 +53,40 @@ TEST(Query, ShowExecutor) {
|
|||
std::cout << dup.dump(4);
|
||||
}
|
||||
|
||||
TEST(Query, DSL) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
ShowPlanNodeVisitor shower;
|
||||
|
||||
std::string dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
auto res = shower.call_child(*plan->plan_node_);
|
||||
std::cout << res.dump(4) << std::endl;
|
||||
|
||||
std::string dsl_string2 = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
auto plan2 = CreatePlan(*schema, dsl_string2);
|
||||
auto res2 = shower.call_child(*plan2->plan_node_);
|
||||
std::cout << res2.dump(4) << std::endl;
|
||||
ASSERT_EQ(res, res2);
|
||||
}
|
||||
|
||||
TEST(Query, ParsePlaceholderGroup) {
|
||||
std::string dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal":3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
// std::string dsl_string = R"(
|
||||
// {
|
||||
// "bool": {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 10,
|
||||
// "round_decimal":3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
int64_t num_queries = 100000;
|
||||
int dim = 16;
|
||||
auto raw_group = CreatePlaceholderGroup(num_queries, dim);
|
||||
|
@ -147,33 +103,59 @@ TEST(Query, ExecWithPredicateLoader) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto counter_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(counter_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "age": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Float
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
int64_t N = ROW_COUNT;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -184,7 +166,9 @@ TEST(Query, ExecWithPredicateLoader) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -231,33 +215,59 @@ TEST(Query, ExecWithPredicateSmallN) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "age": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Float
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
int64_t N = 177;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -268,7 +278,9 @@ TEST(Query, ExecWithPredicateSmallN) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 7, 1024);
|
||||
auto ph_group =
|
||||
|
@ -291,33 +303,59 @@ TEST(Query, ExecWithPredicate) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "age": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Float
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
int64_t N = ROW_COUNT;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -328,7 +366,9 @@ TEST(Query, ExecWithPredicate) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -375,33 +415,51 @@ TEST(Query, ExecTerm) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"age": {
|
||||
"values": [],
|
||||
"is_in_field": false
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "term": {
|
||||
// "age": {
|
||||
// "values": [],
|
||||
// "is_in_field": false
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
term_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Float
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
int64_t N = ROW_COUNT;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -412,7 +470,9 @@ TEST(Query, ExecTerm) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 3;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -435,28 +495,40 @@ TEST(Query, ExecEmpty) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 101
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
int64_t N = ROW_COUNT;
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -483,26 +555,38 @@ TEST(Query, ExecWithoutPredicateFlat) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
int64_t N = ROW_COUNT;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -535,26 +619,38 @@ TEST(Query, ExecWithoutPredicate) {
|
|||
schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "l2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal":3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "l2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal":3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
int64_t N = ROW_COUNT;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -609,32 +705,44 @@ TEST(Query, InnerProduct) {
|
|||
constexpr auto topk = 10;
|
||||
auto num_queries = 5;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"normalized": {
|
||||
"metric_type": "ip",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal":3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "vector": {
|
||||
// "normalized": {
|
||||
// "metric_type": "ip",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal":3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "IP"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto vec_fid = schema->AddDebugField(
|
||||
"normalized", DataType::VECTOR_FLOAT, dim, knowhere::metric::IP);
|
||||
auto i64_fid = schema->AddDebugField("age", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0,
|
||||
N,
|
||||
|
@ -734,26 +842,38 @@ TEST(Query, FillSegment) {
|
|||
return segment;
|
||||
}());
|
||||
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto ph_proto = CreatePlaceholderGroup(10, 16, 443);
|
||||
auto ph = ParsePlaceholderGroup(plan.get(), ph_proto.SerializeAsString());
|
||||
Timestamp ts = N * 2UL;
|
||||
|
@ -826,33 +946,59 @@ TEST(Query, ExecWithPredicateBinary) {
|
|||
auto float_fid = schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "JACCARD",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "age": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "JACCARD",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Float
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "JACCARD"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
int64_t N = ROW_COUNT;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
@ -864,7 +1010,9 @@ TEST(Query, ExecWithPredicateBinary) {
|
|||
dataset.raw_);
|
||||
auto vec_ptr = dataset.get_col<uint8_t>(vec_fid);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(
|
||||
num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);
|
||||
|
|
|
@ -37,25 +37,17 @@ TEST(Sealed, without_predicate) {
|
|||
auto float_fid = schema->AddDebugField("age", DataType::FLOAT);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
auto N = ROW_COUNT;
|
||||
|
||||
|
@ -73,7 +65,9 @@ TEST(Sealed, without_predicate) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
|
||||
|
@ -149,33 +143,59 @@ TEST(Sealed, with_predicate) {
|
|||
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"counter": {
|
||||
"GE": 4200,
|
||||
"LT": 4205
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 6
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "counter": {
|
||||
// "GE": 4200,
|
||||
// "LT": 4205
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 6
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Int64
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
int64_val: 4200
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 4205
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 6
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
auto N = ROW_COUNT;
|
||||
|
||||
|
@ -190,7 +210,9 @@ TEST(Sealed, with_predicate) {
|
|||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
|
||||
|
@ -262,40 +284,68 @@ TEST(Sealed, with_predicate_filter_all) {
|
|||
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
|
||||
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
|
||||
schema->set_primary_field_id(i64_fid);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"counter": {
|
||||
"GE": 4200,
|
||||
"LT": 4199
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 6
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "counter": {
|
||||
// "GE": 4200,
|
||||
// "LT": 4199
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 6
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Int64
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
int64_val: 4200
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 4199
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 6
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
auto N = ROW_COUNT;
|
||||
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto vec_col = dataset.get_col<float>(fake_id);
|
||||
auto query_ptr = vec_col.data() + BIAS * dim;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
|
||||
|
@ -395,36 +445,64 @@ TEST(Sealed, LoadFieldData) {
|
|||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"double": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "double": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 102
|
||||
data_type: Double
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
Timestamp time = 1000000;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -492,36 +570,64 @@ TEST(Sealed, LoadFieldDataMmap) {
|
|||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"double": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "double": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 102
|
||||
data_type: Double
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
Timestamp time = 1000000;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -583,36 +689,64 @@ TEST(Sealed, LoadScalarIndex) {
|
|||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"double": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "double": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 102
|
||||
data_type: Double
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
Timestamp time = 1000000;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -697,36 +831,64 @@ TEST(Sealed, Delete) {
|
|||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"double": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "double": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 102
|
||||
data_type: Double
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
Timestamp time = 1000000;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
@ -781,36 +943,64 @@ TEST(Sealed, OverlapDelete) {
|
|||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"double": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
// std::string dsl = R"({
|
||||
// "bool": {
|
||||
// "must": [
|
||||
// {
|
||||
// "range": {
|
||||
// "double": {
|
||||
// "GE": -1,
|
||||
// "LT": 1
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// "vector": {
|
||||
// "fakevec": {
|
||||
// "metric_type": "L2",
|
||||
// "params": {
|
||||
// "nprobe": 10
|
||||
// },
|
||||
// "query": "$0",
|
||||
// "topk": 5,
|
||||
// "round_decimal": 3
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// })";
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 102
|
||||
data_type: Double
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
float_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
float_val: 1
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
Timestamp time = 1000000;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
|
|
|
@ -381,7 +381,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
}
|
||||
|
||||
t.SearchRequest.Dsl = t.request.Dsl
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
|
||||
// Set username of this search request for feature like task scheduling.
|
||||
|
|
|
@ -59,70 +59,17 @@ const (
|
|||
|
||||
// ---------- unittest util functions ----------
|
||||
// functions of messages and requests
|
||||
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
nProbStr := strconv.Itoa(nProb)
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
}
|
||||
|
||||
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
efStr := strconv.Itoa(ef)
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
}
|
||||
|
||||
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
nProbStr := strconv.Itoa(defaultNProb)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
var fieldID int64
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
fieldID = f.FieldID
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
|
@ -134,24 +81,21 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
|
|||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
return `vector_anns: <
|
||||
field_id: ` + fmt.Sprintf("%d", fieldID) + `
|
||||
query_info: <
|
||||
topk: ` + topKStr + `
|
||||
round_decimal: ` + roundDecimalStr + `
|
||||
metric_type: "` + metricType + `"
|
||||
search_params: "{\"nprobe\": ` + nProbStr + `}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>`, nil
|
||||
}
|
||||
|
||||
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
|
||||
if indexType == IndexFaissIDMap { // float vector
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissBinIDMap {
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissIVFFlat {
|
||||
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissBinIVFFlat { // binary vector
|
||||
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexHNSW {
|
||||
return genHNSWDSL(schema, defaultEf, defaultTopK, defaultRoundDecimal)
|
||||
}
|
||||
return "", fmt.Errorf("Invalid indexType")
|
||||
}
|
||||
|
|
|
@ -955,18 +955,24 @@ func genSearchRequest(nq int64, indexType string, collection *Collection) (*inte
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
simpleDSL, err2 := genDSLByIndexType(collection.Schema(), indexType)
|
||||
planStr, err2 := genDSLByIndexType(collection.Schema(), indexType)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
var planpb planpb.PlanNode
|
||||
proto.UnmarshalText(planStr, &planpb)
|
||||
serializedPlan, err3 := proto.Marshal(&planpb)
|
||||
if err3 != nil {
|
||||
return nil, err3
|
||||
}
|
||||
return &internalpb.SearchRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_Search, 0),
|
||||
CollectionID: collection.ID(),
|
||||
PartitionIDs: collection.GetPartitions(),
|
||||
Dsl: simpleDSL,
|
||||
PlaceholderGroup: placeHolder,
|
||||
DslType: commonpb.DslType_Dsl,
|
||||
Nq: nq,
|
||||
Base: genCommonMsgBase(commonpb.MsgType_Search, 0),
|
||||
CollectionID: collection.ID(),
|
||||
PartitionIDs: collection.GetPartitions(),
|
||||
PlaceholderGroup: placeHolder,
|
||||
SerializedExprPlan: serializedPlan,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
Nq: nq,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -1012,12 +1018,6 @@ func genPlaceHolderGroup(nq int64) ([]byte, error) {
|
|||
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
|
||||
if indexType == IndexFaissIDMap { // float vector
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissBinIDMap {
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissIVFFlat {
|
||||
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissBinIVFFlat { // binary vector
|
||||
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexHNSW {
|
||||
return genHNSWDSL(schema, ef, defaultTopK, defaultRoundDecimal)
|
||||
}
|
||||
|
@ -1030,9 +1030,11 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
|
|||
topKStr := strconv.FormatInt(topK, 10)
|
||||
nProbStr := strconv.Itoa(defaultNProb)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
var fieldID int64
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
fieldID = f.FieldID
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
|
@ -1044,39 +1046,16 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
|
|||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
}
|
||||
|
||||
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
nProbStr := strconv.Itoa(nProb)
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
return `vector_anns: <
|
||||
field_id: ` + fmt.Sprintf("%d", fieldID) + `
|
||||
query_info: <
|
||||
topk: ` + topKStr + `
|
||||
round_decimal: ` + roundDecimalStr + `
|
||||
metric_type: "` + metricType + `"
|
||||
search_params: "{\"nprobe\": ` + nProbStr + `}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>`, nil
|
||||
}
|
||||
|
||||
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
|
||||
|
@ -1085,9 +1064,11 @@ func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDeci
|
|||
efStr := strconv.Itoa(ef)
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
var fieldID int64
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
fieldID = f.FieldID
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
|
@ -1099,11 +1080,16 @@ func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDeci
|
|||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
return `vector_anns: <
|
||||
field_id: ` + fmt.Sprintf("%d", fieldID) + `
|
||||
query_info: <
|
||||
topk: ` + topKStr + `
|
||||
round_decimal: ` + roundDecimalStr + `
|
||||
metric_type: "` + metricType + `"
|
||||
search_params: "{\"ef\": ` + efStr + `}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>`, nil
|
||||
}
|
||||
|
||||
func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) error {
|
||||
|
|
|
@ -31,7 +31,6 @@ import (
|
|||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
. "github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
@ -41,31 +40,6 @@ type SearchPlan struct {
|
|||
cSearchPlan C.CSearchPlan
|
||||
}
|
||||
|
||||
// createSearchPlan returns a new SearchPlan and error
|
||||
func createSearchPlan(col *Collection, dsl string, metricType string) (*SearchPlan, error) {
|
||||
if col.collectionPtr == nil {
|
||||
return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id))
|
||||
}
|
||||
|
||||
cDsl := C.CString(dsl)
|
||||
defer C.free(unsafe.Pointer(cDsl))
|
||||
var cPlan C.CSearchPlan
|
||||
status := C.CreateSearchPlan(col.collectionPtr, cDsl, &cPlan)
|
||||
|
||||
err1 := HandleCStatus(&status, "Create Plan failed")
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
|
||||
var newPlan = &SearchPlan{cSearchPlan: cPlan}
|
||||
if len(metricType) != 0 {
|
||||
newPlan.setMetricType(metricType)
|
||||
} else {
|
||||
newPlan.setMetricType(col.GetMetricType())
|
||||
}
|
||||
return newPlan, nil
|
||||
}
|
||||
|
||||
func createSearchPlanByExpr(col *Collection, expr []byte, metricType string) (*SearchPlan, error) {
|
||||
if col.collectionPtr == nil {
|
||||
return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id))
|
||||
|
@ -121,18 +95,10 @@ func NewSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh
|
|||
var err error
|
||||
var plan *SearchPlan
|
||||
metricType := req.GetReq().GetMetricType()
|
||||
if req.Req.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
expr := req.Req.SerializedExprPlan
|
||||
plan, err = createSearchPlanByExpr(collection, expr, metricType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dsl := req.Req.GetDsl()
|
||||
plan, err = createSearchPlan(collection, dsl, metricType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expr := req.Req.SerializedExprPlan
|
||||
plan, err = createSearchPlanByExpr(collection, expr, metricType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(placeholderGrp) == 0 {
|
||||
|
|
|
@ -17,17 +17,14 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
type PlanSuite struct {
|
||||
|
@ -53,19 +50,6 @@ func (suite *PlanSuite) TearDownTest() {
|
|||
DeleteCollection(suite.collection)
|
||||
}
|
||||
|
||||
func (suite *PlanSuite) TestPlanDSL() {
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
|
||||
plan, err := createSearchPlan(suite.collection, dslString, "")
|
||||
defer plan.delete()
|
||||
suite.NoError(err)
|
||||
suite.NotEqual(plan, nil)
|
||||
topk := plan.getTopK()
|
||||
suite.Equal(int(topk), 10)
|
||||
metricType := plan.getMetricType()
|
||||
suite.Equal(metricType, "L2")
|
||||
}
|
||||
|
||||
func (suite *PlanSuite) TestPlanCreateByExpr() {
|
||||
planNode := &planpb.PlanNode{
|
||||
OutputFieldIds: []int64{rowIDFieldID},
|
||||
|
@ -82,70 +66,8 @@ func (suite *PlanSuite) TestPlanFail() {
|
|||
id: -1,
|
||||
}
|
||||
|
||||
_, err := createSearchPlan(collection, "", "")
|
||||
_, err := createSearchPlanByExpr(collection, nil, "")
|
||||
suite.Error(err)
|
||||
|
||||
_, err = createSearchPlanByExpr(collection, nil, "")
|
||||
suite.Error(err)
|
||||
}
|
||||
|
||||
func (suite *PlanSuite) TestPlanPlaceholderGroup() {
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
plan, err := createSearchPlan(suite.collection, dslString, "")
|
||||
suite.NoError(err)
|
||||
suite.NotNil(plan)
|
||||
|
||||
var searchRawData1 []byte
|
||||
var searchRawData2 []byte
|
||||
var vec = generateFloatVectors(1, defaultDim)
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
searchRawData1 = append(searchRawData1, buf...)
|
||||
}
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
|
||||
searchRawData2 = append(searchRawData2, buf...)
|
||||
}
|
||||
placeholderValue := commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: [][]byte{searchRawData1, searchRawData2},
|
||||
}
|
||||
|
||||
placeholderGroup := commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{&placeholderValue},
|
||||
}
|
||||
|
||||
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
suite.Nil(err)
|
||||
holder, err := parseSearchRequest(plan, placeGroupByte)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(holder)
|
||||
numQueries := holder.getNumOfQuery()
|
||||
suite.Equal(int(numQueries), 2)
|
||||
|
||||
holder.Delete()
|
||||
}
|
||||
|
||||
func (suite *PlanSuite) TestPlanNewSearchRequest() {
|
||||
nq := int64(10)
|
||||
|
||||
iReq, _ := genSearchRequest(nq, IndexHNSW, suite.collection)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: iReq,
|
||||
DmlChannels: []string{"dml"},
|
||||
SegmentIDs: []int64{suite.segmentID},
|
||||
FromShardLeader: true,
|
||||
Scope: querypb.DataScope_Historical,
|
||||
}
|
||||
searchReq, err := NewSearchRequest(suite.collection, req, req.Req.GetPlaceholderGroup())
|
||||
suite.NoError(err)
|
||||
|
||||
suite.EqualValues(nq, searchReq.getNumOfQuery())
|
||||
|
||||
searchReq.Delete()
|
||||
}
|
||||
|
||||
func TestPlan(t *testing.T) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
storage "github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||
|
@ -145,9 +146,21 @@ func (suite *ReduceSuite) TestReduceAllFunc() {
|
|||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
|
||||
plan, err := createSearchPlan(suite.collection, dslString, "")
|
||||
planStr := `vector_anns: <
|
||||
field_id: 107
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 6
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>`
|
||||
var planpb planpb.PlanNode
|
||||
proto.UnmarshalText(planStr, &planpb)
|
||||
serializedPlan, err := proto.Marshal(&planpb)
|
||||
suite.NoError(err)
|
||||
plan, err := createSearchPlanByExpr(suite.collection, serializedPlan, "")
|
||||
suite.NoError(err)
|
||||
searchReq, err := parseSearchRequest(plan, placeGroupByte)
|
||||
searchReq.timestamp = 0
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
@ -33,6 +34,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
|
@ -910,7 +912,13 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
simpleDSL, err2 := genDSLByIndexType(schema, indexType)
|
||||
planStr, err := genDSLByIndexType(schema, indexType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var planpb planpb.PlanNode
|
||||
proto.UnmarshalText(planStr, &planpb)
|
||||
serializedPlan, err2 := proto.Marshal(&planpb)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
|
@ -920,12 +928,12 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema
|
|||
MsgID: rand.Int63(),
|
||||
TargetID: suite.node.session.ServerID,
|
||||
},
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
Dsl: simpleDSL,
|
||||
PlaceholderGroup: placeHolder,
|
||||
DslType: commonpb.DslType_Dsl,
|
||||
Nq: nq,
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: suite.partitionIDs,
|
||||
SerializedExprPlan: serializedPlan,
|
||||
PlaceholderGroup: placeHolder,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
Nq: nq,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue