Add SyntaxTree of QueryNode and Expr

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-11-03 11:45:48 +08:00 committed by yefu.chen
parent 9d212505d8
commit 9d2fa4e430
11 changed files with 303 additions and 16 deletions

View File

@ -1,7 +1,7 @@
# TODO
set(MILVUS_QUERY_SRCS
BinaryQuery.cpp
Parser.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
target_link_libraries(milvus_query libprotobuf)

View File

@ -1,19 +1,20 @@
#include <iostream>
#include "pb/message.pb.h"
#include "query/BooleanQuery.h"
#include "query/BinaryQuery.h"
#include "query/GeneralQuery.h"
#include "segcore/SegmentBase.h"
#include <random>
#include "Parser.h"
namespace milvus::wtf {
using google::protobuf::RepeatedPtrField;
using google::protobuf::RepeatedField;
#if 0
#if 0
void
CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records,
const google::protobuf::RepeatedField<google::protobuf::int64>& grpc_id_array,
engine::VectorsData& vectors) {
CopyRowRecords(const RepeatedPtrField<proto::service::PlaceholderValue>& grpc_records,
const RepeatedField<int64_t>& grpc_id_array,
engine::VectorsData& vectors
) {
// step 1: copy vector data
int64_t float_data_size = 0, binary_data_size = 0;
for (auto& record : grpc_records) {
float_data_size += record.float_data_size();
binary_data_size += record.binary_data().size();
@ -47,9 +48,11 @@ CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRo
vectors.binary_data_.swap(binary_array);
vectors.id_array_.swap(id_array);
}
#endif
Status
ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& query, std::string& field_name) {
#if 0
if (query_json.contains("term")) {
auto leaf_query = std::make_shared<query_old::LeafQuery>();
auto term_query = std::make_shared<query_old::TermQuery>();
@ -59,7 +62,6 @@ ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr&
term_query->json_obj = json_obj;
milvus::json::iterator json_it = json_obj.begin();
field_name = json_it.key();
leaf_query->term_query = term_query;
query->AddLeafQuery(leaf_query);
} else if (query_json.contains("range")) {
@ -84,6 +86,7 @@ ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr&
} else {
return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"};
}
#endif
return Status::OK();
}
@ -91,6 +94,7 @@ Status
ProcessBooleanQueryJson(const milvus::json& query_json,
query_old::BooleanQueryPtr& boolean_query,
query_old::QueryPtr& query_ptr) {
#if 0
if (query_json.empty()) {
return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"};
}
@ -163,15 +167,16 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
}
}
#endif
return Status::OK();
}
Status
test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
const std::string& dsl_string,
query_old::BooleanQueryPtr& boolean_query,
query_old::QueryPtr& query_ptr) {
#if 0
try {
milvus::json dsl_json = json::parse(dsl_string);
@ -231,5 +236,24 @@ test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vect
} catch (std::exception& e) {
return Status(SERVER_INVALID_DSL_PARAMETER, e.what());
}
#endif
return Status::OK();
}
#endif
query_old::QueryPtr tester(proto::service::Query* request) {
query_old::BooleanQueryPtr boolean_query = std::make_shared<query_old::BooleanQuery>();
query_old::QueryPtr query_ptr = std::make_shared<query_old::Query>();
#if 0
query_ptr->collection_id = request->collection_name();
auto status = DeserializeJsonToBoolQuery(request->placeholders(), request->dsl(), boolean_query, query_ptr);
status = query_old::ValidateBooleanQuery(boolean_query);
query_old::GeneralQueryPtr general_query = std::make_shared<query_old::GeneralQuery>();
query_old::GenBinaryQuery(boolean_query, general_query->bin);
query_ptr->root = general_query;
#endif
return query_ptr;
}
} // namespace milvus::wtf

View File

@ -0,0 +1,15 @@
#pragma once
//#include "pb/message.pb.h"
#include "pb/service_msg.pb.h"
#include "query/BooleanQuery.h"
#include "query/BinaryQuery.h"
#include "query/GeneralQuery.h"
namespace milvus::wtf {
query_old::QueryPtr
tester(proto::service::Query* query);
} // namespace milvus::wtf

View File

@ -0,0 +1,82 @@
#pragma once
#include <memory>
#include <vector>
#include <any>
#include <string>
#include <optional>
namespace milvus::query {
class ExprVisitor;
// Base of all Exprs
struct Expr {
public:
virtual ~Expr() = default;
virtual void
accept(ExprVisitor&) = 0;
};
using ExprPtr = std::unique_ptr<Expr>;
struct BinaryExpr : Expr {
ExprPtr left_;
ExprPtr right_;
public:
void
accept(ExprVisitor&) = 0;
};
struct UnaryExpr : Expr {
ExprPtr child_;
public:
void
accept(ExprVisitor&) = 0;
};
// TODO: not enabled in sprint 1
struct BoolUnaryExpr: UnaryExpr {
enum class OpType { LogicalNot };
OpType op_type_;
public:
void
accept(ExprVisitor&) override;
};
// TODO: not enabled in sprint 1
struct BoolBinaryExpr : BinaryExpr {
enum class OpType { LogicalAnd, LogicalOr, LogicalXor };
OpType op_type_;
public:
void
accept(ExprVisitor&) override;
};
// // TODO: not enabled in sprint 1
// struct ArthmeticBinaryOpExpr : BinaryExpr {
// enum class OpType { Add, Sub, Multiply, Divide };
// OpType op_type_;
// public:
// void
// accept(ExprVisitor&) override;
// };
using FieldId = int64_t;
struct TermExpr : Expr {
FieldId field_id_;
std::vector<std::any> terms_; //
public:
void
accept(ExprVisitor&) override;
};
struct RangeExpr : Expr {
FieldId field_id_;
enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual };
std::vector<std::tuple<OpType, std::any>> conditions_;
public:
void
accept(ExprVisitor&) override;
};
} // namespace milvus::query

View File

@ -0,0 +1,55 @@
#pragma once
#include <memory>
#include <vector>
#include <any>
#include <string>
#include <optional>
#include "Predicate.h"
namespace milvus::query {
class QueryNodeVisitor;
enum class QueryNodeType {
kInvalid = 0,
kScan,
kANNS,
};
// Base of all Nodes
struct QueryNode {
QueryNodeType node_type;
public:
virtual ~QueryNode() = default;
virtual void
accept(QueryNodeVisitor&) = 0;
};
using QueryNodePtr = std::unique_ptr<QueryNode>;
struct VectorQueryNode : QueryNode {
std::optional<QueryNodePtr> child_;
int64_t num_queries_;
int64_t dim_;
FieldId field_id_;
public:
virtual void
accept(QueryNodeVisitor&) = 0;
};
struct FloatVectorANNS: VectorQueryNode {
std::shared_ptr<float> data;
std::string metric_type_; // TODO: use enum
public:
void
accept(QueryNodeVisitor&) override;
};
struct BinaryVectorANNS: VectorQueryNode {
std::shared_ptr<uint8_t> data;
std::string metric_type_; // TODO: use enum
public:
void
accept(QueryNodeVisitor&) override;
};
} // namespace milvus::query

View File

@ -0,0 +1 @@
#pragma once

View File

@ -5,6 +5,7 @@
#include "pb/message.pb.h"
#include <google/protobuf/text_format.h>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <cstring>
namespace milvus::segcore {
@ -132,7 +133,7 @@ Collection::parse() {
int dim = 16;
for (const auto& type_param : type_params) {
if (type_param.key() == "dim") {
// dim = type_param.value();
dim = strtoll(type_param.value().c_str(), nullptr, 10);
}
}
std::cout << "add Field, name :" << child.name() << ", datatype :" << child.data_type() << ", dim :" << dim

View File

@ -8,6 +8,7 @@ set(MILVUS_TEST_FILES
test_concurrent_vector.cpp
test_c_api.cpp
test_indexing.cpp
test_query.cpp
)
add_executable(all_tests
${MILVUS_TEST_FILES}

View File

@ -0,0 +1,64 @@
#!python
import random
import copy
def show_dsl(query_entities):
if not isinstance(query_entities, (dict,)):
raise ParamError("Invalid query format. 'query_entities' must be a dict")
duplicated_entities = copy.deepcopy(query_entities)
vector_placeholders = dict()
def extract_vectors_param(param, placeholders):
if not isinstance(param, (dict, list)):
return
if isinstance(param, dict):
if "vector" in param:
# TODO: Here may not replace ph
ph = "$" + str(len(placeholders))
for pk, pv in param["vector"].items():
if "query" not in pv:
raise ParamError("param vector must contain 'query'")
placeholders[ph] = pv["query"]
param["vector"][pk]["query"] = ph
return
else:
for _, v in param.items():
extract_vectors_param(v, placeholders)
if isinstance(param, list):
for item in param:
extract_vectors_param(item, placeholders)
extract_vectors_param(duplicated_entities, vector_placeholders)
print(duplicated_entities)
for tag, vectors in vector_placeholders.items():
print("tag: ", tag)
if __name__ == "__main__":
num = 5
dimension = 4
vectors = [[random.random() for _ in range(4)] for _ in range(num)]
dsl = {
"bool": {
"must":[
{
"term": {"A": [1, 2, 5]}
},
{
"range": {"B": {"GT": 1, "LT": 100}}
},
{
"vector": {
"Vec": {"topk": 10, "query": vectors[:1], "metric_type": "L2", "params": {"nprobe": 10}}
}
}
]
}
}
show_dsl(dsl)

View File

@ -1,5 +1,3 @@
#include <gtest/gtest.h>
TEST(TestNaive, Naive) {

View File

@ -0,0 +1,46 @@
#include <gtest/gtest.h>
#include "query/Parser.h"
#include "query/Predicate.h"
#include "query/QueryNode.h"
TEST(Query, Naive) {
SUCCEED();
using namespace milvus::wtf;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"term": {
"A": [
1,
2,
5
]
}
},
{
"range": {
"B": {
"GT": 1,
"LT": 100
}
}
},
{
"vector": {
"Vec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
}