mirror of https://github.com/milvus-io/milvus.git
Fix test_query_range_valid_ranges bug (#3224)
* Fix TestSearchDSL level 2 bugs Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix QueryTest Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Add annotation in milvus.proto Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix CreateIndex in C++ sdk Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix C++ sdk range test Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix test_search_ip_index_partitions Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix test_query_range_valid_ranges Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix GetCollectionInfo Signed-off-by: fishpenguin <kun.yu@zilliz.com> Co-authored-by: quicksilver <zhifeng.zhang@zilliz.com>pull/3207/head^2
parent
3735e3d19d
commit
92e2a26a78
|
@ -462,32 +462,35 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
|
|||
segment_reader_->GetSegment(segment_ptr);
|
||||
knowhere::IndexPtr index_ptr = nullptr;
|
||||
auto attr_index = segment_ptr->GetStructuredIndex(field_name, index_ptr);
|
||||
if (!index_ptr) {
|
||||
return Status(DB_ERROR, "Get field: " + field_name + " structured index failed");
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::INT8: {
|
||||
ProcessIndexedTermQuery<int8_t>(bitset, index_ptr, term_values_json);
|
||||
STATUS_CHECK(ProcessIndexedTermQuery<int8_t>(bitset, index_ptr, term_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
ProcessIndexedTermQuery<int16_t>(bitset, index_ptr, term_values_json);
|
||||
STATUS_CHECK(ProcessIndexedTermQuery<int16_t>(bitset, index_ptr, term_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
ProcessIndexedTermQuery<int32_t>(bitset, index_ptr, term_values_json);
|
||||
STATUS_CHECK(ProcessIndexedTermQuery<int32_t>(bitset, index_ptr, term_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
ProcessIndexedTermQuery<int64_t>(bitset, index_ptr, term_values_json);
|
||||
STATUS_CHECK(ProcessIndexedTermQuery<int64_t>(bitset, index_ptr, term_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
ProcessIndexedTermQuery<float>(bitset, index_ptr, term_values_json);
|
||||
STATUS_CHECK(ProcessIndexedTermQuery<float>(bitset, index_ptr, term_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
ProcessIndexedTermQuery<double>(bitset, index_ptr, term_values_json);
|
||||
STATUS_CHECK(ProcessIndexedTermQuery<double>(bitset, index_ptr, term_values_json));
|
||||
break;
|
||||
}
|
||||
default: { return Status{SERVER_INVALID_ARGUMENT, "Attribute:" + field_name + " type is wrong"}; }
|
||||
default: { return Status(SERVER_INVALID_ARGUMENT, "Attribute:" + field_name + " type is wrong"); }
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -544,27 +547,27 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const
|
|||
auto status = Status::OK();
|
||||
switch (data_type) {
|
||||
case DataType::INT8: {
|
||||
ProcessIndexedRangeQuery<int8_t>(bitset, index_ptr, range_values_json);
|
||||
STATUS_CHECK(ProcessIndexedRangeQuery<int8_t>(bitset, index_ptr, range_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
ProcessIndexedRangeQuery<int16_t>(bitset, index_ptr, range_values_json);
|
||||
STATUS_CHECK(ProcessIndexedRangeQuery<int16_t>(bitset, index_ptr, range_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
ProcessIndexedRangeQuery<int32_t>(bitset, index_ptr, range_values_json);
|
||||
STATUS_CHECK(ProcessIndexedRangeQuery<int32_t>(bitset, index_ptr, range_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
ProcessIndexedRangeQuery<int64_t>(bitset, index_ptr, range_values_json);
|
||||
STATUS_CHECK(ProcessIndexedRangeQuery<int64_t>(bitset, index_ptr, range_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
ProcessIndexedRangeQuery<float>(bitset, index_ptr, range_values_json);
|
||||
STATUS_CHECK(ProcessIndexedRangeQuery<float>(bitset, index_ptr, range_values_json));
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
ProcessIndexedRangeQuery<double>(bitset, index_ptr, range_values_json);
|
||||
STATUS_CHECK(ProcessIndexedRangeQuery<double>(bitset, index_ptr, range_values_json));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -105,7 +105,7 @@ ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
|
|||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
u64_1[i] &= u64_2[i];
|
||||
u64_1[i] |= u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
|
@ -134,7 +134,7 @@ ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
|
|||
size_t n64 = n8 / 8;
|
||||
|
||||
for (size_t i = 0; i < n64; i++) {
|
||||
result_64[i] = u64_1[i] & u64_2[i];
|
||||
result_64[i] = u64_1[i] | u64_2[i];
|
||||
}
|
||||
|
||||
size_t remain = n8 % 8;
|
||||
|
@ -150,7 +150,7 @@ ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
|
|||
|
||||
ConcurrentBitset&
|
||||
ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
|
||||
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
|
||||
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
|
||||
// bitset_[i].fetch_xor(bitset.bitset()[i].load());
|
||||
// }
|
||||
|
||||
|
|
|
@ -51,6 +51,10 @@ Job::TaskDone(Task* task) {
|
|||
return;
|
||||
}
|
||||
|
||||
auto json = task->Dump();
|
||||
std::string task_desc = json.dump();
|
||||
LOG_SERVER_DEBUG_ << LogOut("scheduler job [%ld] task %s finish", id(), task_desc.c_str());
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
for (JobTasks::iterator iter = tasks_.begin(); iter != tasks_.end(); ++iter) {
|
||||
if (task == (*iter).get()) {
|
||||
|
@ -61,10 +65,6 @@ Job::TaskDone(Task* task) {
|
|||
if (tasks_.empty()) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
auto json = task->Dump();
|
||||
std::string task_desc = json.dump();
|
||||
LOG_SERVER_DEBUG_ << LogOut("scheduler job [%ld] task %s finish", id(), task_desc.c_str());
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -24,7 +24,7 @@ Task::Load(LoadType type, uint8_t device_id) {
|
|||
if (job_) {
|
||||
if (!status.ok()) {
|
||||
job_->status() = status;
|
||||
job_->TaskDone(this);
|
||||
// job_->TaskDone(this);
|
||||
}
|
||||
} else {
|
||||
LOG_ENGINE_ERROR_ << "Scheduler task's parent job not specified!";
|
||||
|
|
|
@ -52,20 +52,22 @@ GetCollectionInfoReq::OnExecute() {
|
|||
for (auto& field_kv : field_mappings) {
|
||||
auto field = field_kv.first;
|
||||
|
||||
FieldSchema field_schema;
|
||||
milvus::json field_index_param;
|
||||
auto field_elements = field_kv.second;
|
||||
for (const auto& element : field_elements) {
|
||||
if (element->GetFtype() == (engine::snapshot::FTYPE_TYPE)engine::FieldElementType::FET_INDEX) {
|
||||
field_index_param = element->GetParams();
|
||||
auto type = element->GetTypeName();
|
||||
field_schema.index_params_ = field_index_param;
|
||||
field_schema.index_params_[engine::PARAM_INDEX_TYPE] = element->GetTypeName();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto field_name = field->GetName();
|
||||
FieldSchema field_schema;
|
||||
field_schema.field_type_ = (engine::DataType)field->GetFtype();
|
||||
field_schema.field_params_ = field->GetParams();
|
||||
field_schema.index_params_ = field_index_param;
|
||||
|
||||
collection_schema_.fields_.insert(std::make_pair(field_name, field_schema));
|
||||
}
|
||||
|
|
|
@ -1038,7 +1038,11 @@ GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::m
|
|||
for (auto& item : field_schema.index_params_.items()) {
|
||||
auto grpc_index_param = field->add_index_params();
|
||||
grpc_index_param->set_key(item.key());
|
||||
grpc_index_param->set_value(item.value());
|
||||
if (item.value().is_object()) {
|
||||
grpc_index_param->set_value(item.value().dump());
|
||||
} else {
|
||||
grpc_index_param->set_value(item.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -326,7 +326,7 @@ ClientTest::Test() {
|
|||
ListCollections(table_array);
|
||||
|
||||
CreateCollection(collection_name);
|
||||
GetCollectionInfo(collection_name);
|
||||
// GetCollectionInfo(collection_name);
|
||||
GetCollectionStats(collection_name);
|
||||
|
||||
ListCollections(table_array);
|
||||
|
@ -336,6 +336,7 @@ ClientTest::Test() {
|
|||
Flush(collection_name);
|
||||
CountEntities(collection_name);
|
||||
CreateIndex(collection_name, 1024);
|
||||
GetCollectionInfo(collection_name);
|
||||
// GetCollectionStats(collection_name);
|
||||
//
|
||||
BuildVectors(NQ, COLLECTION_DIMENSION);
|
||||
|
|
|
@ -553,7 +553,7 @@ class TestSearchBase:
|
|||
if min_distance > tmp_dis:
|
||||
min_distance = tmp_dis
|
||||
res = connect.search(collection, query)
|
||||
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0])
|
||||
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= epsilon
|
||||
|
||||
# TODO
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1133,7 +1133,7 @@ class TestSearchDSL(object):
|
|||
return request.param
|
||||
|
||||
# TODO:
|
||||
def _test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
|
||||
def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
|
||||
'''
|
||||
method: build query with valid ranges
|
||||
expected: pass
|
||||
|
|
|
@ -44,7 +44,7 @@ default_index_params = [
|
|||
{"nlist": 1024, "m": 16},
|
||||
{"M": 48, "efConstruction": 500},
|
||||
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
|
||||
{"n_trees": 4},
|
||||
{"n_trees": 50},
|
||||
{"nlist": 1024},
|
||||
{"nlist": 1024}
|
||||
]
|
||||
|
@ -314,7 +314,7 @@ def gen_default_term_expr(keyword="term", values=None):
|
|||
def gen_default_range_expr(keyword="range", ranges=None):
|
||||
if ranges is None:
|
||||
ranges = {"GT": 1, "LT": nb // 2}
|
||||
expr = {keyword: {"int64": {"ranges": ranges}}}
|
||||
expr = {keyword: {"int64": ranges}}
|
||||
return expr
|
||||
|
||||
|
||||
|
@ -341,7 +341,7 @@ def gen_invalid_ranges():
|
|||
def gen_valid_ranges():
|
||||
ranges = [
|
||||
{"GT": 0, "LT": nb//2},
|
||||
{"GT": nb, "LT": nb*2},
|
||||
{"GT": nb // 2, "LT": nb*2},
|
||||
{"GT": 0},
|
||||
{"LT": nb},
|
||||
{"GT": -1, "LT": top_k},
|
||||
|
@ -766,7 +766,7 @@ def get_search_param(index_type):
|
|||
elif index_type == "NSG":
|
||||
search_params.update({"search_length": 100})
|
||||
elif index_type == "ANNOY":
|
||||
search_params.update({"search_k": 100})
|
||||
search_params.update({"search_k": 1000})
|
||||
else:
|
||||
logging.getLogger().error("Invalid index_type.")
|
||||
raise Exception("Invalid index_type.")
|
||||
|
|
Loading…
Reference in New Issue