mirror of https://github.com/milvus-io/milvus.git
fix merge result (#2463)
* fix merge result Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * fix tests Signed-off-by: shengjun.li <shengjun.li@zilliz.com>pull/2481/head
parent
1daf00dcf7
commit
257ea7782f
|
@ -84,11 +84,8 @@ JobMgr::worker_function() {
|
||||||
// TODO(zhiru): if the job is search by ids, pass any task where the ids don't exist
|
// TODO(zhiru): if the job is search by ids, pass any task where the ids don't exist
|
||||||
auto search_job = std::dynamic_pointer_cast<SearchJob>(job);
|
auto search_job = std::dynamic_pointer_cast<SearchJob>(job);
|
||||||
if (search_job != nullptr) {
|
if (search_job != nullptr) {
|
||||||
scheduler::ResultIds ids(search_job->nq() * search_job->topk(), -1);
|
search_job->GetResultIds().resize(search_job->nq(), -1);
|
||||||
scheduler::ResultDistances distances(search_job->nq() * search_job->topk(),
|
search_job->GetResultDistances().resize(search_job->nq(), std::numeric_limits<float>::max());
|
||||||
std::numeric_limits<float>::max());
|
|
||||||
search_job->GetResultIds() = ids;
|
|
||||||
search_job->GetResultDistances() = distances;
|
|
||||||
|
|
||||||
if (search_job->vectors().float_data_.empty() && search_job->vectors().binary_data_.empty() &&
|
if (search_job->vectors().float_data_.empty() && search_job->vectors().binary_data_.empty() &&
|
||||||
!search_job->vectors().id_array_.empty()) {
|
!search_job->vectors().id_array_.empty()) {
|
||||||
|
|
|
@ -279,9 +279,7 @@ XSearchTask::Execute() {
|
||||||
auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk;
|
auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk;
|
||||||
if (spec_k == 0) {
|
if (spec_k == 0) {
|
||||||
LOG_ENGINE_WARNING_ << "Searching in an empty file. file location = " << file_->location_;
|
LOG_ENGINE_WARNING_ << "Searching in an empty file. file location = " << file_->location_;
|
||||||
}
|
} else {
|
||||||
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(search_job->mutex());
|
std::unique_lock<std::mutex> lock(search_job->mutex());
|
||||||
search_job->vector_count() = nq;
|
search_job->vector_count() = nq;
|
||||||
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
|
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
|
||||||
|
@ -315,19 +313,8 @@ XSearchTask::Execute() {
|
||||||
if (spec_k == 0) {
|
if (spec_k == 0) {
|
||||||
LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", "search", 0,
|
LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", "search", 0,
|
||||||
file_->location_.c_str());
|
file_->location_.c_str());
|
||||||
}
|
} else {
|
||||||
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(search_job->mutex());
|
std::unique_lock<std::mutex> lock(search_job->mutex());
|
||||||
|
|
||||||
if (search_job->GetResultIds().size() > spec_k) {
|
|
||||||
if (search_job->GetResultIds().front() == -1) {
|
|
||||||
// initialized results set
|
|
||||||
search_job->GetResultIds().resize(spec_k * nq);
|
|
||||||
search_job->GetResultDistances().resize(spec_k * nq);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
|
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
|
||||||
search_job->GetResultIds(), search_job->GetResultDistances());
|
search_job->GetResultIds(), search_job->GetResultDistances());
|
||||||
}
|
}
|
||||||
|
|
|
@ -565,11 +565,7 @@ TEST_F(DeleteTest, delete_single_vector) {
|
||||||
milvus::engine::ResultDistances result_distances;
|
milvus::engine::ResultDistances result_distances;
|
||||||
stat = db_->Query(dummy_context_,
|
stat = db_->Query(dummy_context_,
|
||||||
collection_info.collection_id_, tags, topk, json_params, xb, result_ids, result_distances);
|
collection_info.collection_id_, tags, topk, json_params, xb, result_ids, result_distances);
|
||||||
ASSERT_TRUE(result_ids.empty());
|
ASSERT_TRUE(result_ids.empty() || (result_ids[0] == -1));
|
||||||
ASSERT_TRUE(result_distances.empty());
|
|
||||||
// ASSERT_EQ(result_ids[0], -1);
|
|
||||||
// ASSERT_LT(result_distances[0], 1e-4);
|
|
||||||
// ASSERT_EQ(result_distances[0], std::numeric_limits<float>::max());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeleteTest, delete_add_create_index) {
|
TEST_F(DeleteTest, delete_add_create_index) {
|
||||||
|
|
|
@ -61,7 +61,7 @@ class TestDeleteBase:
|
||||||
status, res = connect.search(collection, top_k, vector, params=search_param)
|
status, res = connect.search(collection, top_k, vector, params=search_param)
|
||||||
logging.getLogger().info(res)
|
logging.getLogger().info(res)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(res) == 0
|
assert len(res[0]) == 0
|
||||||
|
|
||||||
def test_delete_vector_multi_same_ids(self, connect, collection, get_simple_index):
|
def test_delete_vector_multi_same_ids(self, connect, collection, get_simple_index):
|
||||||
'''
|
'''
|
||||||
|
@ -83,7 +83,7 @@ class TestDeleteBase:
|
||||||
status, res = connect.search(collection, top_k, [vectors[0]], params=search_param)
|
status, res = connect.search(collection, top_k, [vectors[0]], params=search_param)
|
||||||
logging.getLogger().info(res)
|
logging.getLogger().info(res)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(res) == 0
|
assert len(res[0]) == 0
|
||||||
|
|
||||||
def test_delete_vector_collection_count(self, connect, collection):
|
def test_delete_vector_collection_count(self, connect, collection):
|
||||||
'''
|
'''
|
||||||
|
@ -327,7 +327,7 @@ class TestDeleteIndexedVectors:
|
||||||
status, res = connect.search(collection, top_k, vector, params=search_param)
|
status, res = connect.search(collection, top_k, vector, params=search_param)
|
||||||
logging.getLogger().info(res)
|
logging.getLogger().info(res)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(res) == 0
|
assert len(res[0]) == 0
|
||||||
|
|
||||||
def test_insert_delete_vector(self, connect, collection, get_simple_index):
|
def test_insert_delete_vector(self, connect, collection, get_simple_index):
|
||||||
'''
|
'''
|
||||||
|
@ -399,9 +399,7 @@ class TestDeleteBinary:
|
||||||
status, res = connect.search(jac_collection, top_k, vector, params=search_param)
|
status, res = connect.search(jac_collection, top_k, vector, params=search_param)
|
||||||
logging.getLogger().info(res)
|
logging.getLogger().info(res)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(res) == 0
|
assert len(res[0]) == 0
|
||||||
assert status.OK()
|
|
||||||
assert len(res) == 0
|
|
||||||
|
|
||||||
# TODO: soft delete
|
# TODO: soft delete
|
||||||
def test_delete_vector_collection_count(self, connect, jac_collection):
|
def test_delete_vector_collection_count(self, connect, jac_collection):
|
||||||
|
|
Loading…
Reference in New Issue