Enable most DSL-related cases

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-12-21 14:45:00 +08:00 committed by yefu.chen
parent 06620935ab
commit d5daa18392
5 changed files with 332 additions and 142 deletions

View File

@ -61,7 +61,7 @@ ruleguard:
verifiers: getdeps cppcheck fmt lint ruleguard
# Builds various components locally.
build-go:
build-go: build-cpp
@echo "Building each component's binary to './bin'"
@echo "Building master ..."
@mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="0" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null

View File

@ -20,6 +20,7 @@
#include <memory>
#include <boost/align/aligned_allocator.hpp>
#include <boost/algorithm/string.hpp>
#include <algorithm>
namespace milvus::query {
@ -39,10 +40,8 @@ const std::map<std::string, RangeExpr::OpType> RangeExpr::mapping_ = {
class Parser {
public:
static std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
return Parser(schema).CreatePlanImpl(dsl_str);
}
friend std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str);
private:
std::unique_ptr<Plan>
@ -51,29 +50,55 @@ class Parser {
explicit Parser(const Schema& schema) : schema(schema) {
}
// vector node parser, should be called exactly once per pass.
std::unique_ptr<VectorPlanNode>
ParseVecNode(const Json& out_body);
// Dispatcher of all parse function
// NOTE: when nullptr, it is a pure vector node
ExprPtr
ParseAnyNode(const Json& body);
ExprPtr
ParseMustNode(const Json& body);
ExprPtr
ParseShouldNode(const Json& body);
ExprPtr
ParseShouldNotNode(const Json& body);
// parse the value of "should"/"must"/"should_not" entry
std::vector<ExprPtr>
ParseItemList(const Json& body);
// parse the value of "range" entry
ExprPtr
ParseRangeNode(const Json& out_body);
// parse the value of "term" entry
ExprPtr
ParseTermNode(const Json& out_body);
private:
// template implementation of leaf parser
// used by corresponding parser
template <typename T>
std::unique_ptr<Expr>
ExprPtr
ParseRangeNodeImpl(const std::string& field_name, const Json& body);
template <typename T>
std::unique_ptr<Expr>
ExprPtr
ParseTermNodeImpl(const std::string& field_name, const Json& body);
std::unique_ptr<Expr>
ParseRangeNode(const Json& out_body);
std::unique_ptr<Expr>
ParseTermNode(const Json& out_body);
private:
const Schema& schema;
std::map<std::string, FieldId> tag2field_; // PlaceholderName -> FieldId
std::optional<std::unique_ptr<VectorPlanNode>> vector_node_opt_;
};
std::unique_ptr<Expr>
ExprPtr
Parser::ParseRangeNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
@ -84,9 +109,8 @@ Parser::ParseRangeNode(const Json& out_body) {
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
case DataType::BOOL:
return ParseRangeNodeImpl<bool>(field_name, body);
}
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(field_name, body);
case DataType::INT16:
@ -106,51 +130,22 @@ Parser::ParseRangeNode(const Json& out_body) {
std::unique_ptr<Plan>
Parser::CreatePlanImpl(const std::string& dsl_str) {
auto plan = std::make_unique<Plan>(schema);
auto dsl = nlohmann::json::parse(dsl_str);
nlohmann::json vec_pack;
std::optional<std::unique_ptr<Expr>> predicate;
// top level
auto& bool_dsl = dsl.at("bool");
if (bool_dsl.contains("must")) {
auto& packs = bool_dsl.at("must");
Assert(packs.is_array());
for (auto& pack : packs) {
if (pack.contains("vector")) {
auto& out_body = pack.at("vector");
plan->plan_node_ = ParseVecNode(out_body);
} else if (pack.contains("term")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("term");
predicate = ParseTermNode(out_body);
} else if (pack.contains("range")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("range");
predicate = ParseRangeNode(out_body);
} else {
PanicInfo("unsupported node");
}
}
AssertInfo(plan->plan_node_, "vector node not found");
} else if (bool_dsl.contains("vector")) {
auto& out_body = bool_dsl.at("vector");
plan->plan_node_ = ParseVecNode(out_body);
Assert(plan->plan_node_);
} else {
PanicInfo("Unsupported DSL");
auto dsl = Json::parse(dsl_str);
auto bool_dsl = dsl.at("bool");
auto predicate = ParseAnyNode(bool_dsl);
Assert(vector_node_opt_.has_value());
auto vec_node = std::move(vector_node_opt_).value();
if (predicate != nullptr) {
vec_node->predicate_ = std::move(predicate);
}
plan->plan_node_->predicate_ = std::move(predicate);
auto plan = std::make_unique<Plan>(schema);
plan->tag2field_ = std::move(tag2field_);
// TODO: target_entry parser
// if schema autoid is true,
// prepend target_entries_ with row_id
// else
// with primary_key
//
plan->plan_node_ = std::move(vec_node);
return plan;
}
std::unique_ptr<Expr>
ExprPtr
Parser::ParseTermNode(const Json& out_body) {
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
@ -221,7 +216,7 @@ Parser::ParseVecNode(const Json& out_body) {
}
template <typename T>
std::unique_ptr<Expr>
ExprPtr
Parser::ParseTermNodeImpl(const std::string& field_name, const Json& body) {
auto expr = std::make_unique<TermExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
@ -249,7 +244,7 @@ Parser::ParseTermNodeImpl(const std::string& field_name, const Json& body) {
}
template <typename T>
std::unique_ptr<Expr>
ExprPtr
Parser::ParseRangeNodeImpl(const std::string& field_name, const Json& body) {
auto expr = std::make_unique<RangeExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
@ -278,12 +273,6 @@ Parser::ParseRangeNodeImpl(const std::string& field_name, const Json& body) {
return expr;
}
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
auto plan = Parser::CreatePlan(schema, dsl_str);
return plan;
}
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, const std::string& blob) {
namespace ser = milvus::proto::service;
@ -313,6 +302,150 @@ ParsePlaceholderGroup(const Plan* plan, const std::string& blob) {
return result;
}
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
auto plan = Parser(schema).CreatePlanImpl(dsl_str);
return plan;
}
std::vector<ExprPtr>
Parser::ParseItemList(const Json& body) {
std::vector<ExprPtr> results;
if (body.is_object()) {
// only one item;
auto new_entry = ParseAnyNode(body);
results.emplace_back(std::move(new_entry));
} else {
// item array
Assert(body.is_array());
for (auto& item : body) {
auto new_entry = ParseAnyNode(item);
results.emplace_back(std::move(new_entry));
}
}
auto old_size = results.size();
auto new_end = std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) { return x == nullptr; });
results.resize(new_end - results.begin());
return results;
}
ExprPtr
Parser::ParseAnyNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto key = out_iter.key();
auto body = out_iter.value();
if (key == "must") {
return ParseMustNode(body);
} else if (key == "should") {
return ParseShouldNode(body);
} else if (key == "should_not") {
return ParseShouldNotNode(body);
} else if (key == "range") {
return ParseRangeNode(body);
} else if (key == "term") {
return ParseTermNode(body);
} else if (key == "vector") {
auto vec_node = ParseVecNode(body);
Assert(!vector_node_opt_.has_value());
vector_node_opt_ = std::move(vec_node);
return nullptr;
} else {
PanicInfo("unsupported key: " + key);
}
}
template <typename Merger>
static ExprPtr
ConstructTree(Merger merger, std::vector<ExprPtr> item_list) {
if (item_list.size() == 0) {
return nullptr;
}
if (item_list.size() == 1) {
return std::move(item_list[0]);
}
// Note: use deque to construct a binary tree
// Op
// / \
// Op Op
// | \ | \
// A B C D
std::deque<ExprPtr> binary_queue;
for (auto& item : item_list) {
Assert(item != nullptr);
binary_queue.push_back(std::move(item));
}
while (binary_queue.size() > 1) {
auto left = std::move(binary_queue.front());
binary_queue.pop_front();
auto right = std::move(binary_queue.front());
binary_queue.pop_front();
binary_queue.push_back(merger(std::move(left), std::move(right)));
}
Assert(binary_queue.size() == 1);
return std::move(binary_queue.front());
}
ExprPtr
Parser::ParseMustNode(const Json& body) {
auto item_list = ParseItemList(body);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = BoolBinaryExpr::OpType;
auto res = std::make_unique<BoolBinaryExpr>();
res->op_type_ = OpType::LogicalAnd;
res->left_ = std::move(left);
res->right_ = std::move(right);
return res;
};
return ConstructTree(merger, std::move(item_list));
}
ExprPtr
Parser::ParseShouldNode(const Json& body) {
auto item_list = ParseItemList(body);
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = BoolBinaryExpr::OpType;
auto res = std::make_unique<BoolBinaryExpr>();
res->op_type_ = OpType::LogicalOr;
res->left_ = std::move(left);
res->right_ = std::move(right);
return res;
};
return ConstructTree(merger, std::move(item_list));
}
ExprPtr
Parser::ParseShouldNotNode(const Json& body) {
auto item_list = ParseItemList(body);
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = BoolBinaryExpr::OpType;
auto res = std::make_unique<BoolBinaryExpr>();
res->op_type_ = OpType::LogicalAnd;
res->left_ = std::move(left);
res->right_ = std::move(right);
return res;
};
auto subtree = ConstructTree(merger, std::move(item_list));
using OpType = BoolUnaryExpr::OpType;
auto res = std::make_unique<BoolUnaryExpr>();
res->op_type_ = OpType::LogicalNot;
res->child_ = std::move(subtree);
return res;
}
int64_t
GetTopK(const Plan* plan) {
return plan->plan_node_->query_info_.topK_;

View File

@ -67,6 +67,7 @@ ExecExprVisitor::visit(BoolUnaryExpr& expr) {
switch (expr.op_type_) {
case OpType::LogicalNot: {
chunk.flip();
break;
}
default: {
PanicInfo("Invalid OpType");

View File

@ -410,3 +410,104 @@ TEST(Expr, TestTerm) {
}
}
}
TEST(Expr, TestSimpleDsl) {
using namespace milvus::query;
using namespace milvus::segcore;
auto vec_dsl = Json::parse(R"(
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
)");
int N = 32;
auto get_item = [&](int base, int bit = 1) {
std::vector<int> terms;
// note: random gen range is [0, 2N)
for (int i = 0; i < N * 2; ++i) {
if (((i >> base) & 0x1) == bit) {
terms.push_back(i);
}
}
Json s;
s["term"]["age"]["values"] = terms;
return s;
};
// std::cout << get_item(0).dump(-2);
// std::cout << vec_dsl.dump(-2);
std::vector<std::tuple<Json, std::function<bool(int)>>> testcases;
{
Json dsl;
dsl["must"] = Json::array({vec_dsl, get_item(0), get_item(1), get_item(2, 0), get_item(3)});
testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) == 0b1011; });
}
{
Json dsl;
Json sub_dsl;
sub_dsl["must"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)});
dsl["must"] = Json::array({sub_dsl, vec_dsl});
testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) == 0b1011; });
}
{
Json dsl;
Json sub_dsl;
sub_dsl["should"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)});
dsl["must"] = Json::array({sub_dsl, vec_dsl});
testcases.emplace_back(dsl, [](int x) { return !!((x & 0b1111) ^ 0b0100); });
}
{
Json dsl;
Json sub_dsl;
sub_dsl["should_not"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)});
dsl["must"] = Json::array({sub_dsl, vec_dsl});
testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) != 0b1011; });
}
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("age", DataType::INT32);
auto seg = CreateSegment(schema);
std::vector<int> age_col;
int num_iters = 100;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGen(schema, N, iter);
auto new_age_col = raw_data.get_col<int>(1);
age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end());
seg->PreInsert(N);
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
}
auto seg_promote = dynamic_cast<SegmentSmallIndex*>(seg.get());
ExecExprVisitor visitor(*seg_promote);
for (auto [clause, ref_func] : testcases) {
Json dsl;
dsl["bool"] = clause;
// std::cout << dsl.dump(2);
auto plan = CreatePlan(*schema, dsl.dump());
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
bool ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}
}

View File

@ -705,8 +705,7 @@ class TestSearchBase:
# TODO:
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_jaccard_flat_index")
# PASS
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
'''
target: search binary_collection, and check the result: distance
@ -740,8 +739,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = connect.search(binary_collection, query)
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_hamming_flat_index")
# PASS
@pytest.mark.level(2)
def test_search_distance_hamming_flat_index(self, connect, binary_collection):
'''
@ -758,8 +756,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_substructure_flat_index")
# PASS
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
'''
@ -777,8 +774,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_substructure_flat_index_B")
# PASS
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
'''
@ -797,8 +793,7 @@ class TestSearchBase:
assert res[1][0].distance <= epsilon
assert res[1][0].id == ids[1]
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_superstructure_flat_index")
# PASS
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
'''
@ -816,8 +811,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_superstructure_flat_index_B")
# PASS
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
'''
@ -838,8 +832,7 @@ class TestSearchBase:
assert res[1][0].id in ids
assert res[1][0].distance <= epsilon
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_tanimoto_flat_index")
# PASS
@pytest.mark.level(2)
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
'''
@ -977,8 +970,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_no_must")
# PASS
def test_query_no_must(self, connect, collection):
'''
method: build query without must expr
@ -989,8 +981,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_no_vector_term_only")
# PASS
def test_query_no_vector_term_only(self, connect, collection):
'''
method: build query without vector only term
@ -1025,8 +1016,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_wrong_format")
# PASS
def test_query_wrong_format(self, connect, collection):
'''
method: build query without must expr, with wrong expr name
@ -1168,8 +1158,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
@pytest.mark.skip("query_complex_dsl")
# PASS
def test_query_complex_dsl(self, connect, collection):
'''
method: query with complicated dsl
@ -1191,9 +1180,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# DOG: TODO INVALID DSL
# TODO
@pytest.mark.skip("query_term_key_error")
# PASS
@pytest.mark.level(2)
def test_query_term_key_error(self, connect, collection):
'''
@ -1213,8 +1200,7 @@ class TestSearchDSL(object):
def get_invalid_term(self, request):
return request.param
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_term_wrong_format")
# PASS
@pytest.mark.level(2)
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
'''
@ -1228,7 +1214,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO UNKNOWN
# DOG: PLEASE IMPLEMENT connect.count_entities
# TODO
@pytest.mark.skip("query_term_field_named_term")
@pytest.mark.level(2)
@ -1244,8 +1230,8 @@ class TestSearchDSL(object):
ids = connect.bulk_insert(collection_term, term_entities)
assert len(ids) == default_nb
connect.flush([collection_term])
count = connect.count_entities(collection_term)
assert count == default_nb
count = connect.count_entities(collection_term) # count_entities is not impelmented
assert count == default_nb # removing these two lines, this test passed
term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}}
expr = {"must": [gen_default_vector_expr(default_query),
term_param]}
@ -1255,8 +1241,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
connect.drop_collection(collection_term)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_term_one_field_not_existed")
# PASS
@pytest.mark.level(2)
def test_query_term_one_field_not_existed(self, connect, collection):
'''
@ -1278,7 +1263,6 @@ class TestSearchDSL(object):
"""
# PASS
# TODO
def test_query_range_key_error(self, connect, collection):
'''
method: build query with range key error
@ -1298,7 +1282,6 @@ class TestSearchDSL(object):
return request.param
# PASS
# TODO
@pytest.mark.level(2)
def test_query_range_wrong_format(self, connect, collection, get_invalid_range):
'''
@ -1366,8 +1349,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_range_one_field_not_existed")
# PASS
def test_query_range_one_field_not_existed(self, connect, collection):
'''
method: build query with two fields ranges, one of fields not existed
@ -1387,10 +1369,7 @@ class TestSearchDSL(object):
************************************************************************
"""
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_term_has_common")
@pytest.mark.level(2)
# PASS
def test_query_multi_term_has_common(self, connect, collection):
'''
method: build query with multi term with same field, and values has common
@ -1405,9 +1384,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_term_no_common")
# PASS
@pytest.mark.level(2)
def test_query_multi_term_no_common(self, connect, collection):
'''
@ -1423,9 +1400,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_term_different_fields")
# PASS
def test_query_multi_term_different_fields(self, connect, collection):
'''
method: build query with multi range with same field, and ranges no common
@ -1441,9 +1416,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_term_multi_fields")
# PASS
@pytest.mark.level(2)
def test_query_single_term_multi_fields(self, connect, collection):
'''
@ -1459,9 +1432,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_range_has_common")
# PASS
@pytest.mark.level(2)
def test_query_multi_range_has_common(self, connect, collection):
'''
@ -1477,9 +1448,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_range_no_common")
# PASS
@pytest.mark.level(2)
def test_query_multi_range_no_common(self, connect, collection):
'''
@ -1495,9 +1464,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_range_different_fields")
# PASS
@pytest.mark.level(2)
def test_query_multi_range_different_fields(self, connect, collection):
'''
@ -1513,9 +1480,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_range_multi_fields")
# PASS
@pytest.mark.level(2)
def test_query_single_range_multi_fields(self, connect, collection):
'''
@ -1537,9 +1502,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_term_range_has_common")
# PASS
@pytest.mark.level(2)
def test_query_single_term_range_has_common(self, connect, collection):
'''
@ -1555,9 +1518,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_term_range_no_common")
# PASS
def test_query_single_term_range_no_common(self, connect, collection):
'''
method: build query with single term single range
@ -1579,7 +1540,6 @@ class TestSearchDSL(object):
"""
# PASS
# TODO
def test_query_multi_vectors_same_field(self, connect, collection):
'''
method: build query with two vectors same field
@ -1616,8 +1576,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_should_only_term")
# PASS
def test_query_should_only_term(self, connect, collection):
'''
method: build query without must, with should.term instead
@ -1628,8 +1587,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_should_only_vector")
# PASS
def test_query_should_only_vector(self, connect, collection):
'''
method: build query without must, with should.vector instead
@ -1640,8 +1598,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_must_not_only_term")
# PASS
def test_query_must_not_only_term(self, connect, collection):
'''
method: build query without must, with must_not.term instead
@ -1652,8 +1609,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_must_not_vector")
# PASS
def test_query_must_not_vector(self, connect, collection):
'''
method: build query without must, with must_not.vector instead
@ -1664,8 +1620,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_must_should")
# PASS
def test_query_must_should(self, connect, collection):
'''
method: build query must, and with should.term