mirror of https://github.com/milvus-io/milvus.git
Enable Query Executor without predicates
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
62a59d094b
commit
5e67e5eb43
|
@ -54,7 +54,7 @@ struct BoolBinaryExpr : BinaryExpr {
|
|||
accept(ExprVisitor&) override;
|
||||
};
|
||||
|
||||
using FieldId = int64_t;
|
||||
using FieldId = std::string;
|
||||
|
||||
struct TermExpr : Expr {
|
||||
FieldId field_id_;
|
||||
|
|
|
@ -26,11 +26,17 @@ struct PlanNode {
|
|||
|
||||
using PlanNodePtr = std::unique_ptr<PlanNode>;
|
||||
|
||||
struct VectorPlanNode : PlanNode {
|
||||
std::optional<ExprPtr> predicate_;
|
||||
struct QueryInfo{
|
||||
int64_t num_queries_;
|
||||
int64_t dim_;
|
||||
int64_t topK_;
|
||||
FieldId field_id_;
|
||||
std::string metric_type_; // TODO: use enum
|
||||
};
|
||||
|
||||
struct VectorPlanNode : PlanNode {
|
||||
std::optional<ExprPtr> predicate_;
|
||||
QueryInfo query_info_;
|
||||
|
||||
public:
|
||||
virtual void
|
||||
|
@ -38,16 +44,12 @@ struct VectorPlanNode : PlanNode {
|
|||
};
|
||||
|
||||
struct FloatVectorANNS : VectorPlanNode {
|
||||
std::vector<float> data_;
|
||||
std::string metric_type_; // TODO: use enum
|
||||
public:
|
||||
void
|
||||
accept(PlanNodeVisitor&) override;
|
||||
};
|
||||
|
||||
struct BinaryVectorANNS : VectorPlanNode {
|
||||
std::vector<uint8_t> data_;
|
||||
std::string metric_type_; // TODO: use enum
|
||||
public:
|
||||
void
|
||||
accept(PlanNodeVisitor&) override;
|
||||
|
|
|
@ -12,5 +12,24 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
|
|||
visit(BinaryVectorANNS& node) override;
|
||||
|
||||
public:
|
||||
using RetType = segcore::QueryResult;
|
||||
ExecPlanNodeVisitor(segcore::SegmentBase& segment, segcore::Timestamp timestamp, const float* src_data)
|
||||
: segment_(segment), timestamp_(timestamp), src_data_(src_data) {
|
||||
}
|
||||
// using RetType = nlohmann::json;
|
||||
|
||||
RetType get_moved_result(){
|
||||
assert(ret_.has_value());
|
||||
auto ret = std::move(ret_).value();
|
||||
ret_ = std::nullopt;
|
||||
return ret;
|
||||
}
|
||||
private:
|
||||
// std::optional<RetType> ret_;
|
||||
segcore::SegmentBase& segment_;
|
||||
segcore::Timestamp timestamp_;
|
||||
const float* src_data_;
|
||||
|
||||
std::optional<RetType> ret_;
|
||||
};
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#pragma once
|
||||
// Generated File
|
||||
// DO NOT EDIT
|
||||
#include "query/Expr.h"
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#pragma once
|
||||
// Generated File
|
||||
// DO NOT EDIT
|
||||
#include "query/PlanNode.h"
|
||||
|
|
|
@ -1,9 +1,50 @@
|
|||
#include "utils/Json.h"
|
||||
#include "segcore/SegmentBase.h"
|
||||
#include "query/generated/ExecPlanNodeVisitor.h"
|
||||
#include "segcore/SegmentSmallIndex.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
#if 1
|
||||
namespace impl {
|
||||
// THIS CONTAINS EXTRA BODY FOR VISITOR
|
||||
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
|
||||
class ExecPlanNodeVisitor : PlanNodeVisitor {
|
||||
public:
|
||||
using RetType = segcore::QueryResult;
|
||||
ExecPlanNodeVisitor(segcore::SegmentBase& segment, segcore::Timestamp timestamp, const float* src_data)
|
||||
: segment_(segment), timestamp_(timestamp), src_data_(src_data) {
|
||||
}
|
||||
// using RetType = nlohmann::json;
|
||||
|
||||
RetType get_moved_result(PlanNode& node){
|
||||
assert(!ret_.has_value());
|
||||
node.accept(*this);
|
||||
assert(ret_.has_value());
|
||||
auto ret = std::move(ret_).value();
|
||||
ret_ = std::nullopt;
|
||||
return ret;
|
||||
}
|
||||
private:
|
||||
// std::optional<RetType> ret_;
|
||||
segcore::SegmentBase& segment_;
|
||||
segcore::Timestamp timestamp_;
|
||||
const float* src_data_;
|
||||
|
||||
std::optional<RetType> ret_;
|
||||
};
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
void
|
||||
ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
||||
// TODO
|
||||
// TODO: optimize here, remove the dynamic cast
|
||||
assert(!ret_.has_value());
|
||||
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
|
||||
AssertInfo(segment, "support SegmentSmallIndex Only");
|
||||
RetType ret;
|
||||
segment->QueryBruteForceImpl(node.query_info_, src_data_, timestamp_, ret);
|
||||
ret_ = ret;
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -38,18 +38,19 @@ void
|
|||
ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
||||
// std::vector<float> data(node.data_.get(), node.data_.get() + node.num_queries_ * node.dim_);
|
||||
assert(!ret_);
|
||||
auto& info = node.query_info_;
|
||||
Json json_body{
|
||||
{"node_type", "FloatVectorANNS"}, //
|
||||
{"metric_type", node.metric_type_}, //
|
||||
{"dim", node.dim_}, //
|
||||
{"field_id_", node.field_id_}, //
|
||||
{"num_queries", node.num_queries_}, //
|
||||
{"data", node.data_}, //
|
||||
{"metric_type", info.metric_type_}, //
|
||||
{"dim", info.dim_}, //
|
||||
{"field_id_", info.field_id_}, //
|
||||
{"num_queries", info.num_queries_}, //
|
||||
{"topK", info.topK_}, //
|
||||
};
|
||||
if (node.predicate_.has_value()) {
|
||||
AssertInfo(false, "unimplemented");
|
||||
} else {
|
||||
json_body["predicate"] = "nullopt";
|
||||
// json_body["predicate"] = "nullopt";
|
||||
}
|
||||
ret_ = json_body;
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ class SegmentBase {
|
|||
|
||||
// query contains metadata of
|
||||
virtual Status
|
||||
Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
|
||||
QueryDeprecated(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
|
||||
|
||||
// // THIS FUNCTION IS REMOVED
|
||||
// virtual Status
|
||||
|
|
|
@ -458,7 +458,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
SegmentNaive::QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
// TODO: enable delete
|
||||
// TODO: enable index
|
||||
// TODO: remove mock
|
||||
|
|
|
@ -45,7 +45,7 @@ class SegmentNaive : public SegmentBase {
|
|||
|
||||
// query contains metadata of
|
||||
Status
|
||||
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||
QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||
|
||||
// stop receive insert requests
|
||||
// will move data to immutable vector or something
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
#include <segcore/SegmentSmallIndex.h>
|
||||
#include <random>
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
#include <queue>
|
||||
|
||||
#include "segcore/SegmentNaive.h"
|
||||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/index/vector_index/VecIndexFactory.h>
|
||||
#include <faiss/utils/distances.h>
|
||||
#include "segcore/SegmentSmallIndex.h"
|
||||
#include "query/PlanNode.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -251,8 +254,9 @@ merge_into(int64_t queries,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
Status
|
||||
SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) {
|
||||
SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info, const float* query_data, Timestamp timestamp, QueryResult& results) {
|
||||
// step 1: binary search to find the barrier of the snapshot
|
||||
auto ins_barrier = get_barrier(record_, timestamp);
|
||||
auto del_barrier = get_barrier(deleted_record_, timestamp);
|
||||
|
@ -263,21 +267,23 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim
|
|||
#endif
|
||||
|
||||
// step 2.1: get meta
|
||||
auto& field = schema_->operator[](query_info->field_name);
|
||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
auto dim = field.get_dim();
|
||||
auto topK = query_info->topK;
|
||||
auto num_queries = query_info->num_queries;
|
||||
auto total_count = topK * num_queries;
|
||||
// TODO: optimize
|
||||
|
||||
// step 2.2: get which vector field to search
|
||||
auto vecfield_offset_opt = schema_->get_offset(query_info->field_name);
|
||||
auto vecfield_offset_opt = schema_->get_offset(info.field_id_);
|
||||
Assert(vecfield_offset_opt.has_value());
|
||||
auto vecfield_offset = vecfield_offset_opt.value();
|
||||
Assert(vecfield_offset < record_.entity_vec_.size());
|
||||
|
||||
auto& field = schema_->operator[](vecfield_offset);
|
||||
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_.at(vecfield_offset));
|
||||
|
||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
auto dim = field.get_dim();
|
||||
auto topK = info.topK_;
|
||||
auto num_queries = info.num_queries_;
|
||||
auto total_count = topK * num_queries;
|
||||
// TODO: optimize
|
||||
|
||||
|
||||
// step 3: small indexing search
|
||||
std::vector<int64_t> final_uids(total_count, -1);
|
||||
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
|
||||
|
@ -308,7 +314,7 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim
|
|||
auto src_data = vec_ptr->get_chunk(chunk_id).data();
|
||||
auto nsize =
|
||||
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
|
||||
faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf);
|
||||
faiss::knn_L2sqr(query_data, src_data, dim, num_queries, nsize, &buf);
|
||||
merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
|
||||
}
|
||||
|
||||
|
@ -327,7 +333,7 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentSmallIndex::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
SegmentSmallIndex::QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
// TODO: enable delete
|
||||
// TODO: enable index
|
||||
// TODO: remove mock
|
||||
|
@ -345,9 +351,16 @@ SegmentSmallIndex::Query(query::QueryPtr query_info, Timestamp timestamp, QueryR
|
|||
x = dis(e);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries;
|
||||
// TODO
|
||||
return QueryBruteForceImpl(query_info, timestamp, result);
|
||||
query::QueryInfo info {
|
||||
query_info->num_queries,
|
||||
inferred_dim,
|
||||
query_info->topK,
|
||||
query_info->field_name,
|
||||
"L2"
|
||||
};
|
||||
return QueryBruteForceImpl(info, query_info->query_raw_data.data(), timestamp, result);
|
||||
}
|
||||
|
||||
Status
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include <shared_mutex>
|
||||
#include <knowhere/index/vector_index/VecIndex.h>
|
||||
#include <query/PlanNode.h>
|
||||
|
||||
#include "AckResponder.h"
|
||||
#include "ConcurrentVector.h"
|
||||
|
@ -70,7 +71,7 @@ class SegmentSmallIndex : public SegmentBase {
|
|||
|
||||
// query contains metadata of
|
||||
Status
|
||||
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||
QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||
|
||||
// stop receive insert requests
|
||||
// will move data to immutable vector or something
|
||||
|
@ -125,21 +126,18 @@ class SegmentSmallIndex : public SegmentBase {
|
|||
explicit SegmentSmallIndex(SchemaPtr schema) : schema_(schema), record_(*schema_), indexing_record_(*schema_) {
|
||||
}
|
||||
|
||||
private:
|
||||
// struct MutableRecord {
|
||||
// ConcurrentVector<uint64_t> uids_;
|
||||
// tbb::concurrent_vector<Timestamp> timestamps_;
|
||||
// std::vector<tbb::concurrent_vector<float>> entity_vecs_;
|
||||
//
|
||||
// MutableRecord(int entity_size) : entity_vecs_(entity_size) {
|
||||
// }
|
||||
// };
|
||||
|
||||
public:
|
||||
std::shared_ptr<DeletedRecord::TmpBitmap>
|
||||
get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force = false);
|
||||
|
||||
// Status
|
||||
// QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
|
||||
|
||||
Status
|
||||
QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results);
|
||||
QueryBruteForceImpl(const query::QueryInfo& info,
|
||||
const float* query_data,
|
||||
Timestamp timestamp,
|
||||
QueryResult& results);
|
||||
|
||||
template <typename Type>
|
||||
knowhere::IndexPtr
|
||||
|
|
|
@ -134,7 +134,7 @@ Search(CSegmentBase c_segment,
|
|||
query_ptr->query_raw_data.resize(num_of_query_raw_data);
|
||||
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
|
||||
|
||||
auto res = segment->Query(query_ptr, timestamp, query_result);
|
||||
auto res = segment->QueryDeprecated(query_ptr, timestamp, query_result);
|
||||
|
||||
// result_ids and result_distances have been allocated memory in goLang,
|
||||
// so we don't need to malloc here.
|
||||
|
|
|
@ -11,5 +11,5 @@ EasyAssertInfo(
|
|||
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
|
||||
}
|
||||
|
||||
#define AssertInfo(expr, info) impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
|
||||
#define AssertInfo(expr, info) milvus::impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
|
||||
#define Assert(expr) AssertInfo((expr), "")
|
||||
|
|
|
@ -56,14 +56,17 @@ TEST(Query, ShowExecutor) {
|
|||
int64_t num_queries = 100L;
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
auto raw_data = DataGen(schema, num_queries);
|
||||
node->data_ = raw_data.get_col<float>(0);
|
||||
node->metric_type_ = "L2";
|
||||
node->num_queries_ = 10;
|
||||
node->dim_ = 16;
|
||||
auto& info = node->query_info_;
|
||||
info.metric_type_ = "L2";
|
||||
info.num_queries_ = 10;
|
||||
info.dim_ = 16;
|
||||
info.topK_ = 20;
|
||||
info.field_id_ = "fakevec";
|
||||
node->predicate_ = std::nullopt;
|
||||
ShowPlanNodeVisitor show_visitor;
|
||||
PlanNodePtr base(node.release());
|
||||
auto res = show_visitor.call_child(*base);
|
||||
res["data"] = "...collased...";
|
||||
std::cout << res.dump(4);
|
||||
auto dup = res;
|
||||
dup["data"] = "...collased...";
|
||||
std::cout << dup.dump(4);
|
||||
}
|
|
@ -36,15 +36,17 @@ def meta_gen(content):
|
|||
|
||||
if len(pack) == 1:
|
||||
pack.append(None)
|
||||
struct_name, base_name = pack
|
||||
if not base_name:
|
||||
root_base = struct_name
|
||||
|
||||
body_res = body_pattern.findall(body)
|
||||
if len(body_res) != 1:
|
||||
continue
|
||||
eprint(struct_name)
|
||||
eprint(body_res)
|
||||
eprint(body)
|
||||
assert(false)
|
||||
struct_name, base_name = pack
|
||||
if not base_name:
|
||||
root_base = struct_name
|
||||
visitor_name, state = body_res[0]
|
||||
assert(visitor_name == root_base)
|
||||
if state.strip() == 'override':
|
||||
|
|
|
@ -6,7 +6,6 @@ void
|
|||
####
|
||||
|
||||
@@@@main
|
||||
#pragma once
|
||||
// Generated File
|
||||
// DO NOT EDIT
|
||||
#include "query/@@root_base@@.h"
|
||||
|
|
Loading…
Reference in New Issue