fix merge result (#2463)

* fix merge result

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

* fix tests

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/2481/head
shengjun.li 2020-06-01 15:31:58 +08:00 committed by GitHub
parent 1daf00dcf7
commit 257ea7782f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 31 deletions

View File

@ -84,11 +84,8 @@ JobMgr::worker_function() {
// TODO(zhiru): if the job is search by ids, pass any task where the ids don't exist // TODO(zhiru): if the job is search by ids, pass any task where the ids don't exist
auto search_job = std::dynamic_pointer_cast<SearchJob>(job); auto search_job = std::dynamic_pointer_cast<SearchJob>(job);
if (search_job != nullptr) { if (search_job != nullptr) {
scheduler::ResultIds ids(search_job->nq() * search_job->topk(), -1); search_job->GetResultIds().resize(search_job->nq(), -1);
scheduler::ResultDistances distances(search_job->nq() * search_job->topk(), search_job->GetResultDistances().resize(search_job->nq(), std::numeric_limits<float>::max());
std::numeric_limits<float>::max());
search_job->GetResultIds() = ids;
search_job->GetResultDistances() = distances;
if (search_job->vectors().float_data_.empty() && search_job->vectors().binary_data_.empty() && if (search_job->vectors().float_data_.empty() && search_job->vectors().binary_data_.empty() &&
!search_job->vectors().id_array_.empty()) { !search_job->vectors().id_array_.empty()) {

View File

@ -279,9 +279,7 @@ XSearchTask::Execute() {
auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk; auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk;
if (spec_k == 0) { if (spec_k == 0) {
LOG_ENGINE_WARNING_ << "Searching in an empty file. file location = " << file_->location_; LOG_ENGINE_WARNING_ << "Searching in an empty file. file location = " << file_->location_;
} } else {
{
std::unique_lock<std::mutex> lock(search_job->mutex()); std::unique_lock<std::mutex> lock(search_job->mutex());
search_job->vector_count() = nq; search_job->vector_count() = nq;
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce, XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
@ -315,19 +313,8 @@ XSearchTask::Execute() {
if (spec_k == 0) { if (spec_k == 0) {
LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", "search", 0, LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", "search", 0,
file_->location_.c_str()); file_->location_.c_str());
} } else {
{
std::unique_lock<std::mutex> lock(search_job->mutex()); std::unique_lock<std::mutex> lock(search_job->mutex());
if (search_job->GetResultIds().size() > spec_k) {
if (search_job->GetResultIds().front() == -1) {
// initialized results set
search_job->GetResultIds().resize(spec_k * nq);
search_job->GetResultDistances().resize(spec_k * nq);
}
}
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce, XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
search_job->GetResultIds(), search_job->GetResultDistances()); search_job->GetResultIds(), search_job->GetResultDistances());
} }

View File

@ -565,11 +565,7 @@ TEST_F(DeleteTest, delete_single_vector) {
milvus::engine::ResultDistances result_distances; milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, stat = db_->Query(dummy_context_,
collection_info.collection_id_, tags, topk, json_params, xb, result_ids, result_distances); collection_info.collection_id_, tags, topk, json_params, xb, result_ids, result_distances);
ASSERT_TRUE(result_ids.empty()); ASSERT_TRUE(result_ids.empty() || (result_ids[0] == -1));
ASSERT_TRUE(result_distances.empty());
// ASSERT_EQ(result_ids[0], -1);
// ASSERT_LT(result_distances[0], 1e-4);
// ASSERT_EQ(result_distances[0], std::numeric_limits<float>::max());
} }
TEST_F(DeleteTest, delete_add_create_index) { TEST_F(DeleteTest, delete_add_create_index) {

View File

@ -61,7 +61,7 @@ class TestDeleteBase:
status, res = connect.search(collection, top_k, vector, params=search_param) status, res = connect.search(collection, top_k, vector, params=search_param)
logging.getLogger().info(res) logging.getLogger().info(res)
assert status.OK() assert status.OK()
assert len(res) == 0 assert len(res[0]) == 0
def test_delete_vector_multi_same_ids(self, connect, collection, get_simple_index): def test_delete_vector_multi_same_ids(self, connect, collection, get_simple_index):
''' '''
@ -83,7 +83,7 @@ class TestDeleteBase:
status, res = connect.search(collection, top_k, [vectors[0]], params=search_param) status, res = connect.search(collection, top_k, [vectors[0]], params=search_param)
logging.getLogger().info(res) logging.getLogger().info(res)
assert status.OK() assert status.OK()
assert len(res) == 0 assert len(res[0]) == 0
def test_delete_vector_collection_count(self, connect, collection): def test_delete_vector_collection_count(self, connect, collection):
''' '''
@ -327,7 +327,7 @@ class TestDeleteIndexedVectors:
status, res = connect.search(collection, top_k, vector, params=search_param) status, res = connect.search(collection, top_k, vector, params=search_param)
logging.getLogger().info(res) logging.getLogger().info(res)
assert status.OK() assert status.OK()
assert len(res) == 0 assert len(res[0]) == 0
def test_insert_delete_vector(self, connect, collection, get_simple_index): def test_insert_delete_vector(self, connect, collection, get_simple_index):
''' '''
@ -399,9 +399,7 @@ class TestDeleteBinary:
status, res = connect.search(jac_collection, top_k, vector, params=search_param) status, res = connect.search(jac_collection, top_k, vector, params=search_param)
logging.getLogger().info(res) logging.getLogger().info(res)
assert status.OK() assert status.OK()
assert len(res) == 0 assert len(res[0]) == 0
assert status.OK()
assert len(res) == 0
# TODO: soft delete # TODO: soft delete
def test_delete_vector_collection_count(self, connect, jac_collection): def test_delete_vector_collection_count(self, connect, jac_collection):