mirror of https://github.com/milvus-io/milvus.git
Add SyntaxTree of QueryNode and Expr
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
9d212505d8
commit
9d2fa4e430
|
@ -1,7 +1,7 @@
|
|||
# TODO
|
||||
set(MILVUS_QUERY_SRCS
|
||||
BinaryQuery.cpp
|
||||
|
||||
Parser.cpp
|
||||
)
|
||||
add_library(milvus_query ${MILVUS_QUERY_SRCS})
|
||||
target_link_libraries(milvus_query libprotobuf)
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
#include <iostream>
|
||||
#include "pb/message.pb.h"
|
||||
#include "query/BooleanQuery.h"
|
||||
#include "query/BinaryQuery.h"
|
||||
#include "query/GeneralQuery.h"
|
||||
#include "segcore/SegmentBase.h"
|
||||
#include <random>
|
||||
#include "Parser.h"
|
||||
|
||||
namespace milvus::wtf {
|
||||
|
||||
using google::protobuf::RepeatedPtrField;
|
||||
using google::protobuf::RepeatedField;
|
||||
#if 0
|
||||
#if 0
|
||||
void
|
||||
CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records,
|
||||
const google::protobuf::RepeatedField<google::protobuf::int64>& grpc_id_array,
|
||||
engine::VectorsData& vectors) {
|
||||
CopyRowRecords(const RepeatedPtrField<proto::service::PlaceholderValue>& grpc_records,
|
||||
const RepeatedField<int64_t>& grpc_id_array,
|
||||
engine::VectorsData& vectors
|
||||
) {
|
||||
// step 1: copy vector data
|
||||
int64_t float_data_size = 0, binary_data_size = 0;
|
||||
|
||||
for (auto& record : grpc_records) {
|
||||
float_data_size += record.float_data_size();
|
||||
binary_data_size += record.binary_data().size();
|
||||
|
@ -47,9 +48,11 @@ CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRo
|
|||
vectors.binary_data_.swap(binary_array);
|
||||
vectors.id_array_.swap(id_array);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status
|
||||
ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& query, std::string& field_name) {
|
||||
#if 0
|
||||
if (query_json.contains("term")) {
|
||||
auto leaf_query = std::make_shared<query_old::LeafQuery>();
|
||||
auto term_query = std::make_shared<query_old::TermQuery>();
|
||||
|
@ -59,7 +62,6 @@ ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr&
|
|||
term_query->json_obj = json_obj;
|
||||
milvus::json::iterator json_it = json_obj.begin();
|
||||
field_name = json_it.key();
|
||||
|
||||
leaf_query->term_query = term_query;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
} else if (query_json.contains("range")) {
|
||||
|
@ -84,6 +86,7 @@ ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr&
|
|||
} else {
|
||||
return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"};
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -91,6 +94,7 @@ Status
|
|||
ProcessBooleanQueryJson(const milvus::json& query_json,
|
||||
query_old::BooleanQueryPtr& boolean_query,
|
||||
query_old::QueryPtr& query_ptr) {
|
||||
#if 0
|
||||
if (query_json.empty()) {
|
||||
return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"};
|
||||
}
|
||||
|
@ -163,15 +167,16 @@ ProcessBooleanQueryJson(const milvus::json& query_json,
|
|||
return Status{SERVER_INVALID_DSL_PARAMETER, msg};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
|
||||
DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params,
|
||||
const std::string& dsl_string,
|
||||
query_old::BooleanQueryPtr& boolean_query,
|
||||
query_old::QueryPtr& query_ptr) {
|
||||
#if 0
|
||||
try {
|
||||
milvus::json dsl_json = json::parse(dsl_string);
|
||||
|
||||
|
@ -231,5 +236,24 @@ test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vect
|
|||
} catch (std::exception& e) {
|
||||
return Status(SERVER_INVALID_DSL_PARAMETER, e.what());
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
query_old::QueryPtr tester(proto::service::Query* request) {
|
||||
query_old::BooleanQueryPtr boolean_query = std::make_shared<query_old::BooleanQuery>();
|
||||
query_old::QueryPtr query_ptr = std::make_shared<query_old::Query>();
|
||||
#if 0
|
||||
query_ptr->collection_id = request->collection_name();
|
||||
auto status = DeserializeJsonToBoolQuery(request->placeholders(), request->dsl(), boolean_query, query_ptr);
|
||||
status = query_old::ValidateBooleanQuery(boolean_query);
|
||||
query_old::GeneralQueryPtr general_query = std::make_shared<query_old::GeneralQuery>();
|
||||
query_old::GenBinaryQuery(boolean_query, general_query->bin);
|
||||
query_ptr->root = general_query;
|
||||
#endif
|
||||
return query_ptr;
|
||||
}
|
||||
|
||||
|
||||
} // namespace milvus::wtf
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
//#include "pb/message.pb.h"
|
||||
#include "pb/service_msg.pb.h"
|
||||
#include "query/BooleanQuery.h"
|
||||
#include "query/BinaryQuery.h"
|
||||
#include "query/GeneralQuery.h"
|
||||
|
||||
namespace milvus::wtf {
|
||||
|
||||
query_old::QueryPtr
|
||||
tester(proto::service::Query* query);
|
||||
|
||||
|
||||
|
||||
} // namespace milvus::wtf
|
|
@ -0,0 +1,82 @@
|
|||
#pragma once
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <any>
|
||||
#include <string>
|
||||
#include <optional>
|
||||
namespace milvus::query {
|
||||
class ExprVisitor;
|
||||
|
||||
// Base of all Exprs
|
||||
struct Expr {
|
||||
public:
|
||||
virtual ~Expr() = default;
|
||||
virtual void
|
||||
accept(ExprVisitor&) = 0;
|
||||
};
|
||||
|
||||
using ExprPtr = std::unique_ptr<Expr>;
|
||||
|
||||
struct BinaryExpr : Expr {
|
||||
ExprPtr left_;
|
||||
ExprPtr right_;
|
||||
public:
|
||||
void
|
||||
accept(ExprVisitor&) = 0;
|
||||
};
|
||||
|
||||
struct UnaryExpr : Expr {
|
||||
ExprPtr child_;
|
||||
public:
|
||||
void
|
||||
accept(ExprVisitor&) = 0;
|
||||
};
|
||||
|
||||
// TODO: not enabled in sprint 1
|
||||
struct BoolUnaryExpr: UnaryExpr {
|
||||
enum class OpType { LogicalNot };
|
||||
OpType op_type_;
|
||||
public:
|
||||
void
|
||||
accept(ExprVisitor&) override;
|
||||
};
|
||||
|
||||
|
||||
// TODO: not enabled in sprint 1
|
||||
struct BoolBinaryExpr : BinaryExpr {
|
||||
enum class OpType { LogicalAnd, LogicalOr, LogicalXor };
|
||||
OpType op_type_;
|
||||
public:
|
||||
void
|
||||
accept(ExprVisitor&) override;
|
||||
};
|
||||
|
||||
// // TODO: not enabled in sprint 1
|
||||
// struct ArthmeticBinaryOpExpr : BinaryExpr {
|
||||
// enum class OpType { Add, Sub, Multiply, Divide };
|
||||
// OpType op_type_;
|
||||
// public:
|
||||
// void
|
||||
// accept(ExprVisitor&) override;
|
||||
// };
|
||||
|
||||
using FieldId = int64_t;
|
||||
|
||||
struct TermExpr : Expr {
|
||||
FieldId field_id_;
|
||||
std::vector<std::any> terms_; //
|
||||
public:
|
||||
void
|
||||
accept(ExprVisitor&) override;
|
||||
};
|
||||
|
||||
struct RangeExpr : Expr {
|
||||
FieldId field_id_;
|
||||
enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual };
|
||||
std::vector<std::tuple<OpType, std::any>> conditions_;
|
||||
public:
|
||||
void
|
||||
accept(ExprVisitor&) override;
|
||||
};
|
||||
|
||||
} // namespace milvus::query
|
|
@ -0,0 +1,55 @@
|
|||
#pragma once
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <any>
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include "Predicate.h"
|
||||
namespace milvus::query {
|
||||
class QueryNodeVisitor;
|
||||
|
||||
enum class QueryNodeType {
|
||||
kInvalid = 0,
|
||||
kScan,
|
||||
kANNS,
|
||||
};
|
||||
|
||||
// Base of all Nodes
|
||||
struct QueryNode {
|
||||
QueryNodeType node_type;
|
||||
public:
|
||||
virtual ~QueryNode() = default;
|
||||
virtual void
|
||||
accept(QueryNodeVisitor&) = 0;
|
||||
};
|
||||
|
||||
using QueryNodePtr = std::unique_ptr<QueryNode>;
|
||||
|
||||
|
||||
struct VectorQueryNode : QueryNode {
|
||||
std::optional<QueryNodePtr> child_;
|
||||
int64_t num_queries_;
|
||||
int64_t dim_;
|
||||
FieldId field_id_;
|
||||
public:
|
||||
virtual void
|
||||
accept(QueryNodeVisitor&) = 0;
|
||||
};
|
||||
|
||||
struct FloatVectorANNS: VectorQueryNode {
|
||||
std::shared_ptr<float> data;
|
||||
std::string metric_type_; // TODO: use enum
|
||||
public:
|
||||
void
|
||||
accept(QueryNodeVisitor&) override;
|
||||
};
|
||||
|
||||
struct BinaryVectorANNS: VectorQueryNode {
|
||||
std::shared_ptr<uint8_t> data;
|
||||
std::string metric_type_; // TODO: use enum
|
||||
public:
|
||||
void
|
||||
accept(QueryNodeVisitor&) override;
|
||||
};
|
||||
|
||||
} // namespace milvus::query
|
|
@ -0,0 +1 @@
|
|||
#pragma once
|
|
@ -5,6 +5,7 @@
|
|||
#include "pb/message.pb.h"
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <cstring>
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -132,7 +133,7 @@ Collection::parse() {
|
|||
int dim = 16;
|
||||
for (const auto& type_param : type_params) {
|
||||
if (type_param.key() == "dim") {
|
||||
// dim = type_param.value();
|
||||
dim = strtoll(type_param.value().c_str(), nullptr, 10);
|
||||
}
|
||||
}
|
||||
std::cout << "add Field, name :" << child.name() << ", datatype :" << child.data_type() << ", dim :" << dim
|
||||
|
|
|
@ -8,6 +8,7 @@ set(MILVUS_TEST_FILES
|
|||
test_concurrent_vector.cpp
|
||||
test_c_api.cpp
|
||||
test_indexing.cpp
|
||||
test_query.cpp
|
||||
)
|
||||
add_executable(all_tests
|
||||
${MILVUS_TEST_FILES}
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
#!python
|
||||
import random
|
||||
import copy
|
||||
|
||||
def show_dsl(query_entities):
|
||||
if not isinstance(query_entities, (dict,)):
|
||||
raise ParamError("Invalid query format. 'query_entities' must be a dict")
|
||||
|
||||
duplicated_entities = copy.deepcopy(query_entities)
|
||||
vector_placeholders = dict()
|
||||
|
||||
def extract_vectors_param(param, placeholders):
|
||||
if not isinstance(param, (dict, list)):
|
||||
return
|
||||
|
||||
if isinstance(param, dict):
|
||||
if "vector" in param:
|
||||
# TODO: Here may not replace ph
|
||||
ph = "$" + str(len(placeholders))
|
||||
|
||||
for pk, pv in param["vector"].items():
|
||||
if "query" not in pv:
|
||||
raise ParamError("param vector must contain 'query'")
|
||||
placeholders[ph] = pv["query"]
|
||||
param["vector"][pk]["query"] = ph
|
||||
|
||||
return
|
||||
else:
|
||||
for _, v in param.items():
|
||||
extract_vectors_param(v, placeholders)
|
||||
|
||||
if isinstance(param, list):
|
||||
for item in param:
|
||||
extract_vectors_param(item, placeholders)
|
||||
|
||||
extract_vectors_param(duplicated_entities, vector_placeholders)
|
||||
print(duplicated_entities)
|
||||
|
||||
for tag, vectors in vector_placeholders.items():
|
||||
print("tag: ", tag)
|
||||
|
||||
if __name__ == "__main__":
|
||||
num = 5
|
||||
dimension = 4
|
||||
vectors = [[random.random() for _ in range(4)] for _ in range(num)]
|
||||
dsl = {
|
||||
"bool": {
|
||||
"must":[
|
||||
{
|
||||
"term": {"A": [1, 2, 5]}
|
||||
},
|
||||
{
|
||||
"range": {"B": {"GT": 1, "LT": 100}}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"Vec": {"topk": 10, "query": vectors[:1], "metric_type": "L2", "params": {"nprobe": 10}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
show_dsl(dsl)
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
TEST(TestNaive, Naive) {
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include "query/Parser.h"
|
||||
#include "query/Predicate.h"
|
||||
#include "query/QueryNode.h"
|
||||
|
||||
TEST(Query, Naive) {
|
||||
SUCCEED();
|
||||
using namespace milvus::wtf;
|
||||
std::string dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"A": [
|
||||
1,
|
||||
2,
|
||||
5
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"range": {
|
||||
"B": {
|
||||
"GT": 1,
|
||||
"LT": 100
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"Vec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
|
||||
}
|
Loading…
Reference in New Issue