mirror of https://github.com/milvus-io/milvus.git
Merge remote-tracking branch 'upstream/branch-0.3.1' into branch-0.3.1
Former-commit-id: 194ecd448af946120b58f2de72b47dcc2b38080fpull/191/head
commit
60647624b2
|
@ -13,6 +13,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
- MS-157 - fix changelog
|
||||
|
||||
## Improvement
|
||||
- MS-156 - Add unittest for merge result functions
|
||||
|
||||
- MS-152 - Delete assert in MySQLMetaImpl and change MySQLConnectionPool impl
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ public:
|
|||
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
|
||||
const Id2IndexMap& GetIndexMap() const { return map_index_files_; }
|
||||
|
||||
using Id2ScoreMap = std::vector<std::pair<int64_t, double>>;
|
||||
using ResultSet = std::vector<Id2ScoreMap>;
|
||||
using Id2DistanceMap = std::vector<std::pair<int64_t, double>>;
|
||||
using ResultSet = std::vector<Id2DistanceMap>;
|
||||
const ResultSet& GetResult() const { return result_; }
|
||||
ResultSet& GetResult() { return result_; }
|
||||
|
||||
|
|
|
@ -13,104 +13,6 @@ namespace milvus {
|
|||
namespace engine {
|
||||
|
||||
namespace {
|
||||
void ClusterResult(const std::vector<long> &output_ids,
|
||||
const std::vector<float> &output_distence,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
SearchContext::ResultSet &result_set) {
|
||||
result_set.clear();
|
||||
result_set.reserve(nq);
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
SearchContext::Id2ScoreMap id_score;
|
||||
id_score.reserve(topk);
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
uint64_t index = i * topk + k;
|
||||
if(output_ids[index] < 0) {
|
||||
continue;
|
||||
}
|
||||
id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
|
||||
}
|
||||
result_set.emplace_back(id_score);
|
||||
}
|
||||
}
|
||||
|
||||
void MergeResult(SearchContext::Id2ScoreMap &score_src,
|
||||
SearchContext::Id2ScoreMap &score_target,
|
||||
uint64_t topk) {
|
||||
//Note: the score_src and score_target are already arranged by score in ascending order
|
||||
if(score_src.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(score_target.empty()) {
|
||||
score_target.swap(score_src);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t src_count = score_src.size();
|
||||
size_t target_count = score_target.size();
|
||||
SearchContext::Id2ScoreMap score_merged;
|
||||
score_merged.reserve(topk);
|
||||
size_t src_index = 0, target_index = 0;
|
||||
while(true) {
|
||||
//all score_src items are merged, if score_merged.size() still less than topk
|
||||
//move items from score_target to score_merged until score_merged.size() equal topk
|
||||
if(src_index >= src_count) {
|
||||
for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) {
|
||||
score_merged.push_back(score_target[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
//all score_target items are merged, if score_merged.size() still less than topk
|
||||
//move items from score_src to score_merged until score_merged.size() equal topk
|
||||
if(target_index >= target_count) {
|
||||
for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) {
|
||||
score_merged.push_back(score_src[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
//compare score, put smallest score to score_merged one by one
|
||||
auto& src_pair = score_src[src_index];
|
||||
auto& target_pair = score_target[target_index];
|
||||
if(src_pair.second > target_pair.second) {
|
||||
score_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
score_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
|
||||
//score_merged.size() already equal topk
|
||||
if(score_merged.size() >= topk) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
score_target.swap(score_merged);
|
||||
}
|
||||
|
||||
void TopkResult(SearchContext::ResultSet &result_src,
|
||||
uint64_t topk,
|
||||
SearchContext::ResultSet &result_target) {
|
||||
if (result_target.empty()) {
|
||||
result_target.swap(result_src);
|
||||
return;
|
||||
}
|
||||
|
||||
if (result_src.size() != result_target.size()) {
|
||||
SERVER_LOG_ERROR << "Invalid result set";
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < result_src.size(); i++) {
|
||||
SearchContext::Id2ScoreMap &score_src = result_src[i];
|
||||
SearchContext::Id2ScoreMap &score_target = result_target[i];
|
||||
MergeResult(score_src, score_target, topk);
|
||||
}
|
||||
}
|
||||
|
||||
void CollectDurationMetrics(int index_type, double total_time) {
|
||||
switch(index_type) {
|
||||
case meta::TableFileSchema::RAW: {
|
||||
|
@ -165,11 +67,11 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
|||
//step 3: cluster result
|
||||
SearchContext::ResultSet result_set;
|
||||
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
|
||||
ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
|
||||
SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
|
||||
rc.Record("cluster result");
|
||||
|
||||
//step 4: pick up topk result
|
||||
TopkResult(result_set, inner_k, context->GetResult());
|
||||
SearchTask::TopkResult(result_set, inner_k, context->GetResult());
|
||||
rc.Record("reduce topk");
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
|
@ -191,6 +93,119 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
|
||||
const std::vector<float> &output_distence,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
SearchContext::ResultSet &result_set) {
|
||||
if(output_ids.size() != nq*topk || output_distence.size() != nq*topk) {
|
||||
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
|
||||
" distance array size: " + std::to_string(output_distence.size());
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status::Error(msg);
|
||||
}
|
||||
|
||||
result_set.clear();
|
||||
result_set.reserve(nq);
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
SearchContext::Id2DistanceMap id_distance;
|
||||
id_distance.reserve(topk);
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
uint64_t index = i * topk + k;
|
||||
if(output_ids[index] < 0) {
|
||||
continue;
|
||||
}
|
||||
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
|
||||
}
|
||||
result_set.emplace_back(id_distance);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
SearchContext::Id2DistanceMap &distance_target,
|
||||
uint64_t topk) {
|
||||
//Note: the score_src and score_target are already arranged by score in ascending order
|
||||
if(distance_src.empty()) {
|
||||
SERVER_LOG_WARNING << "Empty distance source array";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if(distance_target.empty()) {
|
||||
distance_target.swap(distance_src);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t src_count = distance_src.size();
|
||||
size_t target_count = distance_target.size();
|
||||
SearchContext::Id2DistanceMap distance_merged;
|
||||
distance_merged.reserve(topk);
|
||||
size_t src_index = 0, target_index = 0;
|
||||
while(true) {
|
||||
//all score_src items are merged, if score_merged.size() still less than topk
|
||||
//move items from score_target to score_merged until score_merged.size() equal topk
|
||||
if(src_index >= src_count) {
|
||||
for(size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
|
||||
distance_merged.push_back(distance_target[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
//all score_target items are merged, if score_merged.size() still less than topk
|
||||
//move items from score_src to score_merged until score_merged.size() equal topk
|
||||
if(target_index >= target_count) {
|
||||
for(size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
|
||||
distance_merged.push_back(distance_src[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
//compare score, put smallest score to score_merged one by one
|
||||
auto& src_pair = distance_src[src_index];
|
||||
auto& target_pair = distance_target[target_index];
|
||||
if(src_pair.second > target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
|
||||
//score_merged.size() already equal topk
|
||||
if(distance_merged.size() >= topk) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
distance_target.swap(distance_merged);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
|
||||
uint64_t topk,
|
||||
SearchContext::ResultSet &result_target) {
|
||||
if (result_target.empty()) {
|
||||
result_target.swap(result_src);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (result_src.size() != result_target.size()) {
|
||||
std::string msg = "Invalid result set size";
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status::Error(msg);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < result_src.size(); i++) {
|
||||
SearchContext::Id2DistanceMap &score_src = result_src[i];
|
||||
SearchContext::Id2DistanceMap &score_target = result_target[i];
|
||||
SearchTask::MergeResult(score_src, score_target, topk);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,20 @@ public:
|
|||
|
||||
virtual std::shared_ptr<IScheduleTask> Execute() override;
|
||||
|
||||
static Status ClusterResult(const std::vector<long> &output_ids,
|
||||
const std::vector<float> &output_distence,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
SearchContext::ResultSet &result_set);
|
||||
|
||||
static Status MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
SearchContext::Id2DistanceMap &distance_target,
|
||||
uint64_t topk);
|
||||
|
||||
static Status TopkResult(SearchContext::ResultSet &result_src,
|
||||
uint64_t topk,
|
||||
SearchContext::ResultSet &result_target);
|
||||
|
||||
public:
|
||||
size_t index_id_ = 0;
|
||||
int index_type_ = 0; //for metrics
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace {
|
|||
static const std::string TABLE_NAME = "test_group";
|
||||
static constexpr int64_t TABLE_DIM = 256;
|
||||
static constexpr int64_t VECTOR_COUNT = 250000;
|
||||
static constexpr int64_t INSERT_LOOP = 100000;
|
||||
static constexpr int64_t INSERT_LOOP = 10000;
|
||||
|
||||
engine::meta::TableSchema BuildTableSchema() {
|
||||
engine::meta::TableSchema table_info;
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace {
|
|||
static const std::string TABLE_NAME = "test_group";
|
||||
static constexpr int64_t TABLE_DIM = 256;
|
||||
static constexpr int64_t VECTOR_COUNT = 250000;
|
||||
static constexpr int64_t INSERT_LOOP = 100000;
|
||||
static constexpr int64_t INSERT_LOOP = 10000;
|
||||
|
||||
engine::meta::TableSchema BuildTableSchema() {
|
||||
engine::meta::TableSchema table_info;
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
// Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
// Proprietary and confidential.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "db/scheduler/task/SearchTask.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
using namespace zilliz::milvus;
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr uint64_t NQ = 15;
|
||||
static constexpr uint64_t TOP_K = 64;
|
||||
|
||||
void BuildResult(uint64_t nq,
|
||||
uint64_t top_k,
|
||||
std::vector<long> &output_ids,
|
||||
std::vector<float> &output_distence) {
|
||||
output_ids.clear();
|
||||
output_ids.resize(nq*top_k);
|
||||
output_distence.clear();
|
||||
output_distence.resize(nq*top_k);
|
||||
|
||||
for(uint64_t i = 0; i < nq; i++) {
|
||||
for(uint64_t j = 0; j < top_k; j++) {
|
||||
output_ids[i * top_k + j] = (long)(drand48()*100000);
|
||||
output_distence[i * top_k + j] = j + drand48();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
|
||||
const engine::SearchContext::Id2DistanceMap& src_2,
|
||||
const engine::SearchContext::Id2DistanceMap& target) {
|
||||
for(uint64_t i = 0; i < target.size() - 1; i++) {
|
||||
ASSERT_LE(target[i].second, target[i + 1].second);
|
||||
}
|
||||
|
||||
using ID2DistMap = std::map<long, float>;
|
||||
ID2DistMap src_map_1, src_map_2;
|
||||
for(const auto& pair : src_1) {
|
||||
src_map_1.insert(pair);
|
||||
}
|
||||
for(const auto& pair : src_2) {
|
||||
src_map_2.insert(pair);
|
||||
}
|
||||
|
||||
for(const auto& pair : target) {
|
||||
ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());
|
||||
|
||||
float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
|
||||
ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, TOPK_TEST) {
|
||||
std::vector<long> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
engine::SearchContext::ResultSet src_result;
|
||||
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
|
||||
BuildResult(NQ, TOP_K, target_ids, target_distence);
|
||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), NQ);
|
||||
|
||||
engine::SearchContext::ResultSet target_result;
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
ASSERT_EQ(target_result.size(), NQ);
|
||||
|
||||
std::vector<long> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
uint64_t wrong_topk = TOP_K - 10;
|
||||
BuildResult(NQ, wrong_topk, src_ids, src_distence);
|
||||
|
||||
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for(uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
}
|
||||
|
||||
wrong_topk = TOP_K + 10;
|
||||
BuildResult(NQ, wrong_topk, src_ids, src_distence);
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for(uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, MERGE_TEST) {
|
||||
std::vector<long> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
std::vector<long> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
engine::SearchContext::ResultSet src_result, target_result;
|
||||
|
||||
uint64_t src_count = 5, target_count = 8;
|
||||
BuildResult(1, src_count, src_ids, src_distence);
|
||||
BuildResult(1, target_count, target_ids, target_distence);
|
||||
auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 10);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), 10);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
}
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target;
|
||||
status = engine::SearchTask::MergeResult(src, target, 10);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count);
|
||||
ASSERT_TRUE(src.empty());
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
}
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 30);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
}
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap target = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap src = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 30);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue