Enable Query Executor without predicates

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-11-10 13:17:31 +08:00 committed by yefu.chen
parent 62a59d094b
commit 5e67e5eb43
17 changed files with 134 additions and 58 deletions

View File

@ -54,7 +54,7 @@ struct BoolBinaryExpr : BinaryExpr {
accept(ExprVisitor&) override;
};
using FieldId = int64_t;
using FieldId = std::string;
struct TermExpr : Expr {
FieldId field_id_;

View File

@ -26,11 +26,17 @@ struct PlanNode {
using PlanNodePtr = std::unique_ptr<PlanNode>;
struct VectorPlanNode : PlanNode {
std::optional<ExprPtr> predicate_;
struct QueryInfo{
int64_t num_queries_;
int64_t dim_;
int64_t topK_;
FieldId field_id_;
std::string metric_type_; // TODO: use enum
};
struct VectorPlanNode : PlanNode {
std::optional<ExprPtr> predicate_;
QueryInfo query_info_;
public:
virtual void
@ -38,16 +44,12 @@ struct VectorPlanNode : PlanNode {
};
struct FloatVectorANNS : VectorPlanNode {
std::vector<float> data_;
std::string metric_type_; // TODO: use enum
public:
void
accept(PlanNodeVisitor&) override;
};
struct BinaryVectorANNS : VectorPlanNode {
std::vector<uint8_t> data_;
std::string metric_type_; // TODO: use enum
public:
void
accept(PlanNodeVisitor&) override;

View File

@ -12,5 +12,24 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
visit(BinaryVectorANNS& node) override;
public:
using RetType = segcore::QueryResult;
ExecPlanNodeVisitor(segcore::SegmentBase& segment, segcore::Timestamp timestamp, const float* src_data)
: segment_(segment), timestamp_(timestamp), src_data_(src_data) {
}
// using RetType = nlohmann::json;
RetType get_moved_result(){
assert(ret_.has_value());
auto ret = std::move(ret_).value();
ret_ = std::nullopt;
return ret;
}
private:
// std::optional<RetType> ret_;
segcore::SegmentBase& segment_;
segcore::Timestamp timestamp_;
const float* src_data_;
std::optional<RetType> ret_;
};
} // namespace milvus::query

View File

@ -1,4 +1,3 @@
#pragma once
// Generated File
// DO NOT EDIT
#include "query/Expr.h"

View File

@ -1,4 +1,3 @@
#pragma once
// Generated File
// DO NOT EDIT
#include "query/PlanNode.h"

View File

@ -1,9 +1,50 @@
#include "utils/Json.h"
#include "segcore/SegmentBase.h"
#include "query/generated/ExecPlanNodeVisitor.h"
#include "segcore/SegmentSmallIndex.h"
namespace milvus::query {
#if 1
namespace impl {
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
class ExecPlanNodeVisitor : PlanNodeVisitor {
public:
using RetType = segcore::QueryResult;
ExecPlanNodeVisitor(segcore::SegmentBase& segment, segcore::Timestamp timestamp, const float* src_data)
: segment_(segment), timestamp_(timestamp), src_data_(src_data) {
}
// using RetType = nlohmann::json;
RetType get_moved_result(PlanNode& node){
assert(!ret_.has_value());
node.accept(*this);
assert(ret_.has_value());
auto ret = std::move(ret_).value();
ret_ = std::nullopt;
return ret;
}
private:
// std::optional<RetType> ret_;
segcore::SegmentBase& segment_;
segcore::Timestamp timestamp_;
const float* src_data_;
std::optional<RetType> ret_;
};
} // namespace impl
#endif
void
ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
// TODO
// TODO: optimize here, remove the dynamic cast
assert(!ret_.has_value());
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
AssertInfo(segment, "support SegmentSmallIndex Only");
RetType ret;
segment->QueryBruteForceImpl(node.query_info_, src_data_, timestamp_, ret);
ret_ = ret;
}
void

View File

@ -38,18 +38,19 @@ void
ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
// std::vector<float> data(node.data_.get(), node.data_.get() + node.num_queries_ * node.dim_);
assert(!ret_);
auto& info = node.query_info_;
Json json_body{
{"node_type", "FloatVectorANNS"}, //
{"metric_type", node.metric_type_}, //
{"dim", node.dim_}, //
{"field_id_", node.field_id_}, //
{"num_queries", node.num_queries_}, //
{"data", node.data_}, //
{"metric_type", info.metric_type_}, //
{"dim", info.dim_}, //
{"field_id_", info.field_id_}, //
{"num_queries", info.num_queries_}, //
{"topK", info.topK_}, //
};
if (node.predicate_.has_value()) {
AssertInfo(false, "unimplemented");
} else {
json_body["predicate"] = "nullopt";
// json_body["predicate"] = "nullopt";
}
ret_ = json_body;
}

View File

@ -49,7 +49,7 @@ class SegmentBase {
// query contains metadata of
virtual Status
Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
QueryDeprecated(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
// // THIS FUNCTION IS REMOVED
// virtual Status

View File

@ -458,7 +458,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
}
Status
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
SegmentNaive::QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
// TODO: enable delete
// TODO: enable index
// TODO: remove mock

View File

@ -45,7 +45,7 @@ class SegmentNaive : public SegmentBase {
// query contains metadata of
Status
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
// stop receive insert requests
// will move data to immutable vector or something

View File

@ -1,13 +1,16 @@
#include <segcore/SegmentSmallIndex.h>
#include <random>
#include <algorithm>
#include <numeric>
#include <thread>
#include <queue>
#include "segcore/SegmentNaive.h"
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
#include <faiss/utils/distances.h>
#include "segcore/SegmentSmallIndex.h"
#include "query/PlanNode.h"
namespace milvus::segcore {
@ -251,8 +254,9 @@ merge_into(int64_t queries,
}
}
Status
SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) {
SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info, const float* query_data, Timestamp timestamp, QueryResult& results) {
// step 1: binary search to find the barrier of the snapshot
auto ins_barrier = get_barrier(record_, timestamp);
auto del_barrier = get_barrier(deleted_record_, timestamp);
@ -263,21 +267,23 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim
#endif
// step 2.1: get meta
auto& field = schema_->operator[](query_info->field_name);
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto topK = query_info->topK;
auto num_queries = query_info->num_queries;
auto total_count = topK * num_queries;
// TODO: optimize
// step 2.2: get which vector field to search
auto vecfield_offset_opt = schema_->get_offset(query_info->field_name);
auto vecfield_offset_opt = schema_->get_offset(info.field_id_);
Assert(vecfield_offset_opt.has_value());
auto vecfield_offset = vecfield_offset_opt.value();
Assert(vecfield_offset < record_.entity_vec_.size());
auto& field = schema_->operator[](vecfield_offset);
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_.at(vecfield_offset));
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto topK = info.topK_;
auto num_queries = info.num_queries_;
auto total_count = topK * num_queries;
// TODO: optimize
// step 3: small indexing search
std::vector<int64_t> final_uids(total_count, -1);
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
@ -308,7 +314,7 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim
auto src_data = vec_ptr->get_chunk(chunk_id).data();
auto nsize =
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf);
faiss::knn_L2sqr(query_data, src_data, dim, num_queries, nsize, &buf);
merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
}
@ -327,7 +333,7 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim
}
Status
SegmentSmallIndex::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
SegmentSmallIndex::QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
// TODO: enable delete
// TODO: enable index
// TODO: remove mock
@ -345,9 +351,16 @@ SegmentSmallIndex::Query(query::QueryPtr query_info, Timestamp timestamp, QueryR
x = dis(e);
}
}
int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries;
// TODO
return QueryBruteForceImpl(query_info, timestamp, result);
query::QueryInfo info {
query_info->num_queries,
inferred_dim,
query_info->topK,
query_info->field_name,
"L2"
};
return QueryBruteForceImpl(info, query_info->query_raw_data.data(), timestamp, result);
}
Status

View File

@ -6,6 +6,7 @@
#include <shared_mutex>
#include <knowhere/index/vector_index/VecIndex.h>
#include <query/PlanNode.h>
#include "AckResponder.h"
#include "ConcurrentVector.h"
@ -70,7 +71,7 @@ class SegmentSmallIndex : public SegmentBase {
// query contains metadata of
Status
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
// stop receive insert requests
// will move data to immutable vector or something
@ -125,21 +126,18 @@ class SegmentSmallIndex : public SegmentBase {
explicit SegmentSmallIndex(SchemaPtr schema) : schema_(schema), record_(*schema_), indexing_record_(*schema_) {
}
private:
// struct MutableRecord {
// ConcurrentVector<uint64_t> uids_;
// tbb::concurrent_vector<Timestamp> timestamps_;
// std::vector<tbb::concurrent_vector<float>> entity_vecs_;
//
// MutableRecord(int entity_size) : entity_vecs_(entity_size) {
// }
// };
public:
std::shared_ptr<DeletedRecord::TmpBitmap>
get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force = false);
// Status
// QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
Status
QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
QueryBruteForceImpl(const query::QueryInfo& info,
const float* query_data,
Timestamp timestamp,
QueryResult& results);
template <typename Type>
knowhere::IndexPtr

View File

@ -134,7 +134,7 @@ Search(CSegmentBase c_segment,
query_ptr->query_raw_data.resize(num_of_query_raw_data);
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
auto res = segment->Query(query_ptr, timestamp, query_result);
auto res = segment->QueryDeprecated(query_ptr, timestamp, query_result);
// result_ids and result_distances have been allocated memory in goLang,
// so we don't need to malloc here.

View File

@ -11,5 +11,5 @@ EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
}
#define AssertInfo(expr, info) impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
#define AssertInfo(expr, info) milvus::impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
#define Assert(expr) AssertInfo((expr), "")

View File

@ -56,14 +56,17 @@ TEST(Query, ShowExecutor) {
int64_t num_queries = 100L;
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
auto raw_data = DataGen(schema, num_queries);
node->data_ = raw_data.get_col<float>(0);
node->metric_type_ = "L2";
node->num_queries_ = 10;
node->dim_ = 16;
auto& info = node->query_info_;
info.metric_type_ = "L2";
info.num_queries_ = 10;
info.dim_ = 16;
info.topK_ = 20;
info.field_id_ = "fakevec";
node->predicate_ = std::nullopt;
ShowPlanNodeVisitor show_visitor;
PlanNodePtr base(node.release());
auto res = show_visitor.call_child(*base);
res["data"] = "...collased...";
std::cout << res.dump(4);
auto dup = res;
dup["data"] = "...collased...";
std::cout << dup.dump(4);
}

View File

@ -36,15 +36,17 @@ def meta_gen(content):
if len(pack) == 1:
pack.append(None)
struct_name, base_name = pack
if not base_name:
root_base = struct_name
body_res = body_pattern.findall(body)
if len(body_res) != 1:
continue
eprint(struct_name)
eprint(body_res)
eprint(body)
assert(false)
struct_name, base_name = pack
if not base_name:
root_base = struct_name
visitor_name, state = body_res[0]
assert(visitor_name == root_base)
if state.strip() == 'override':

View File

@ -6,7 +6,6 @@ void
####
@@@@main
#pragma once
// Generated File
// DO NOT EDIT
#include "query/@@root_base@@.h"