Add queryNodeSegStatsMsg for msgStream

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/4973/head^2
xige-16 2020-11-16 10:55:49 +08:00 committed by yefu.chen
parent 23d9ddb4a3
commit d59f6ac6ca
24 changed files with 558 additions and 74 deletions

View File

@ -21,6 +21,7 @@ ENDFOREACH(proto_file)
add_library(milvus_proto STATIC
${MILVUS_PROTO_SRCS}
)
message(${MILVUS_PROTO_SRCS})
target_link_libraries(milvus_proto
libprotobuf

View File

@ -5,8 +5,9 @@ set(MILVUS_QUERY_SRCS
generated/Expr.cpp
visitors/ShowPlanNodeVisitor.cpp
visitors/ExecPlanNodeVisitor.cpp
visitors/ShowExprVisitor.cpp
Parser.cpp
Plan.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
target_link_libraries(milvus_query libprotobuf)
target_link_libraries(milvus_query milvus_proto)

View File

@ -4,6 +4,8 @@
#include <any>
#include <string>
#include <optional>
#include "segcore/SegmentDefs.h"
namespace milvus::query {
class ExprVisitor;
@ -58,7 +60,13 @@ using FieldId = std::string;
struct TermExpr : Expr {
FieldId field_id_;
std::vector<std::any> terms_; //
segcore::DataType data_type_;
// std::vector<std::any> terms_;
protected:
// prevent accidential instantiation
TermExpr() = default;
public:
void
accept(ExprVisitor&) override;
@ -66,12 +74,14 @@ struct TermExpr : Expr {
struct RangeExpr : Expr {
FieldId field_id_;
enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual };
std::vector<std::tuple<OpType, std::any>> conditions_;
segcore::DataType data_type_;
// std::vector<std::tuple<OpType, std::any>> conditions_;
protected:
// prevent accidential instantiation
RangeExpr() = default;
public:
void
accept(ExprVisitor&) override;
};
} // namespace milvus::query

View File

@ -0,0 +1,16 @@
#pragma once
#include "Expr.h"
namespace milvus::query {
template <typename T>
struct TermExprImpl : TermExpr {
std::vector<T> terms_;
};
template <typename T>
struct RangeExprImpl : RangeExpr {
enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual };
std::vector<std::tuple<OpType, T>> conditions_;
};
} // namespace milvus::query

View File

@ -23,7 +23,7 @@ CreateVec(const std::string& field_name, const json& vec_info) {
static std::unique_ptr<Plan>
CreatePlanImplNaive(const std::string& dsl_str) {
auto plan = std::unique_ptr<Plan>();
auto plan = std::make_unique<Plan>();
auto dsl = nlohmann::json::parse(dsl_str);
nlohmann::json vec_pack;
@ -36,17 +36,19 @@ CreatePlanImplNaive(const std::string& dsl_str) {
auto key = iter.key();
auto& body = iter.value();
plan->plan_node_ = CreateVec(key, body);
return plan;
}
}
PanicInfo("Unsupported DSL: vector node not detected");
} else if (bool_dsl.contains("vector")) {
auto iter = bool_dsl["vector"].begin();
auto key = iter.key();
auto& body = iter.value();
plan->plan_node_ = CreateVec(key, body);
return plan;
} else {
PanicInfo("Unsupported DSL: vector node not detected");
}
return plan;
}
void
@ -55,6 +57,7 @@ CheckNull(const Json& json) {
}
class PlanParser {
public:
void
ParseBoolBody(const Json& dsl) {
CheckNull(dsl);
@ -74,6 +77,8 @@ class PlanParser {
}
PanicInfo("unimplemented");
}
private:
};
std::unique_ptr<Plan>
@ -83,11 +88,12 @@ CreatePlan(const std::string& dsl_str) {
}
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const char* placeholder_group_blob) {
ParsePlaceholderGroup(const std::string& blob) {
namespace ser = milvus::proto::service;
auto result = std::unique_ptr<PlaceholderGroup>();
auto result = std::make_unique<PlaceholderGroup>();
ser::PlaceholderGroup ph_group;
GOOGLE_PROTOBUF_PARSER_ASSERT(ph_group.ParseFromString(placeholder_group_blob));
auto ok = ph_group.ParseFromString(blob);
Assert(ok);
for (auto& info : ph_group.placeholders()) {
Placeholder element;
element.tag_ = info.tag();

View File

@ -13,7 +13,7 @@ std::unique_ptr<Plan>
CreatePlan(const std::string& dsl);
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const char* placeholder_group_blob);
ParsePlaceholderGroup(const std::string& placeholder_group_blob);
int64_t
GetNumOfQueries(const PlaceholderGroup*);
@ -24,3 +24,5 @@ int64_t
GetTopK(const Plan*);
} // namespace milvus::query
#include "PlanImpl.h"

View File

@ -28,7 +28,6 @@ struct PlanNode {
using PlanNodePtr = std::unique_ptr<PlanNode>;
struct QueryInfo {
int64_t num_queries_;
int64_t topK_;
FieldId field_id_;
std::string metric_type_; // TODO: use enum

View File

@ -1,7 +1,11 @@
#pragma once
// Generated File
// DO NOT EDIT
#include "utils/Json.h"
#include "query/PlanImpl.h"
#include "segcore/SegmentBase.h"
#include "PlanNodeVisitor.h"
namespace milvus::query {
class ExecPlanNodeVisitor : PlanNodeVisitor {
public:

View File

@ -1,7 +1,11 @@
#pragma once
// Generated File
// DO NOT EDIT
#include "query/Plan.h"
#include "utils/EasyAssert.h"
#include "utils/Json.h"
#include "ExprVisitor.h"
namespace milvus::query {
class ShowExprVisitor : ExprVisitor {
public:
@ -18,5 +22,35 @@ class ShowExprVisitor : ExprVisitor {
visit(RangeExpr& expr) override;
public:
using RetType = Json;
public:
RetType
call_child(Expr& expr) {
assert(!ret_.has_value());
expr.accept(*this);
assert(ret_.has_value());
auto ret = std::move(ret_);
ret_ = std::nullopt;
return std::move(ret.value());
}
Json
combine(Json&& extra, UnaryExpr& expr) {
auto result = std::move(extra);
result["child"] = call_child(*expr.child_);
return result;
}
Json
combine(Json&& extra, BinaryExpr& expr) {
auto result = std::move(extra);
result["left_child"] = call_child(*expr.left_);
result["right_child"] = call_child(*expr.right_);
return result;
}
private:
std::optional<RetType> ret_;
};
} // namespace milvus::query

View File

@ -1,7 +1,12 @@
#pragma once
// Generated File
// DO NOT EDIT
#include "utils/EasyAssert.h"
#include "utils/Json.h"
#include <optional>
#include "PlanNodeVisitor.h"
namespace milvus::query {
class ShowPlanNodeVisitor : PlanNodeVisitor {
public:
@ -21,6 +26,7 @@ class ShowPlanNodeVisitor : PlanNodeVisitor {
node.accept(*this);
assert(ret_.has_value());
auto ret = std::move(ret_);
ret_ = std::nullopt;
return std::move(ret.value());
}

View File

@ -48,8 +48,10 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
AssertInfo(segment, "support SegmentSmallIndex Only");
RetType ret;
auto src_data = placeholder_group_.at(0).get_blob<float>();
segment->QueryBruteForceImpl(node.query_info_, src_data, timestamp_, ret);
auto& ph = placeholder_group_.at(0);
auto src_data = ph.get_blob<float>();
auto num_queries = ph.num_of_queries_;
segment->QueryBruteForceImpl(node.query_info_, src_data, num_queries, timestamp_, ret);
ret_ = ret;
}

View File

@ -0,0 +1,173 @@
#include "query/Plan.h"
#include "utils/EasyAssert.h"
#include "utils/Json.h"
#include "query/generated/ShowExprVisitor.h"
#include "query/ExprImpl.h"
namespace milvus::query {
using Json = nlohmann::json;
#if 1
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR
namespace impl {
class ShowExprNodeVisitor : ExprVisitor {
public:
using RetType = Json;
public:
RetType
call_child(Expr& expr) {
assert(!ret_.has_value());
expr.accept(*this);
assert(ret_.has_value());
auto ret = std::move(ret_);
ret_ = std::nullopt;
return std::move(ret.value());
}
Json
combine(Json&& extra, UnaryExpr& expr) {
auto result = std::move(extra);
result["child"] = call_child(*expr.child_);
return result;
}
Json
combine(Json&& extra, BinaryExpr& expr) {
auto result = std::move(extra);
result["left_child"] = call_child(*expr.left_);
result["right_child"] = call_child(*expr.right_);
return result;
}
private:
std::optional<RetType> ret_;
};
} // namespace impl
#endif
void
ShowExprVisitor::visit(BoolUnaryExpr& expr) {
Assert(!ret_.has_value());
using OpType = BoolUnaryExpr::OpType;
// TODO: use magic_enum if available
Assert(expr.op_type_ == OpType::LogicalNot);
auto op_name = "LogicalNot";
Json extra{
{"expr_type", "BoolUnary"},
{"op", op_name},
};
ret_ = this->combine(std::move(extra), expr);
}
void
ShowExprVisitor::visit(BoolBinaryExpr& expr) {
Assert(!ret_.has_value());
using OpType = BoolBinaryExpr::OpType;
// TODO: use magic_enum if available
auto op_name = [](OpType op) {
switch (op) {
case OpType::LogicalAnd:
return "LogicalAnd";
case OpType::LogicalOr:
return "LogicalOr";
case OpType::LogicalXor:
return "LogicalXor";
default:
PanicInfo("unsupported op");
}
}(expr.op_type_);
Json extra{
{"expr_type", "BoolBinary"},
{"op", op_name},
};
ret_ = this->combine(std::move(extra), expr);
}
template <typename T>
static Json
TermExtract(const TermExpr& expr_raw) {
auto expr = dynamic_cast<const TermExprImpl<T>*>(&expr_raw);
Assert(expr);
return Json{expr->terms_};
}
void
ShowExprVisitor::visit(TermExpr& expr) {
Assert(!ret_.has_value());
Assert(segcore::field_is_vector(expr.data_type_) == false);
using segcore::DataType;
auto terms = [&] {
switch (expr.data_type_) {
case DataType::INT8:
return TermExtract<int8_t>(expr);
case DataType::INT16:
return TermExtract<int16_t>(expr);
case DataType::INT32:
return TermExtract<int32_t>(expr);
case DataType::INT64:
return TermExtract<int64_t>(expr);
case DataType::DOUBLE:
return TermExtract<double>(expr);
case DataType::FLOAT:
return TermExtract<float>(expr);
case DataType::BOOL:
return TermExtract<bool>(expr);
default:
PanicInfo("unsupported type");
}
}();
Json res{{"expr_type", "Term"},
{"field_id", expr.field_id_},
{"data_type", segcore::datatype_name(expr.data_type_)},
{"terms", std::move(terms)}};
ret_ = res;
}
template <typename T>
static Json
CondtionExtract(const RangeExpr& expr_raw) {
auto expr = dynamic_cast<const TermExprImpl<T>*>(&expr_raw);
Assert(expr);
return Json{expr->terms_};
}
void
ShowExprVisitor::visit(RangeExpr& expr) {
Assert(!ret_.has_value());
Assert(segcore::field_is_vector(expr.data_type_) == false);
using segcore::DataType;
auto conditions = [&] {
switch (expr.data_type_) {
case DataType::BOOL:
return CondtionExtract<bool>(expr);
case DataType::INT8:
return CondtionExtract<int8_t>(expr);
case DataType::INT16:
return CondtionExtract<int16_t>(expr);
case DataType::INT32:
return CondtionExtract<int32_t>(expr);
case DataType::INT64:
return CondtionExtract<int64_t>(expr);
case DataType::DOUBLE:
return CondtionExtract<double>(expr);
case DataType::FLOAT:
return CondtionExtract<float>(expr);
default:
PanicInfo("unsupported type");
}
}();
Json res{{"expr_type", "Range"},
{"field_id", expr.field_id_},
{"data_type", segcore::datatype_name(expr.data_type_)},
{"conditions", std::move(conditions)}};
}
} // namespace milvus::query

View File

@ -19,6 +19,7 @@ class ShowPlanNodeVisitorImpl : PlanNodeVisitor {
node.accept(*this);
assert(ret_.has_value());
auto ret = std::move(ret_);
ret_ = std::nullopt;
return std::move(ret.value());
}
@ -40,11 +41,9 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
assert(!ret_);
auto& info = node.query_info_;
Json json_body{
{"node_type", "FloatVectorANNS"}, //
{"metric_type", info.metric_type_}, //
// {"dim", info.dim_}, //
{"node_type", "FloatVectorANNS"}, //
{"metric_type", info.metric_type_}, //
{"field_id_", info.field_id_}, //
{"num_queries", info.num_queries_}, //
{"topK", info.topK_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
@ -52,7 +51,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
if (node.predicate_.has_value()) {
PanicInfo("unimplemented");
} else {
json_body["predicate"] = "nullopt";
json_body["predicate"] = "None";
}
ret_ = json_body;
}

View File

@ -50,6 +50,36 @@ field_sizeof(DataType data_type, int dim = 1) {
}
}
// TODO: use magic_enum when available
inline std::string
datatype_name(DataType data_type) {
switch (data_type) {
case DataType::BOOL:
return "bool";
case DataType::DOUBLE:
return "double";
case DataType::FLOAT:
return "float";
case DataType::INT8:
return "int8_t";
case DataType::INT16:
return "int16_t";
case DataType::INT32:
return "int32_t";
case DataType::INT64:
return "int64_t";
case DataType::VECTOR_FLOAT:
return "vector_float";
case DataType::VECTOR_BINARY: {
return "vector_binary";
}
default: {
auto err_msg = "Unsupported DataType(" + std::to_string((int)data_type) + ")";
PanicInfo(err_msg);
}
}
}
inline bool
field_is_vector(DataType datatype) {
return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT;

View File

@ -223,6 +223,7 @@ get_barrier(const RecordType& record, Timestamp timestamp) {
Status
SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
QueryResult& results) {
// step 1: binary search to find the barrier of the snapshot
@ -247,7 +248,6 @@ SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto topK = info.topK_;
auto num_queries = info.num_queries_;
auto total_count = topK * num_queries;
// TODO: optimize
@ -321,7 +321,6 @@ SegmentSmallIndex::QueryDeprecated(query::QueryDeprecatedPtr query_info, Timesta
int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries;
// TODO
query::QueryInfo info{
query_info->num_queries,
query_info->topK,
query_info->field_name,
"L2",
@ -329,7 +328,8 @@ SegmentSmallIndex::QueryDeprecated(query::QueryDeprecatedPtr query_info, Timesta
{"nprobe", 10},
},
};
return QueryBruteForceImpl(info, query_info->query_raw_data.data(), timestamp, result);
auto num_queries = query_info->num_queries;
return QueryBruteForceImpl(info, query_info->query_raw_data.data(), num_queries, timestamp, result);
}
Status
@ -453,14 +453,15 @@ SegmentSmallIndex::GetMemoryUsageInBytes() {
}
Status
SegmentSmallIndex::Search(const query::Plan* Plan,
SegmentSmallIndex::Search(const query::Plan* plan,
const query::PlaceholderGroup** placeholder_groups,
const Timestamp* timestamps,
int num_groups,
QueryResult& results) {
Assert(num_groups == 1);
query::ExecPlanNodeVisitor visitor(*this, timestamps[0], *placeholder_groups[0]);
PanicInfo("unimplemented");
results = visitor.get_moved_result(*plan->plan_node_);
return Status::OK();
}
} // namespace milvus::segcore

View File

@ -143,6 +143,7 @@ class SegmentSmallIndex : public SegmentBase {
Status
QueryBruteForceImpl(const query::QueryInfo& info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
QueryResult& results);

View File

@ -9,6 +9,7 @@ set(MILVUS_TEST_FILES
test_c_api.cpp
test_indexing.cpp
test_query.cpp
test_expr.cpp
)
add_executable(all_tests
${MILVUS_TEST_FILES}

View File

@ -0,0 +1,70 @@
#include <gtest/gtest.h>
#include "query/Parser.h"
#include "query/Expr.h"
#include "query/PlanNode.h"
#include "query/generated/ExprVisitor.h"
#include "query/generated/PlanNodeVisitor.h"
#include "test_utils/DataGen.h"
#include "query/generated/ShowPlanNodeVisitor.h"
TEST(Expr, Naive) {
SUCCEED();
using namespace milvus::wtf;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"term": {
"A": [
1,
2,
5
]
}
},
{
"range": {
"B": {
"GT": 1,
"LT": 100
}
}
},
{
"vector": {
"Vec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
}
TEST(Expr, ShowExecutor) {
using namespace milvus::query;
using namespace milvus::segcore;
auto node = std::make_unique<FloatVectorANNS>();
auto schema = std::make_shared<Schema>();
int64_t num_queries = 100L;
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
auto raw_data = DataGen(schema, num_queries);
auto& info = node->query_info_;
info.metric_type_ = "L2";
info.topK_ = 20;
info.field_id_ = "fakevec";
node->predicate_ = std::nullopt;
ShowPlanNodeVisitor show_visitor;
PlanNodePtr base(node.release());
auto res = show_visitor.call_child(*base);
auto dup = res;
dup["data"] = "...collased...";
std::cout << dup.dump(4);
}

View File

@ -6,6 +6,8 @@
#include "query/generated/PlanNodeVisitor.h"
#include "test_utils/DataGen.h"
#include "query/generated/ShowPlanNodeVisitor.h"
#include "query/generated/ExecPlanNodeVisitor.h"
#include "query/PlanImpl.h"
TEST(Query, Naive) {
SUCCEED();
@ -58,7 +60,6 @@ TEST(Query, ShowExecutor) {
auto raw_data = DataGen(schema, num_queries);
auto& info = node->query_info_;
info.metric_type_ = "L2";
info.num_queries_ = 10;
info.topK_ = 20;
info.field_id_ = "fakevec";
node->predicate_ = std::nullopt;
@ -66,6 +67,87 @@ TEST(Query, ShowExecutor) {
PlanNodePtr base(node.release());
auto res = show_visitor.call_child(*base);
auto dup = res;
dup["data"] = "...collased...";
std::cout << dup.dump(4);
}
}
TEST(Query, DSL) {
using namespace milvus::query;
using namespace milvus::segcore;
ShowPlanNodeVisitor shower;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"vector": {
"Vec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto plan = CreatePlan(dsl_string);
auto res = shower.call_child(*plan->plan_node_);
std::cout << res.dump(4) << std::endl;
std::string dsl_string2 = R"(
{
"bool": {
"vector": {
"Vec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
})";
auto plan2 = CreatePlan(dsl_string2);
auto res2 = shower.call_child(*plan2->plan_node_);
std::cout << res2.dump(4) << std::endl;
ASSERT_EQ(res, res2);
}
TEST(Query, ParsePlaceholderGroup) {
using namespace milvus::query;
using namespace milvus::segcore;
namespace ser = milvus::proto::service;
int num_queries = 10;
int dim = 16;
std::default_random_engine e;
std::normal_distribution<double> dis(0, 1);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
for(int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for(int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
//ser::PlaceholderGroup new_group;
//new_group.ParseFromString()
auto fuck = ParsePlaceholderGroup(blob);
int x = 1+1;
}
TEST(Query, Exec) {
using namespace milvus::query;
using namespace milvus::segcore;
}

View File

@ -106,6 +106,16 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) *TsMsg {
TimeTickMsg: timeTickResult,
}
tsMsg = timeTickMsg
case internalPb.MsgType_kQueryNodeSegStats:
queryNodeSegStats := internalPb.QueryNodeSegStats{
MsgType: internalPb.MsgType_kQueryNodeSegStats,
PeerID: reqID,
}
queryNodeSegStatsMsg := &QueryNodeSegStatsMsg{
BaseMsg: baseMsg,
QueryNodeSegStats: queryNodeSegStats,
}
tsMsg = queryNodeSegStatsMsg
}
return &tsMsg
}
@ -452,24 +462,11 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
consumerChannels := []string{"insert1", "insert2"}
consumerSubName := "subInsert"
baseMsg := BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []int32{1},
}
timeTickRequest := internalPb.TimeTickMsg{
MsgType: internalPb.MsgType_kTimeTick,
PeerID: int64(1),
Timestamp: uint64(1),
}
timeTick := &TimeTickMsg{
BaseMsg: baseMsg,
TimeTickMsg: timeTickRequest,
}
var tsMsg TsMsg = timeTick
msgPack := MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, &tsMsg)
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kTimeTick, 1, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kSearch, 2, 2))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kSearchResult, 3, 3))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kQueryNodeSegStats, 4, 4))
inputStream := NewPulsarMsgStream(context.Background(), 100)
inputStream.SetPulsarCient(pulsarAddress)

View File

@ -57,24 +57,24 @@ func (it *InsertMsg) Marshal(input *TsMsg) ([]byte, error) {
func (it *InsertMsg) Unmarshal(input []byte) (*TsMsg, error) {
insertRequest := internalPb.InsertRequest{}
err := proto.Unmarshal(input, &insertRequest)
insertMsg := &InsertMsg{InsertRequest: insertRequest}
if err != nil {
return nil, err
}
insertMsg := &InsertMsg{InsertRequest: insertRequest}
for _, timestamp := range insertMsg.Timestamps {
it.BeginTimestamp = timestamp
it.EndTimestamp = timestamp
insertMsg.BeginTimestamp = timestamp
insertMsg.EndTimestamp = timestamp
break
}
for _, timestamp := range insertMsg.Timestamps {
if timestamp > it.EndTimestamp {
it.EndTimestamp = timestamp
if timestamp > insertMsg.EndTimestamp {
insertMsg.EndTimestamp = timestamp
}
if timestamp < it.BeginTimestamp {
it.BeginTimestamp = timestamp
if timestamp < insertMsg.BeginTimestamp {
insertMsg.BeginTimestamp = timestamp
}
}
var tsMsg TsMsg = insertMsg
return &tsMsg, nil
}
@ -102,24 +102,24 @@ func (dt *DeleteMsg) Marshal(input *TsMsg) ([]byte, error) {
func (dt *DeleteMsg) Unmarshal(input []byte) (*TsMsg, error) {
deleteRequest := internalPb.DeleteRequest{}
err := proto.Unmarshal(input, &deleteRequest)
deleteMsg := &DeleteMsg{DeleteRequest: deleteRequest}
if err != nil {
return nil, err
}
deleteMsg := &DeleteMsg{DeleteRequest: deleteRequest}
for _, timestamp := range deleteMsg.Timestamps {
dt.BeginTimestamp = timestamp
dt.EndTimestamp = timestamp
deleteMsg.BeginTimestamp = timestamp
deleteMsg.EndTimestamp = timestamp
break
}
for _, timestamp := range deleteMsg.Timestamps {
if timestamp > dt.EndTimestamp {
dt.EndTimestamp = timestamp
if timestamp > deleteMsg.EndTimestamp {
deleteMsg.EndTimestamp = timestamp
}
if timestamp < dt.BeginTimestamp {
dt.BeginTimestamp = timestamp
if timestamp < deleteMsg.BeginTimestamp {
deleteMsg.BeginTimestamp = timestamp
}
}
var tsMsg TsMsg = deleteMsg
return &tsMsg, nil
}
@ -147,13 +147,13 @@ func (st *SearchMsg) Marshal(input *TsMsg) ([]byte, error) {
func (st *SearchMsg) Unmarshal(input []byte) (*TsMsg, error) {
searchRequest := internalPb.SearchRequest{}
err := proto.Unmarshal(input, &searchRequest)
searchMsg := &SearchMsg{SearchRequest: searchRequest}
if err != nil {
return nil, err
}
st.BeginTimestamp = searchMsg.Timestamp
st.EndTimestamp = searchMsg.Timestamp
searchMsg := &SearchMsg{SearchRequest: searchRequest}
searchMsg.BeginTimestamp = searchMsg.Timestamp
searchMsg.EndTimestamp = searchMsg.Timestamp
var tsMsg TsMsg = searchMsg
return &tsMsg, nil
}
@ -181,13 +181,13 @@ func (srt *SearchResultMsg) Marshal(input *TsMsg) ([]byte, error) {
func (srt *SearchResultMsg) Unmarshal(input []byte) (*TsMsg, error) {
searchResultRequest := internalPb.SearchResult{}
err := proto.Unmarshal(input, &searchResultRequest)
searchResultMsg := &SearchResultMsg{SearchResult: searchResultRequest}
if err != nil {
return nil, err
}
srt.BeginTimestamp = searchResultMsg.Timestamp
srt.EndTimestamp = searchResultMsg.Timestamp
searchResultMsg := &SearchResultMsg{SearchResult: searchResultRequest}
searchResultMsg.BeginTimestamp = searchResultMsg.Timestamp
searchResultMsg.EndTimestamp = searchResultMsg.Timestamp
var tsMsg TsMsg = searchResultMsg
return &tsMsg, nil
}
@ -215,17 +215,49 @@ func (tst *TimeTickMsg) Marshal(input *TsMsg) ([]byte, error) {
func (tst *TimeTickMsg) Unmarshal(input []byte) (*TsMsg, error) {
timeTickMsg := internalPb.TimeTickMsg{}
err := proto.Unmarshal(input, &timeTickMsg)
timeTick := &TimeTickMsg{TimeTickMsg: timeTickMsg}
if err != nil {
return nil, err
}
tst.BeginTimestamp = timeTick.Timestamp
tst.EndTimestamp = timeTick.Timestamp
timeTick := &TimeTickMsg{TimeTickMsg: timeTickMsg}
timeTick.BeginTimestamp = timeTick.Timestamp
timeTick.EndTimestamp = timeTick.Timestamp
var tsMsg TsMsg = timeTick
return &tsMsg, nil
}
/////////////////////////////////////////QueryNodeSegStats//////////////////////////////////////////
type QueryNodeSegStatsMsg struct {
BaseMsg
internalPb.QueryNodeSegStats
}
func (qs *QueryNodeSegStatsMsg) Type() MsgType {
return qs.MsgType
}
func (qs *QueryNodeSegStatsMsg) Marshal(input *TsMsg) ([]byte, error) {
queryNodeSegStatsTask := (*input).(*QueryNodeSegStatsMsg)
queryNodeSegStats := &queryNodeSegStatsTask.QueryNodeSegStats
mb, err := proto.Marshal(queryNodeSegStats)
if err != nil {
return nil, err
}
return mb, nil
}
func (qs *QueryNodeSegStatsMsg) Unmarshal(input []byte) (*TsMsg, error) {
queryNodeSegStats := internalPb.QueryNodeSegStats{}
err := proto.Unmarshal(input, &queryNodeSegStats)
if err != nil {
return nil, err
}
queryNodeSegStatsMsg := &QueryNodeSegStatsMsg{QueryNodeSegStats: queryNodeSegStats}
var tsMsg TsMsg = queryNodeSegStatsMsg
return &tsMsg, nil
}
///////////////////////////////////////////Key2Seg//////////////////////////////////////////
//type Key2SegMsg struct {
// BaseMsg

View File

@ -30,12 +30,14 @@ func (dispatcher *UnmarshalDispatcher) addDefaultMsgTemplates() {
searchMsg := SearchMsg{}
searchResultMsg := SearchResultMsg{}
timeTickMsg := TimeTickMsg{}
queryNodeSegStatsMsg := QueryNodeSegStatsMsg{}
dispatcher.tempMap = make(map[internalPb.MsgType]UnmarshalFunc)
dispatcher.tempMap[internalPb.MsgType_kInsert] = insertMsg.Unmarshal
dispatcher.tempMap[internalPb.MsgType_kDelete] = deleteMsg.Unmarshal
dispatcher.tempMap[internalPb.MsgType_kSearch] = searchMsg.Unmarshal
dispatcher.tempMap[internalPb.MsgType_kSearchResult] = searchResultMsg.Unmarshal
dispatcher.tempMap[internalPb.MsgType_kTimeTick] = timeTickMsg.Unmarshal
dispatcher.tempMap[internalPb.MsgType_kQueryNodeSegStats] = queryNodeSegStatsMsg.Unmarshal
}
func NewUnmarshalDispatcher() *UnmarshalDispatcher {

View File

@ -14,7 +14,7 @@ def gen_file(rootfile, template, output, **kwargs):
def extract_extra_body(visitor_info, query_path):
pattern = re.compile("class(.*){\n((.|\n)*?)\n};", re.MULTILINE)
pattern = re.compile(r"class(.*){\n((.|\n)*?)\n};", re.MULTILINE)
for node, visitors in visitor_info.items():
for visitor in visitors:
@ -22,11 +22,24 @@ def extract_extra_body(visitor_info, query_path):
vis_file = query_path + "visitors/" + vis_name + ".cpp"
body = ' public:'
inc_pattern_str = r'^(#include(.|\n)*)\n#include "query/generated/{}.h"'.format(vis_name)
inc_pattern = re.compile(inc_pattern_str, re.MULTILINE)
if os.path.exists(vis_file):
infos = pattern.findall(readfile(vis_file))
content = readfile(vis_file)
infos = pattern.findall(content)
assert len(infos) <= 1
if len(infos) == 1:
name, body, _ = infos[0]
extra_inc_infos = inc_pattern.findall(content)
assert(len(extra_inc_infos) <= 1)
print(extra_inc_infos)
if len(extra_inc_infos) == 1:
extra_inc_body, _ = extra_inc_infos[0]
visitor["ctor_and_member"] = body
visitor["extra_inc"] = extra_inc_body
if __name__ == "__main__":
query_path = "../../internal/core/src/query/"

View File

@ -9,7 +9,9 @@
#pragma once
// Generated File
// DO NOT EDIT
@@extra_inc@@
#include "@@base_visitor@@.h"
namespace @@namespace@@ {
class @@visitor_name@@ : @@base_visitor@@ {
public: