Remove outdated searchplan (#25282)

Signed-off-by: Enwei Jiao <enwei.jiao@zilliz.com>
pull/25337/head
Enwei Jiao 2023-07-04 18:30:25 +08:00 committed by GitHub
parent 4a87b9f60a
commit 816158e4af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 2414 additions and 3170 deletions

View File

@ -21,7 +21,6 @@ set(MILVUS_QUERY_SRCS
visitors/VerifyExprVisitor.cpp
visitors/ExtractInfoPlanNodeVisitor.cpp
visitors/ExtractInfoExprVisitor.cpp
Parser.cpp
Plan.cpp
SearchOnGrowing.cpp
SearchOnSealed.cpp

View File

@ -1,514 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <boost/algorithm/string.hpp>
#include "ExprImpl.h"
#include "Parser.h"
#include "Plan.h"
#include "generated/ExtractInfoPlanNodeVisitor.h"
#include "generated/VerifyPlanNodeVisitor.h"
#include "pb/plan.pb.h"
#include "query/Expr.h"
namespace milvus::query {
template <typename Merger>
static ExprPtr
ConstructTree(Merger merger, std::vector<ExprPtr> item_list) {
if (item_list.size() == 0) {
return nullptr;
}
if (item_list.size() == 1) {
return std::move(item_list[0]);
}
// Note: use deque to construct a binary tree
// Op
// / \
// Op Op
// | \ | \
// A B C D
std::deque<ExprPtr> binary_queue;
for (auto& item : item_list) {
Assert(item != nullptr);
binary_queue.push_back(std::move(item));
}
while (binary_queue.size() > 1) {
auto left = std::move(binary_queue.front());
binary_queue.pop_front();
auto right = std::move(binary_queue.front());
binary_queue.pop_front();
binary_queue.push_back(merger(std::move(left), std::move(right)));
}
Assert(binary_queue.size() == 1);
return std::move(binary_queue.front());
}
ExprPtr
Parser::ParseCompareNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto op_name = boost::algorithm::to_lower_copy(std::string(out_iter.key()));
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
auto body = out_iter.value();
Assert(body.is_array());
Assert(body.size() == 2);
auto expr = std::make_unique<CompareExpr>();
expr->op_type_ = mapping_.at(op_name);
auto& item0 = body[0];
Assert(item0.is_string());
auto left_field_name = FieldName(item0.get<std::string>());
expr->left_data_type_ = schema[left_field_name].get_data_type();
expr->left_field_id_ = schema.get_field_id(left_field_name);
auto& item1 = body[1];
Assert(item1.is_string());
auto right_field_name = FieldName(item1.get<std::string>());
expr->right_data_type_ = schema[right_field_name].get_data_type();
expr->right_field_id_ = schema.get_field_id(right_field_name);
return expr;
}
ExprPtr
Parser::ParseRangeNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = FieldName(out_iter.key());
auto& body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!datatype_is_vector(data_type));
switch (data_type) {
case DataType::BOOL:
return ParseRangeNodeImpl<bool>(field_name, body);
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(field_name, body);
case DataType::INT16:
return ParseRangeNodeImpl<int16_t>(field_name, body);
case DataType::INT32:
return ParseRangeNodeImpl<int32_t>(field_name, body);
case DataType::INT64:
return ParseRangeNodeImpl<int64_t>(field_name, body);
case DataType::FLOAT:
return ParseRangeNodeImpl<float>(field_name, body);
case DataType::DOUBLE:
return ParseRangeNodeImpl<double>(field_name, body);
default:
PanicInfo("unsupported");
}
}
std::unique_ptr<Plan>
Parser::CreatePlanImpl(const Json& dsl) {
auto bool_dsl = dsl.at("bool");
auto predicate = ParseAnyNode(bool_dsl);
Assert(vector_node_opt_.has_value());
auto vec_node = std::move(vector_node_opt_).value();
if (predicate != nullptr) {
vec_node->predicate_ = std::move(predicate);
}
VerifyPlanNodeVisitor verifier;
vec_node->accept(verifier);
ExtractedPlanInfo plan_info(schema.size());
ExtractInfoPlanNodeVisitor extractor(plan_info);
vec_node->accept(extractor);
auto plan = std::make_unique<Plan>(schema);
plan->tag2field_ = std::move(tag2field_);
plan->plan_node_ = std::move(vec_node);
plan->extra_info_opt_ = std::move(plan_info);
return plan;
}
ExprPtr
Parser::ParseTermNode(const Json& out_body) {
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = FieldName(out_iter.key());
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!datatype_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
return ParseTermNodeImpl<bool>(field_name, body);
}
case DataType::INT8: {
return ParseTermNodeImpl<int8_t>(field_name, body);
}
case DataType::INT16: {
return ParseTermNodeImpl<int16_t>(field_name, body);
}
case DataType::INT32: {
return ParseTermNodeImpl<int32_t>(field_name, body);
}
case DataType::INT64: {
return ParseTermNodeImpl<int64_t>(field_name, body);
}
case DataType::FLOAT: {
return ParseTermNodeImpl<float>(field_name, body);
}
case DataType::DOUBLE: {
return ParseTermNodeImpl<double>(field_name, body);
}
default: {
PanicInfo("unsupported data_type");
}
}
}
std::unique_ptr<VectorPlanNode>
Parser::ParseVecNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto iter = out_body.begin();
auto field_name = FieldName(iter.key());
auto& vec_info = iter.value();
Assert(vec_info.is_object());
auto topk = vec_info["topk"];
AssertInfo(topk > 0, "topk must greater than 0");
AssertInfo(topk < 16384, "topk is too large");
auto field_id = schema.get_field_id(field_name);
auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
auto& field_meta = schema.operator[](field_name);
auto data_type = field_meta.get_data_type();
if (data_type == DataType::VECTOR_FLOAT) {
return std::make_unique<FloatVectorANNS>();
} else {
return std::make_unique<BinaryVectorANNS>();
}
}();
vec_node->search_info_.topk_ = topk;
vec_node->search_info_.metric_type_ = vec_info.at("metric_type");
vec_node->search_info_.search_params_ = vec_info.at("params");
vec_node->search_info_.field_id_ = field_id;
vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");
vec_node->placeholder_tag_ = vec_info.at("query");
auto tag = vec_node->placeholder_tag_;
AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");
tag2field_.emplace(tag, field_id);
return vec_node;
}
template <typename T>
ExprPtr
Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) {
Assert(body.is_object());
auto values = body["values"];
auto is_in_field = body["is_in_field"];
std::vector<T> terms(values.size());
auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
for (int i = 0; i < values.size(); i++) {
auto value = values[i];
if constexpr (std::is_same_v<T, bool>) {
Assert(value.is_boolean());
val_case = proto::plan::GenericValue::ValCase::kBoolVal;
} else if constexpr (std::is_integral_v<T>) {
Assert(value.is_number_integer());
val_case = proto::plan::GenericValue::ValCase::kInt64Val;
} else if constexpr (std::is_floating_point_v<T>) {
Assert(value.is_number());
val_case = proto::plan::GenericValue::ValCase::kFloatVal;
} else {
static_assert(always_false<T>, "unsupported type");
}
terms[i] = value;
}
std::sort(terms.begin(), terms.end());
return std::make_unique<TermExprImpl<T>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
terms,
val_case,
is_in_field);
}
template <typename T>
ExprPtr
Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
Assert(body.is_object());
if (body.size() == 1) {
auto item = body.begin();
auto op_name = boost::algorithm::to_lower_copy(std::string(item.key()));
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
// This is an expression with an arithmetic operation
if (item.value().is_object()) {
/* // This is the expected DSL expression
{
range: {
field_name: {
op: {
arith_op: {
right_operand: operand,
value: value
},
}
}
}
}
EXAMPLE:
{
range: {
field_name: {
"EQ": {
"ADD": {
right_operand: 10,
value: 25
},
}
}
}
}
*/
auto arith = item.value();
auto arith_body = arith.begin();
auto arith_op_name =
boost::algorithm::to_lower_copy(std::string(arith_body.key()));
AssertInfo(arith_op_mapping_.count(arith_op_name),
"arith op(" + arith_op_name + ") not found");
auto& arith_op_body = arith_body.value();
Assert(arith_op_body.is_object());
auto right_operand = arith_op_body["right_operand"];
auto value = arith_op_body["value"];
if constexpr (std::is_same_v<T, bool>) {
throw std::runtime_error("bool type is not supported");
} else if constexpr (std::is_integral_v<T>) {
Assert(right_operand.is_number_integer());
Assert(value.is_number_integer());
// see also: https://github.com/milvus-io/milvus/issues/23646.
return std::make_unique<
BinaryArithOpEvalRangeExprImpl<int64_t>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
proto::plan::GenericValue::ValCase::kInt64Val,
arith_op_mapping_.at(arith_op_name),
right_operand,
mapping_.at(op_name),
value);
} else if constexpr (std::is_floating_point_v<T>) {
Assert(right_operand.is_number());
Assert(value.is_number());
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
proto::plan::GenericValue::ValCase::kFloatVal,
arith_op_mapping_.at(arith_op_name),
right_operand,
mapping_.at(op_name),
value);
} else {
static_assert(always_false<T>, "unsupported type");
}
}
if constexpr (std::is_same_v<T, bool>) {
Assert(item.value().is_boolean());
return std::make_unique<UnaryRangeExprImpl<T>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
mapping_.at(op_name),
item.value(),
proto::plan::GenericValue::ValCase::kBoolVal);
} else if constexpr (std::is_integral_v<T>) {
Assert(item.value().is_number_integer());
// see also: https://github.com/milvus-io/milvus/issues/23646.
return std::make_unique<UnaryRangeExprImpl<int64_t>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
mapping_.at(op_name),
item.value(),
proto::plan::GenericValue::ValCase::kInt64Val);
} else if constexpr (std::is_floating_point_v<T>) {
Assert(item.value().is_number());
return std::make_unique<UnaryRangeExprImpl<T>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
mapping_.at(op_name),
item.value(),
proto::plan::GenericValue::ValCase::kFloatVal);
} else {
static_assert(always_false<T>, "unsupported type");
}
} else if (body.size() == 2) {
bool has_lower_value = false;
bool has_upper_value = false;
bool lower_inclusive = false;
bool upper_inclusive = false;
T lower_value;
T upper_value;
for (auto& item : body.items()) {
auto op_name =
boost::algorithm::to_lower_copy(std::string(item.key()));
AssertInfo(mapping_.count(op_name),
"op(" + op_name + ") not found");
if constexpr (std::is_same_v<T, bool>) {
Assert(item.value().is_boolean());
} else if constexpr (std::is_integral_v<T>) {
Assert(item.value().is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(item.value().is_number());
} else {
static_assert(always_false<T>, "unsupported type");
}
auto op = mapping_.at(op_name);
switch (op) {
case OpType::GreaterEqual:
lower_inclusive = true;
case OpType::GreaterThan:
lower_value = item.value();
has_lower_value = true;
break;
case OpType::LessEqual:
upper_inclusive = true;
case OpType::LessThan:
upper_value = item.value();
has_upper_value = true;
break;
default:
PanicInfo("unsupported operator in binary-range node");
}
}
AssertInfo(has_lower_value && has_upper_value,
"illegal binary-range node");
return std::make_unique<BinaryRangeExprImpl<T>>(
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
proto::plan::GenericValue::ValCase::VAL_NOT_SET,
lower_inclusive,
upper_inclusive,
lower_value,
upper_value);
} else {
PanicInfo("illegal range node, too more or too few ops");
}
}
std::vector<ExprPtr>
Parser::ParseItemList(const Json& body) {
std::vector<ExprPtr> results;
if (body.is_object()) {
// only one item;
auto new_expr = ParseAnyNode(body);
results.emplace_back(std::move(new_expr));
} else {
// item array
Assert(body.is_array());
for (auto& item : body) {
auto new_expr = ParseAnyNode(item);
results.emplace_back(std::move(new_expr));
}
}
auto old_size = results.size();
auto new_end =
std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) {
return x == nullptr;
});
results.resize(new_end - results.begin());
return results;
}
ExprPtr
Parser::ParseAnyNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto key = out_iter.key();
auto body = out_iter.value();
if (key == "must") {
return ParseMustNode(body);
} else if (key == "should") {
return ParseShouldNode(body);
} else if (key == "must_not") {
return ParseMustNotNode(body);
} else if (key == "range") {
return ParseRangeNode(body);
} else if (key == "term") {
return ParseTermNode(body);
} else if (key == "compare") {
return ParseCompareNode(body);
} else if (key == "vector") {
auto vec_node = ParseVecNode(body);
Assert(!vector_node_opt_.has_value());
vector_node_opt_ = std::move(vec_node);
return nullptr;
} else {
PanicInfo("unsupported key: " + key);
}
}
ExprPtr
Parser::ParseMustNode(const Json& body) {
auto item_list = ParseItemList(body);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = LogicalBinaryExpr::OpType;
return std::make_unique<LogicalBinaryExpr>(
OpType::LogicalAnd, left, right);
};
return ConstructTree(merger, std::move(item_list));
}
ExprPtr
Parser::ParseShouldNode(const Json& body) {
auto item_list = ParseItemList(body);
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = LogicalBinaryExpr::OpType;
return std::make_unique<LogicalBinaryExpr>(
OpType::LogicalOr, left, right);
};
return ConstructTree(merger, std::move(item_list));
}
ExprPtr
Parser::ParseMustNotNode(const Json& body) {
auto item_list = ParseItemList(body);
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = LogicalBinaryExpr::OpType;
return std::make_unique<LogicalBinaryExpr>(
OpType::LogicalAnd, left, right);
};
auto subtree = ConstructTree(merger, std::move(item_list));
using OpType = LogicalUnaryExpr::OpType;
return std::make_unique<LogicalUnaryExpr>(OpType::LogicalNot, subtree);
}
} // namespace milvus::query

View File

@ -1,92 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "Plan.h"
namespace milvus::query {
class Parser {
public:
friend std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string_view dsl_str);
private:
std::unique_ptr<Plan>
CreatePlanImpl(const Json& dsl);
explicit Parser(const Schema& schema) : schema(schema) {
}
// vector node parser, should be called exactly once per pass.
std::unique_ptr<VectorPlanNode>
ParseVecNode(const Json& out_body);
// Dispatcher of all parse function
// NOTE: when nullptr, it is a pure vector node
ExprPtr
ParseAnyNode(const Json& body);
ExprPtr
ParseMustNode(const Json& body);
ExprPtr
ParseShouldNode(const Json& body);
ExprPtr
ParseMustNotNode(const Json& body);
// parse the value of "should"/"must"/"must_not" entry
std::vector<ExprPtr>
ParseItemList(const Json& body);
// parse the value of "range" entry
ExprPtr
ParseRangeNode(const Json& out_body);
// parse the value of "term" entry
ExprPtr
ParseTermNode(const Json& out_body);
// parse the value of "term" entry
ExprPtr
ParseCompareNode(const Json& out_body);
private:
// template implementation of leaf parser
// used by corresponding parser
template <typename T>
ExprPtr
ParseRangeNodeImpl(const FieldName& field_name, const Json& body);
template <typename T>
ExprPtr
ParseTermNodeImpl(const FieldName& field_name, const Json& body);
private:
const Schema& schema;
std::map<std::string, FieldId> tag2field_; // PlaceholderName -> field id
std::optional<std::unique_ptr<VectorPlanNode>> vector_node_opt_;
};
} // namespace milvus::query

View File

@ -14,7 +14,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Parser.h"
#include "Plan.h"
#include "PlanProto.h"
#include "generated/ShowPlanNodeVisitor.h"
@ -63,14 +62,6 @@ ParsePlaceholderGroup(const Plan* plan,
return result;
}
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string_view dsl_str) {
Json dsl;
dsl = json::parse(dsl_str);
auto plan = Parser(schema).CreatePlanImpl(dsl);
return plan;
}
std::unique_ptr<Plan>
CreateSearchPlanByExpr(const Schema& schema,
const void* serialized_expr_plan,

View File

@ -26,9 +26,6 @@ struct Plan;
struct PlaceholderGroup;
struct RetrievePlan;
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string_view dsl);
// Note: serialized_expr_plan is of binary format
std::unique_ptr<Plan>
CreateSearchPlanByExpr(const Schema& schema,

View File

@ -15,34 +15,6 @@
#include "segcore/Collection.h"
#include "segcore/plan_c.h"
CStatus
CreateSearchPlan(CCollection c_col, const char* dsl, CSearchPlan* res_plan) {
auto col = (milvus::segcore::Collection*)c_col;
try {
auto res = milvus::query::CreatePlan(*col->get_schema(), dsl);
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto plan = (CSearchPlan)res.release();
*res_plan = plan;
return status;
} catch (milvus::SegcoreError& e) {
auto status = CStatus();
status.error_code = e.get_error_code();
status.error_msg = strdup(e.what());
*res_plan = nullptr;
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
*res_plan = nullptr;
return status;
}
}
// Note: serialized_expr_plan is of binary format
CStatus
CreateSearchPlanByExpr(CCollection c_col,

View File

@ -23,9 +23,6 @@ typedef void* CSearchPlan;
typedef void* CPlaceholderGroup;
typedef void* CRetrievePlan;
CStatus
CreateSearchPlan(CCollection col, const char* dsl, CSearchPlan* res_plan);
// Note: serialized_expr_plan is of binary format
CStatus
CreateSearchPlanByExpr(CCollection col,

View File

@ -32,7 +32,6 @@ set(MILVUS_TEST_FILES
test_index_wrapper.cpp
test_init.cpp
test_parquet_c.cpp
test_plan_proto.cpp
test_query.cpp
test_reduce.cpp
test_reduce_c.cpp

View File

@ -32,26 +32,19 @@ const auto schema = []() {
}();
const auto plan = [] {
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": -1
}
}
}
]
}
})";
auto plan = CreatePlan(*schema, dsl);
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 5
round_decimal: -1
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
return plan;
}();
auto ph_group = [] {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@
#include <gtest/gtest.h>
#include "pb/plan.pb.h"
#include "segcore/SegmentGrowing.h"
#include "segcore/SegmentGrowingImpl.h"
#include "pb/schema.pb.h"
@ -42,25 +43,36 @@ TEST(GrowingIndex, Correctness) {
std::make_shared<CollectionIndexMeta>(226985, std::move(filedMap));
auto segment = CreateGrowingSegment(schema, metaPtr);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"embeddings": {
"metric_type": "l2",
"params": {
"nprobe": 16
},
"query": "$0",
"topk": 5,
"round_decimal":3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "vector": {
// "embeddings": {
// "metric_type": "l2",
// "params": {
// "nprobe": 16
// },
// "query": "$0",
// "topk": 5,
// "round_decimal":3
// }
// }
// }
// ]
// }
// })";
milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
vector_anns->set_is_binary(false);
vector_anns->set_placeholder_tag("$0");
vector_anns->set_field_id(102);
auto query_info = vector_anns->mutable_query_info();
query_info->set_topk(5);
query_info->set_round_decimal(3);
query_info->set_metric_type("l2");
query_info->set_search_params(R"({"nprobe": 16})");
auto plan_str = plan_node.SerializeAsString();
int64_t per_batch = 10000;
int64_t n_batch = 20;
@ -75,7 +87,8 @@ TEST(GrowingIndex, Correctness) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = milvus::query::CreatePlan(*schema, dsl);
auto plan = milvus::query::CreateSearchPlanByExpr(
*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 128, 1024);
auto ph_group =

View File

@ -1,692 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <boost/format.hpp>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <queue>
#include <random>
#include <vector>
#include "pb/plan.pb.h"
#include "query/PlanProto.h"
#include "query/generated/ShowPlanNodeVisitor.h"
using namespace milvus;
using namespace milvus::query;
namespace planpb = proto::plan;
using std::string;
namespace spb = proto::schema;
static SchemaPtr
getStandardSchema() {
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"FloatVectorField", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
schema->AddDebugField("BinaryVectorField",
DataType::VECTOR_BINARY,
16,
knowhere::metric::JACCARD);
schema->AddDebugField("Int64Field", DataType::INT64);
schema->AddDebugField("Int32Field", DataType::INT32);
schema->AddDebugField("Int16Field", DataType::INT16);
schema->AddDebugField("Int8Field", DataType::INT8);
schema->AddDebugField("DoubleField", DataType::DOUBLE);
schema->AddDebugField("FloatField", DataType::FLOAT);
return schema;
}
class PlanProtoTest : public ::testing::TestWithParam<std::tuple<std::string>> {
public:
PlanProtoTest() {
schema = getStandardSchema();
}
protected:
SchemaPtr schema;
};
INSTANTIATE_TEST_CASE_P(InstName,
PlanProtoTest,
::testing::Values( //
std::make_tuple("DoubleField"), //
std::make_tuple("FloatField"), //
std::make_tuple("Int64Field"), //
std::make_tuple("Int32Field"), //
std::make_tuple("Int16Field"), //
std::make_tuple("Int8Field") //
));
TEST_P(PlanProtoTest, Range) {
// xxx.query(predicates = "int64field > 3", topk = 10, ...)
FieldName vec_field_name = FieldName("FloatVectorField");
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
auto field_name = std::get<0>(GetParam());
auto field_id = schema->get_field_id(FieldName(field_name));
auto data_type = schema->operator[](field_id).get_data_type();
auto data_type_str = spb::DataType_Name(int(data_type));
string value_tag = "bool_val";
if (datatype_is_floating(data_type)) {
value_tag = "float_val";
} else if (datatype_is_integer(data_type)) {
value_tag = "int64_val";
}
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: %1%
predicates: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
op: GreaterThan
value: <
%4%: 3
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % vec_float_field_id.get() %
field_id.get() % data_type_str % value_tag;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"range": {
"%1%": {
"GT": 3
}
}
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % field_name);
auto ref_plan = CreatePlan(*schema, dsl_text);
plan->check_identical(*ref_plan);
}
TEST_P(PlanProtoTest, TermExpr) {
// xxx.query(predicates = "int64field in [1, 2, 3]", topk = 10, ...)
FieldName vec_field_name = FieldName("FloatVectorField");
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
auto field_name = std::get<0>(GetParam());
auto field_id = schema->get_field_id(FieldName(field_name));
auto data_type = schema->operator[](field_id).get_data_type();
auto data_type_str = spb::DataType_Name(int(data_type));
string value_tag = "bool_val";
if (datatype_is_floating(data_type)) {
value_tag = "float_val";
} else if (datatype_is_integer(data_type)) {
value_tag = "int64_val";
}
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: %1%
predicates: <
term_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
values: <
%4%: 1
>
values: <
%4%: 2
>
values: <
%4%: 3
>
is_in_field : false
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % vec_float_field_id.get() %
field_id.get() % data_type_str % value_tag;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"term": {
"%1%": {
"values": [1,2,3],
"is_in_field" : false
}
}
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % field_name);
auto ref_plan = CreatePlan(*schema, dsl_text);
plan->check_identical(*ref_plan);
}
TEST(PlanProtoTest, NotExpr) {
auto schema = getStandardSchema();
// xxx.query(predicates = "not (int64field > 3)", topk = 10, ...)
FieldName vec_field_name = FieldName("FloatVectorField");
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
FieldName int64_field_name = FieldName("Int64Field");
FieldId int64_field_id = schema->get_field_id(int64_field_name);
string value_tag = "int64_val";
auto data_type = spb::DataType::Int64;
auto data_type_str = spb::DataType_Name(int(data_type));
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: %1%
predicates: <
unary_expr: <
op: Not
child: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
op: GreaterThan
value: <
%4%: 3
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % vec_float_field_id.get() %
int64_field_id.get() % data_type_str % value_tag;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"must_not": [{
"range": {
"%1%": {
"GT": 3
}
}
}]
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % int64_field_name.get());
auto ref_plan = CreatePlan(*schema, dsl_text);
auto ref_json = ShowPlanNodeVisitor().call_child(*ref_plan->plan_node_);
EXPECT_EQ(json.dump(2), ref_json.dump(2));
plan->check_identical(*ref_plan);
}
TEST(PlanProtoTest, AndOrExpr) {
auto schema = getStandardSchema();
// xxx.query(predicates = "(int64field < 3) && (int64field > 2 || int64field == 1)", topk = 10, ...)
FieldName vec_field_name = FieldName("FloatVectorField");
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
FieldName int64_field_name = FieldName("Int64Field");
FieldId int64_field_id = schema->get_field_id(int64_field_name);
string value_tag = "int64_val";
auto data_type = spb::DataType::Int64;
auto data_type_str = spb::DataType_Name(int(data_type));
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: %1%
predicates: <
binary_expr: <
op: LogicalAnd
left: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
op: LessThan
value: <
%4%: 3
>
>
>
right: <
binary_expr: <
op: LogicalOr
left: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
op: GreaterThan
value: <
%4%: 2
>
>
>
right: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
op: Equal
value: <
%4%: 1
>
>
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % vec_float_field_id.get() %
int64_field_id.get() % data_type_str % value_tag;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"range": {
"%1%": {
"LT": 3
}
}
},
{
"should": [
{
"range": {
"%1%": {
"GT": 2
}
}
},
{
"range": {
"%1%": {
"EQ": 1
}
}
}
]
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % int64_field_name.get());
auto ref_plan = CreatePlan(*schema, dsl_text);
auto ref_json = ShowPlanNodeVisitor().call_child(*ref_plan->plan_node_);
EXPECT_EQ(json.dump(2), ref_json.dump(2));
plan->check_identical(*ref_plan);
}
TEST_P(PlanProtoTest, CompareExpr) {
auto schema = getStandardSchema();
auto age_fid = schema->AddDebugField("age1", DataType::INT64);
// xxx.query(predicates = "int64field < int64field", topk = 10, ...)
FieldName vec_field_name = FieldName("FloatVectorField");
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
auto field_name = std::get<0>(GetParam());
auto field_id = schema->get_field_id(FieldName(field_name));
auto data_type = schema->operator[](field_id).get_data_type();
auto data_type_str = spb::DataType_Name(int(data_type));
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: %1%
predicates: <
compare_expr: <
left_column_info: <
field_id: %2%
data_type: Int64
>
right_column_info: <
field_id: %3%
data_type: %4%
>
op: LessThan
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % vec_float_field_id.get() %
age_fid.get() % field_id.get() % data_type_str;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"compare": {
"LT": [
"age1",
"%1%"
]
}
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % field_name);
auto ref_plan = CreatePlan(*schema, dsl_text);
plan->check_identical(*ref_plan);
}
TEST_P(PlanProtoTest, BinaryArithOpEvalRange) {
// xxx.query(predicates = "int64field > 3", topk = 10, ...)
// auto data_type = std::get<0>(GetParam());
// auto data_type_str = spb::DataType_Name(data_type);
// auto field_id = 100 + (int)data_type;
// auto field_name = data_type_str + "Field";
// string value_tag = "bool_val";
// if (datatype_is_floating((DataType)data_type)) {
// value_tag = "float_val";
// } else if (datatype_is_integer((DataType)data_type)) {
// value_tag = "int64_val";
// }
FieldName vec_field_name = FieldName("FloatVectorField");
FieldId vec_float_field_id = schema->get_field_id(vec_field_name);
auto field_name = std::get<0>(GetParam());
auto field_id = schema->get_field_id(FieldName(field_name));
auto data_type = schema->operator[](field_id).get_data_type();
auto data_type_str = spb::DataType_Name(int(data_type));
string value_tag = "bool_val";
if (datatype_is_floating(data_type)) {
value_tag = "float_val";
} else if (datatype_is_integer(data_type)) {
value_tag = "int64_val";
}
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: %1%
predicates: <
binary_arith_op_eval_range_expr: <
column_info: <
field_id: %2%
data_type: %3%
>
arith_op: Add
right_operand: <
%4%: 1029
>
op: Equal
value: <
%4%: 2016
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % vec_float_field_id.get() %
field_id.get() % data_type_str % value_tag;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"range": {
"%1%": {
"EQ": {
"ADD": {
"right_operand": 1029,
"value": 2016
}
}
}
}
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % field_name);
auto ref_plan = CreatePlan(*schema, dsl_text);
plan->check_identical(*ref_plan);
}
TEST(PlanProtoTest, Predicates) {
auto schema = getStandardSchema();
auto age_fid = schema->AddDebugField("age1", DataType::INT64);
planpb::PlanNode plan_node_proto;
auto expr =
plan_node_proto.mutable_predicates()->mutable_unary_range_expr();
expr->set_op(planpb::Equal);
auto column_info = expr->mutable_column_info();
column_info->set_data_type(proto::schema::DataType::Int64);
column_info->set_field_id(age_fid.get());
auto value = expr->mutable_value();
value->set_int64_val(1000);
std::string binary;
plan_node_proto.SerializeToString(&binary);
auto plan =
CreateRetrievePlanByExpr(*schema, binary.c_str(), binary.size());
ASSERT_TRUE(plan->plan_node_->predicate_.has_value());
ASSERT_FALSE(plan->plan_node_->is_count);
}

View File

@ -53,84 +53,40 @@ TEST(Query, ShowExecutor) {
std::cout << dup.dump(4);
}
TEST(Query, DSL) {
using namespace milvus::query;
using namespace milvus::segcore;
ShowPlanNodeVisitor shower;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto plan = CreatePlan(*schema, dsl_string);
auto res = shower.call_child(*plan->plan_node_);
std::cout << res.dump(4) << std::endl;
std::string dsl_string2 = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
auto plan2 = CreatePlan(*schema, dsl_string2);
auto res2 = shower.call_child(*plan2->plan_node_);
std::cout << res2.dump(4) << std::endl;
ASSERT_EQ(res, res2);
}
TEST(Query, ParsePlaceholderGroup) {
std::string dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal":3
}
}
}
})";
// std::string dsl_string = R"(
// {
// "bool": {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal":3
// }
// }
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto plan = CreatePlan(*schema, dsl_string);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
int64_t num_queries = 100000;
int dim = 16;
auto raw_group = CreatePlaceholderGroup(num_queries, dim);
@ -147,33 +103,59 @@ TEST(Query, ExecWithPredicateLoader) {
schema->AddDebugField("age", DataType::FLOAT);
auto counter_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(counter_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"age": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Float
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
int64_t N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -184,7 +166,9 @@ TEST(Query, ExecWithPredicateLoader) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -231,33 +215,59 @@ TEST(Query, ExecWithPredicateSmallN) {
schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"age": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Float
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
int64_t N = 177;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -268,7 +278,9 @@ TEST(Query, ExecWithPredicateSmallN) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 7, 1024);
auto ph_group =
@ -291,33 +303,59 @@ TEST(Query, ExecWithPredicate) {
schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"age": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Float
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
int64_t N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -328,7 +366,9 @@ TEST(Query, ExecWithPredicate) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -375,33 +415,51 @@ TEST(Query, ExecTerm) {
schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"term": {
"age": {
"values": [],
"is_in_field": false
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "term": {
// "age": {
// "values": [],
// "is_in_field": false
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
term_expr: <
column_info: <
field_id: 101
data_type: Float
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
int64_t N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -412,7 +470,9 @@ TEST(Query, ExecTerm) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 3;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -435,28 +495,40 @@ TEST(Query, ExecEmpty) {
schema->AddDebugField("age", DataType::FLOAT);
schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 101
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
int64_t N = ROW_COUNT;
auto segment = CreateGrowingSegment(schema, empty_index_meta);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -483,26 +555,38 @@ TEST(Query, ExecWithoutPredicateFlat) {
schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
auto plan = CreatePlan(*schema, dsl);
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
int64_t N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -535,26 +619,38 @@ TEST(Query, ExecWithoutPredicate) {
schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "l2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal":3
}
}
}
]
}
})";
auto plan = CreatePlan(*schema, dsl);
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "vector": {
// "fakevec": {
// "metric_type": "l2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal":3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
int64_t N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -609,32 +705,44 @@ TEST(Query, InnerProduct) {
constexpr auto topk = 10;
auto num_queries = 5;
auto schema = std::make_shared<Schema>();
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"normalized": {
"metric_type": "ip",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal":3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "vector": {
// "normalized": {
// "metric_type": "ip",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal":3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 5
round_decimal: 3
metric_type: "IP"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto vec_fid = schema->AddDebugField(
"normalized", DataType::VECTOR_FLOAT, dim, knowhere::metric::IP);
auto i64_fid = schema->AddDebugField("age", DataType::INT64);
schema->set_primary_field_id(i64_fid);
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
segment->PreInsert(N);
segment->Insert(0,
N,
@ -734,26 +842,38 @@ TEST(Query, FillSegment) {
return segment;
}());
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
auto plan = CreatePlan(*schema, dsl);
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto ph_proto = CreatePlaceholderGroup(10, 16, 443);
auto ph = ParsePlaceholderGroup(plan.get(), ph_proto.SerializeAsString());
Timestamp ts = N * 2UL;
@ -826,33 +946,59 @@ TEST(Query, ExecWithPredicateBinary) {
auto float_fid = schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"age": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "JACCARD",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "JACCARD",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Float
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "JACCARD"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
int64_t N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, empty_index_meta);
@ -864,7 +1010,9 @@ TEST(Query, ExecWithPredicateBinary) {
dataset.raw_);
auto vec_ptr = dataset.get_col<uint8_t>(vec_fid);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(
num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);

View File

@ -37,25 +37,17 @@ TEST(Sealed, without_predicate) {
auto float_fid = schema->AddDebugField("age", DataType::FLOAT);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto N = ROW_COUNT;
@ -73,7 +65,9 @@ TEST(Sealed, without_predicate) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw =
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
@ -149,33 +143,59 @@ TEST(Sealed, with_predicate) {
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"counter": {
"GE": 4200,
"LT": 4205
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 6
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "counter": {
// "GE": 4200,
// "LT": 4205
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 6
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Int64
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
int64_val: 4200
>
upper_value: <
int64_val: 4205
>
>
>
query_info: <
topk: 5
round_decimal: 6
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto N = ROW_COUNT;
@ -190,7 +210,9 @@ TEST(Sealed, with_predicate) {
dataset.timestamps_.data(),
dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw =
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
@ -262,40 +284,68 @@ TEST(Sealed, with_predicate_filter_all) {
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"counter": {
"GE": 4200,
"LT": 4199
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 6
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "counter": {
// "GE": 4200,
// "LT": 4199
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 6
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Int64
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
int64_val: 4200
>
upper_value: <
int64_val: 4199
>
>
>
query_info: <
topk: 5
round_decimal: 6
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<float>(fake_id);
auto query_ptr = vec_col.data() + BIAS * dim;
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw =
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
@ -395,36 +445,64 @@ TEST(Sealed, LoadFieldData) {
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto segment = CreateSealedSegment(schema);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"double": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "double": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 102
data_type: Double
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
Timestamp time = 1000000;
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -492,36 +570,64 @@ TEST(Sealed, LoadFieldDataMmap) {
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto segment = CreateSealedSegment(schema);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"double": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "double": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 102
data_type: Double
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
Timestamp time = 1000000;
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -583,36 +689,64 @@ TEST(Sealed, LoadScalarIndex) {
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto segment = CreateSealedSegment(schema);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"double": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "double": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 102
data_type: Double
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
Timestamp time = 1000000;
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -697,36 +831,64 @@ TEST(Sealed, Delete) {
auto fakevec = dataset.get_col<float>(fakevec_id);
auto segment = CreateSealedSegment(schema);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"double": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "double": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 102
data_type: Double
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
Timestamp time = 1000000;
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =
@ -781,36 +943,64 @@ TEST(Sealed, OverlapDelete) {
auto fakevec = dataset.get_col<float>(fakevec_id);
auto segment = CreateSealedSegment(schema);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"double": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
// std::string dsl = R"({
// "bool": {
// "must": [
// {
// "range": {
// "double": {
// "GE": -1,
// "LT": 1
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 5,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 102
data_type: Double
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
float_val: -1
>
upper_value: <
float_val: 1
>
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
Timestamp time = 1000000;
auto plan = CreatePlan(*schema, dsl);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group =

View File

@ -381,7 +381,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
t.SearchRequest.Dsl = t.request.Dsl
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
// Set username of this search request for feature like task scheduling.

View File

@ -59,70 +59,17 @@ const (
// ---------- unittest util functions ----------
// functions of messages and requests
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
nProbStr := strconv.Itoa(nProb)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
efStr := strconv.Itoa(ef)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
topKStr := strconv.FormatInt(topK, 10)
nProbStr := strconv.Itoa(defaultNProb)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
var fieldID int64
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
fieldID = f.FieldID
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
@ -134,24 +81,21 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
return `vector_anns: <
field_id: ` + fmt.Sprintf("%d", fieldID) + `
query_info: <
topk: ` + topKStr + `
round_decimal: ` + roundDecimalStr + `
metric_type: "` + metricType + `"
search_params: "{\"nprobe\": ` + nProbStr + `}"
>
placeholder_tag: "$0"
>`, nil
}
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
if indexType == IndexFaissIDMap { // float vector
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIDMap {
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissIVFFlat {
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIVFFlat { // binary vector
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexHNSW {
return genHNSWDSL(schema, defaultEf, defaultTopK, defaultRoundDecimal)
}
return "", fmt.Errorf("Invalid indexType")
}

View File

@ -955,18 +955,24 @@ func genSearchRequest(nq int64, indexType string, collection *Collection) (*inte
if err != nil {
return nil, err
}
simpleDSL, err2 := genDSLByIndexType(collection.Schema(), indexType)
planStr, err2 := genDSLByIndexType(collection.Schema(), indexType)
if err2 != nil {
return nil, err2
}
var planpb planpb.PlanNode
proto.UnmarshalText(planStr, &planpb)
serializedPlan, err3 := proto.Marshal(&planpb)
if err3 != nil {
return nil, err3
}
return &internalpb.SearchRequest{
Base: genCommonMsgBase(commonpb.MsgType_Search, 0),
CollectionID: collection.ID(),
PartitionIDs: collection.GetPartitions(),
Dsl: simpleDSL,
PlaceholderGroup: placeHolder,
DslType: commonpb.DslType_Dsl,
Nq: nq,
Base: genCommonMsgBase(commonpb.MsgType_Search, 0),
CollectionID: collection.ID(),
PartitionIDs: collection.GetPartitions(),
PlaceholderGroup: placeHolder,
SerializedExprPlan: serializedPlan,
DslType: commonpb.DslType_BoolExprV1,
Nq: nq,
}, nil
}
@ -1012,12 +1018,6 @@ func genPlaceHolderGroup(nq int64) ([]byte, error) {
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
if indexType == IndexFaissIDMap { // float vector
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIDMap {
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissIVFFlat {
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIVFFlat { // binary vector
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexHNSW {
return genHNSWDSL(schema, ef, defaultTopK, defaultRoundDecimal)
}
@ -1030,9 +1030,11 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
topKStr := strconv.FormatInt(topK, 10)
nProbStr := strconv.Itoa(defaultNProb)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
var fieldID int64
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
fieldID = f.FieldID
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
@ -1044,39 +1046,16 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
nProbStr := strconv.Itoa(nProb)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
return `vector_anns: <
field_id: ` + fmt.Sprintf("%d", fieldID) + `
query_info: <
topk: ` + topKStr + `
round_decimal: ` + roundDecimalStr + `
metric_type: "` + metricType + `"
search_params: "{\"nprobe\": ` + nProbStr + `}"
>
placeholder_tag: "$0"
>`, nil
}
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
@ -1085,9 +1064,11 @@ func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDeci
efStr := strconv.Itoa(ef)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
var fieldID int64
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
fieldID = f.FieldID
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
@ -1099,11 +1080,16 @@ func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDeci
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
return `vector_anns: <
field_id: ` + fmt.Sprintf("%d", fieldID) + `
query_info: <
topk: ` + topKStr + `
round_decimal: ` + roundDecimalStr + `
metric_type: "` + metricType + `"
search_params: "{\"ef\": ` + efStr + `}"
>
placeholder_tag: "$0"
>`, nil
}
func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) error {

View File

@ -31,7 +31,6 @@ import (
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
. "github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -41,31 +40,6 @@ type SearchPlan struct {
cSearchPlan C.CSearchPlan
}
// createSearchPlan returns a new SearchPlan and error
func createSearchPlan(col *Collection, dsl string, metricType string) (*SearchPlan, error) {
if col.collectionPtr == nil {
return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id))
}
cDsl := C.CString(dsl)
defer C.free(unsafe.Pointer(cDsl))
var cPlan C.CSearchPlan
status := C.CreateSearchPlan(col.collectionPtr, cDsl, &cPlan)
err1 := HandleCStatus(&status, "Create Plan failed")
if err1 != nil {
return nil, err1
}
var newPlan = &SearchPlan{cSearchPlan: cPlan}
if len(metricType) != 0 {
newPlan.setMetricType(metricType)
} else {
newPlan.setMetricType(col.GetMetricType())
}
return newPlan, nil
}
func createSearchPlanByExpr(col *Collection, expr []byte, metricType string) (*SearchPlan, error) {
if col.collectionPtr == nil {
return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id))
@ -121,18 +95,10 @@ func NewSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh
var err error
var plan *SearchPlan
metricType := req.GetReq().GetMetricType()
if req.Req.GetDslType() == commonpb.DslType_BoolExprV1 {
expr := req.Req.SerializedExprPlan
plan, err = createSearchPlanByExpr(collection, expr, metricType)
if err != nil {
return nil, err
}
} else {
dsl := req.Req.GetDsl()
plan, err = createSearchPlan(collection, dsl, metricType)
if err != nil {
return nil, err
}
expr := req.Req.SerializedExprPlan
plan, err = createSearchPlanByExpr(collection, expr, metricType)
if err != nil {
return nil, err
}
if len(placeholderGrp) == 0 {

View File

@ -17,17 +17,14 @@
package segments
import (
"math"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
)
type PlanSuite struct {
@ -53,19 +50,6 @@ func (suite *PlanSuite) TearDownTest() {
DeleteCollection(suite.collection)
}
func (suite *PlanSuite) TestPlanDSL() {
dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
plan, err := createSearchPlan(suite.collection, dslString, "")
defer plan.delete()
suite.NoError(err)
suite.NotEqual(plan, nil)
topk := plan.getTopK()
suite.Equal(int(topk), 10)
metricType := plan.getMetricType()
suite.Equal(metricType, "L2")
}
func (suite *PlanSuite) TestPlanCreateByExpr() {
planNode := &planpb.PlanNode{
OutputFieldIds: []int64{rowIDFieldID},
@ -82,70 +66,8 @@ func (suite *PlanSuite) TestPlanFail() {
id: -1,
}
_, err := createSearchPlan(collection, "", "")
_, err := createSearchPlanByExpr(collection, nil, "")
suite.Error(err)
_, err = createSearchPlanByExpr(collection, nil, "")
suite.Error(err)
}
func (suite *PlanSuite) TestPlanPlaceholderGroup() {
dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
plan, err := createSearchPlan(suite.collection, dslString, "")
suite.NoError(err)
suite.NotNil(plan)
var searchRawData1 []byte
var searchRawData2 []byte
var vec = generateFloatVectors(1, defaultDim)
for i, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
searchRawData1 = append(searchRawData1, buf...)
}
for i, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
searchRawData2 = append(searchRawData2, buf...)
}
placeholderValue := commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: [][]byte{searchRawData1, searchRawData2},
}
placeholderGroup := commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{&placeholderValue},
}
placeGroupByte, err := proto.Marshal(&placeholderGroup)
suite.Nil(err)
holder, err := parseSearchRequest(plan, placeGroupByte)
suite.NoError(err)
suite.NotNil(holder)
numQueries := holder.getNumOfQuery()
suite.Equal(int(numQueries), 2)
holder.Delete()
}
func (suite *PlanSuite) TestPlanNewSearchRequest() {
nq := int64(10)
iReq, _ := genSearchRequest(nq, IndexHNSW, suite.collection)
req := &querypb.SearchRequest{
Req: iReq,
DmlChannels: []string{"dml"},
SegmentIDs: []int64{suite.segmentID},
FromShardLeader: true,
Scope: querypb.DataScope_Historical,
}
searchReq, err := NewSearchRequest(suite.collection, req, req.Req.GetPlaceholderGroup())
suite.NoError(err)
suite.EqualValues(nq, searchReq.getNumOfQuery())
searchReq.Delete()
}
func TestPlan(t *testing.T) {

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
storage "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/initcore"
@ -145,9 +146,21 @@ func (suite *ReduceSuite) TestReduceAllFunc() {
log.Print("marshal placeholderGroup failed")
}
dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
plan, err := createSearchPlan(suite.collection, dslString, "")
planStr := `vector_anns: <
field_id: 107
query_info: <
topk: 10
round_decimal: 6
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>`
var planpb planpb.PlanNode
proto.UnmarshalText(planStr, &planpb)
serializedPlan, err := proto.Marshal(&planpb)
suite.NoError(err)
plan, err := createSearchPlanByExpr(suite.collection, serializedPlan, "")
suite.NoError(err)
searchReq, err := parseSearchRequest(plan, placeGroupByte)
searchReq.timestamp = 0

View File

@ -22,6 +22,7 @@ import (
"testing"
"github.com/cockroachdb/errors"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
@ -910,7 +912,13 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema
if err != nil {
return nil, err
}
simpleDSL, err2 := genDSLByIndexType(schema, indexType)
planStr, err := genDSLByIndexType(schema, indexType)
if err != nil {
return nil, err
}
var planpb planpb.PlanNode
proto.UnmarshalText(planStr, &planpb)
serializedPlan, err2 := proto.Marshal(&planpb)
if err2 != nil {
return nil, err2
}
@ -920,12 +928,12 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema
MsgID: rand.Int63(),
TargetID: suite.node.session.ServerID,
},
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
Dsl: simpleDSL,
PlaceholderGroup: placeHolder,
DslType: commonpb.DslType_Dsl,
Nq: nq,
CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs,
SerializedExprPlan: serializedPlan,
PlaceholderGroup: placeHolder,
DslType: commonpb.DslType_BoolExprV1,
Nq: nq,
}, nil
}