mirror of https://github.com/milvus-io/milvus.git
parent
bc6f233709
commit
f49d23badd
|
@ -29,6 +29,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
- \#149 - Improve large query optimizer pass
|
||||
- \#156 - Not return error when search_resources and index_build_device set cpu
|
||||
- \#159 - Change the configuration name from 'use_gpu_threshold' to 'gpu_search_threshold'
|
||||
- \#168 - Improve result reduce
|
||||
|
||||
## Task
|
||||
|
||||
|
|
|
@ -67,15 +67,16 @@ class DB {
|
|||
|
||||
virtual Status
|
||||
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
QueryResults& results) = 0;
|
||||
ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
const meta::DatesT& dates, QueryResults& results) = 0;
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq,
|
||||
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) = 0;
|
||||
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Size(uint64_t& result) = 0;
|
||||
|
|
|
@ -336,20 +336,20 @@ DBImpl::DropIndex(const std::string& table_id) {
|
|||
|
||||
Status
|
||||
DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
QueryResults& results) {
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
if (shutting_down_.load(std::memory_order_acquire)) {
|
||||
return Status(DB_ERROR, "Milsvus server is shutdown!");
|
||||
}
|
||||
|
||||
meta::DatesT dates = {utils::GetDate()};
|
||||
Status result = Query(table_id, k, nq, nprobe, vectors, dates, results);
|
||||
Status result = Query(table_id, k, nq, nprobe, vectors, dates, result_ids, result_distances);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Status
|
||||
DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
const meta::DatesT& dates, QueryResults& results) {
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
if (shutting_down_.load(std::memory_order_acquire)) {
|
||||
return Status(DB_ERROR, "Milsvus server is shutdown!");
|
||||
}
|
||||
|
@ -372,14 +372,15 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr
|
|||
}
|
||||
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
|
||||
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results);
|
||||
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, result_ids, result_distances);
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
|
||||
return status;
|
||||
}
|
||||
|
||||
Status
|
||||
DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq,
|
||||
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) {
|
||||
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) {
|
||||
if (shutting_down_.load(std::memory_order_acquire)) {
|
||||
return Status(DB_ERROR, "Milsvus server is shutdown!");
|
||||
}
|
||||
|
@ -413,7 +414,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_
|
|||
}
|
||||
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
|
||||
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results);
|
||||
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, result_ids, result_distances);
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
|
||||
return status;
|
||||
}
|
||||
|
@ -432,7 +433,7 @@ DBImpl::Size(uint64_t& result) {
|
|||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Status
|
||||
DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq,
|
||||
uint64_t nprobe, const float* vectors, QueryResults& results) {
|
||||
uint64_t nprobe, const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
server::CollectQueryMetrics metrics(nq);
|
||||
|
||||
TimeRecorder rc("");
|
||||
|
@ -453,7 +454,8 @@ DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& fi
|
|||
}
|
||||
|
||||
// step 3: construct results
|
||||
results = job->GetResult();
|
||||
result_ids = job->GetResultIds();
|
||||
result_distances = job->GetResultDistances();
|
||||
rc.ElapseFromBegin("Engine query totally cost");
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -91,15 +91,16 @@ class DBImpl : public DB {
|
|||
|
||||
Status
|
||||
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
QueryResults& results) override;
|
||||
ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
const meta::DatesT& dates, QueryResults& results) override;
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq,
|
||||
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) override;
|
||||
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Size(uint64_t& result) override;
|
||||
|
@ -107,7 +108,7 @@ class DBImpl : public DB {
|
|||
private:
|
||||
Status
|
||||
QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq,
|
||||
uint64_t nprobe, const float* vectors, QueryResults& results);
|
||||
uint64_t nprobe, const float* vectors, ResultIds& result_ids, ResultDistances& result_distances);
|
||||
|
||||
void
|
||||
BackgroundTimerTask();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "db/engine/ExecutionEngine.h"
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -26,12 +27,13 @@
|
|||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
typedef int64_t IDNumber;
|
||||
using IDNumber = faiss::Index::idx_t;
|
||||
|
||||
typedef IDNumber* IDNumberPtr;
|
||||
typedef std::vector<IDNumber> IDNumbers;
|
||||
|
||||
typedef std::vector<std::pair<IDNumber, double>> QueryResult;
|
||||
typedef std::vector<QueryResult> QueryResults;
|
||||
typedef std::vector<faiss::Index::idx_t> ResultIds;
|
||||
typedef std::vector<faiss::Index::distance_t> ResultDistances;
|
||||
|
||||
struct TableIndex {
|
||||
int32_t engine_type_ = (int)EngineType::FAISS_IDMAP;
|
||||
|
|
|
@ -53,9 +53,14 @@ SearchJob::SearchDone(size_t index_id) {
|
|||
SERVER_LOG_DEBUG << "SearchJob " << id() << " finish index file: " << index_id;
|
||||
}
|
||||
|
||||
ResultSet&
|
||||
SearchJob::GetResult() {
|
||||
return result_;
|
||||
ResultIds&
|
||||
SearchJob::GetResultIds() {
|
||||
return result_ids_;
|
||||
}
|
||||
|
||||
ResultDistances&
|
||||
SearchJob::GetResultDistances() {
|
||||
return result_distances_;
|
||||
}
|
||||
|
||||
Status&
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "Job.h"
|
||||
#include "db/Types.h"
|
||||
#include "db/meta/MetaTypes.h"
|
||||
|
||||
namespace milvus {
|
||||
|
@ -37,9 +38,9 @@ namespace scheduler {
|
|||
using engine::meta::TableFileSchemaPtr;
|
||||
|
||||
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
|
||||
using IdDistPair = std::pair<int64_t, double>;
|
||||
using Id2DistVec = std::vector<IdDistPair>;
|
||||
using ResultSet = std::vector<Id2DistVec>;
|
||||
|
||||
using ResultIds = engine::ResultIds;
|
||||
using ResultDistances = engine::ResultDistances;
|
||||
|
||||
class SearchJob : public Job {
|
||||
public:
|
||||
|
@ -55,8 +56,11 @@ class SearchJob : public Job {
|
|||
void
|
||||
SearchDone(size_t index_id);
|
||||
|
||||
ResultSet&
|
||||
GetResult();
|
||||
ResultIds&
|
||||
GetResultIds();
|
||||
|
||||
ResultDistances&
|
||||
GetResultDistances();
|
||||
|
||||
Status&
|
||||
GetStatus();
|
||||
|
@ -104,7 +108,8 @@ class SearchJob : public Job {
|
|||
|
||||
Id2IndexMap index_files_;
|
||||
// TODO: column-base better ?
|
||||
ResultSet result_;
|
||||
ResultIds result_ids_;
|
||||
ResultDistances result_distances_;
|
||||
Status status_;
|
||||
|
||||
std::mutex mutex_;
|
||||
|
|
|
@ -222,7 +222,7 @@ XSearchTask::Execute() {
|
|||
{
|
||||
std::unique_lock<std::mutex> lock(search_job->mutex());
|
||||
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
|
||||
search_job->GetResult());
|
||||
search_job->GetResultIds(), search_job->GetResultDistances());
|
||||
}
|
||||
|
||||
span = rc.RecordSection(hdr + ", reduce topk");
|
||||
|
@ -243,71 +243,75 @@ XSearchTask::Execute() {
|
|||
}
|
||||
|
||||
void
|
||||
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
|
||||
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending,
|
||||
scheduler::ResultSet& result) {
|
||||
if (result.empty()) {
|
||||
result.resize(nq);
|
||||
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& src_ids, const std::vector<float>& src_distances,
|
||||
size_t src_k, size_t nq, size_t topk, bool ascending, scheduler::ResultIds& tar_ids,
|
||||
scheduler::ResultDistances& tar_distances) {
|
||||
if (src_ids.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (tar_ids.empty()) {
|
||||
tar_ids = src_ids;
|
||||
tar_distances = src_distances;
|
||||
return;
|
||||
}
|
||||
|
||||
size_t tar_k = tar_ids.size() / nq;
|
||||
size_t buf_k = std::min(topk, src_k + tar_k);
|
||||
|
||||
scheduler::ResultIds buf_ids(nq * buf_k, -1);
|
||||
scheduler::ResultDistances buf_distances(nq * buf_k, 0.0);
|
||||
|
||||
for (uint64_t i = 0; i < nq; i++) {
|
||||
scheduler::Id2DistVec result_buf;
|
||||
auto& result_i = result[i];
|
||||
size_t buf_k_j = 0, src_k_j = 0, tar_k_j = 0;
|
||||
size_t buf_idx, src_idx, tar_idx;
|
||||
|
||||
if (result[i].empty()) {
|
||||
result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0));
|
||||
uint64_t input_k_multi_i = topk * i;
|
||||
for (auto k = 0; k < input_k; ++k) {
|
||||
uint64_t idx = input_k_multi_i + k;
|
||||
auto& result_buf_item = result_buf[k];
|
||||
result_buf_item.first = input_ids[idx];
|
||||
result_buf_item.second = input_distance[idx];
|
||||
size_t buf_k_multi_i = buf_k * i;
|
||||
size_t src_k_multi_i = topk * i;
|
||||
size_t tar_k_multi_i = tar_k * i;
|
||||
|
||||
while (buf_k_j < buf_k && src_k_j < src_k && tar_k_j < tar_k) {
|
||||
src_idx = src_k_multi_i + src_k_j;
|
||||
tar_idx = tar_k_multi_i + tar_k_j;
|
||||
buf_idx = buf_k_multi_i + buf_k_j;
|
||||
|
||||
if ((ascending && src_distances[src_idx] < tar_distances[tar_idx]) ||
|
||||
(!ascending && src_distances[src_idx] > tar_distances[tar_idx])) {
|
||||
buf_ids[buf_idx] = src_ids[src_idx];
|
||||
buf_distances[buf_idx] = src_distances[src_idx];
|
||||
src_k_j++;
|
||||
} else {
|
||||
buf_ids[buf_idx] = tar_ids[tar_idx];
|
||||
buf_distances[buf_idx] = tar_distances[tar_idx];
|
||||
tar_k_j++;
|
||||
}
|
||||
} else {
|
||||
size_t tar_size = result_i.size();
|
||||
uint64_t output_k = std::min(topk, input_k + tar_size);
|
||||
result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0));
|
||||
size_t buf_k = 0, src_k = 0, tar_k = 0;
|
||||
uint64_t src_idx;
|
||||
uint64_t input_k_multi_i = topk * i;
|
||||
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf[buf_k];
|
||||
auto& result_item = result_i[tar_k];
|
||||
if ((ascending && input_distance[src_idx] < result_item.second) ||
|
||||
(!ascending && input_distance[src_idx] > result_item.second)) {
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
} else {
|
||||
result_buf_item = result_item;
|
||||
tar_k++;
|
||||
buf_k_j++;
|
||||
}
|
||||
|
||||
if (buf_k_j < buf_k) {
|
||||
if (src_k_j < src_k) {
|
||||
while (buf_k_j < buf_k && src_k_j < src_k) {
|
||||
buf_idx = buf_k_multi_i + buf_k_j;
|
||||
src_idx = src_k_multi_i + src_k_j;
|
||||
buf_ids[buf_idx] = src_ids[src_idx];
|
||||
buf_distances[buf_idx] = src_distances[src_idx];
|
||||
src_k_j++;
|
||||
buf_k_j++;
|
||||
}
|
||||
buf_k++;
|
||||
}
|
||||
|
||||
if (buf_k < output_k) {
|
||||
if (src_k < input_k) {
|
||||
while (buf_k < output_k && src_k < input_k) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf[buf_k];
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
buf_k++;
|
||||
}
|
||||
} else {
|
||||
while (buf_k < output_k && tar_k < tar_size) {
|
||||
result_buf[buf_k] = result_i[tar_k];
|
||||
tar_k++;
|
||||
buf_k++;
|
||||
}
|
||||
} else {
|
||||
while (buf_k_j < buf_k && tar_k_j < tar_k) {
|
||||
buf_idx = buf_k_multi_i + buf_k_j;
|
||||
tar_idx = tar_k_multi_i + tar_k_j;
|
||||
buf_ids[buf_idx] = tar_ids[tar_idx];
|
||||
buf_distances[buf_idx] = tar_distances[tar_idx];
|
||||
tar_k_j++;
|
||||
buf_k_j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result_i.swap(result_buf);
|
||||
}
|
||||
tar_ids.swap(buf_ids);
|
||||
tar_distances.swap(buf_distances);
|
||||
}
|
||||
|
||||
// void
|
||||
|
|
|
@ -39,8 +39,9 @@ class XSearchTask : public Task {
|
|||
|
||||
public:
|
||||
static void
|
||||
MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
|
||||
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
|
||||
MergeTopkToResultSet(const std::vector<int64_t>& src_ids, const std::vector<float>& src_distances, uint64_t src_k,
|
||||
uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultIds& tar_ids,
|
||||
scheduler::ResultDistances& tar_distances);
|
||||
|
||||
// static void
|
||||
// MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
|
||||
|
|
|
@ -637,7 +637,8 @@ SearchTask::OnExecute() {
|
|||
rc.RecordSection("prepare vector data");
|
||||
|
||||
// step 6: search vectors
|
||||
engine::QueryResults results;
|
||||
engine::ResultIds result_ids;
|
||||
engine::ResultDistances result_distances;
|
||||
auto record_count = (uint64_t)search_param_->query_record_array().size();
|
||||
|
||||
#ifdef MILVUS_ENABLE_PROFILING
|
||||
|
@ -647,11 +648,11 @@ SearchTask::OnExecute() {
|
|||
#endif
|
||||
|
||||
if (file_id_array_.empty()) {
|
||||
status =
|
||||
DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates, results);
|
||||
status = DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates,
|
||||
result_ids, result_distances);
|
||||
} else {
|
||||
status = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t)top_k, record_count, nprobe,
|
||||
vec_f.data(), dates, results);
|
||||
vec_f.data(), dates, result_ids, result_distances);
|
||||
}
|
||||
|
||||
#ifdef MILVUS_ENABLE_PROFILING
|
||||
|
@ -663,23 +664,20 @@ SearchTask::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
if (results.empty()) {
|
||||
if (result_ids.empty()) {
|
||||
return Status::OK(); // empty table
|
||||
}
|
||||
|
||||
if (results.size() != record_count) {
|
||||
std::string msg = "Search " + std::to_string(record_count) + " vectors but only return " +
|
||||
std::to_string(results.size()) + " results";
|
||||
return Status(SERVER_ILLEGAL_SEARCH_RESULT, msg);
|
||||
}
|
||||
size_t result_k = result_ids.size() / record_count;
|
||||
|
||||
// step 7: construct result array
|
||||
for (auto& result : results) {
|
||||
for (size_t i = 0; i < record_count; i++) {
|
||||
::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result();
|
||||
for (auto& pair : result) {
|
||||
for (size_t j = 0; j < result_k; j++) {
|
||||
::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays();
|
||||
grpc_result->set_id(pair.first);
|
||||
grpc_result->set_distance(pair.second);
|
||||
size_t idx = i * result_k + j;
|
||||
grpc_result->set_id(result_ids[idx]);
|
||||
grpc_result->set_distance(result_distances[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -175,7 +175,8 @@ TEST_F(DBTest, DB_TEST) {
|
|||
BuildVectors(qb, qxb);
|
||||
|
||||
std::thread search([&]() {
|
||||
milvus::engine::QueryResults results;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
int k = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
|
||||
|
@ -190,17 +191,17 @@ TEST_F(DBTest, DB_TEST) {
|
|||
prev_count = count;
|
||||
|
||||
START_TIMER;
|
||||
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
|
||||
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
ASSERT_TRUE(stat.ok());
|
||||
for (auto k = 0; k < qb; ++k) {
|
||||
ASSERT_EQ(results[k][0].first, target_ids[k]);
|
||||
for (auto i = 0; i < qb; ++i) {
|
||||
ASSERT_EQ(result_ids[i*k], target_ids[i]);
|
||||
ss.str("");
|
||||
ss << "Result [" << k << "]:";
|
||||
for (auto result : results[k]) {
|
||||
ss << result.first << " ";
|
||||
ss << "Result [" << i << "]:";
|
||||
for (auto t = 0; t < k; t++) {
|
||||
ss << result_ids[i * k + t] << " ";
|
||||
}
|
||||
/* LOG(DEBUG) << ss.str(); */
|
||||
}
|
||||
|
@ -284,16 +285,18 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
|
||||
{
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
{//search by specify index file
|
||||
milvus::engine::meta::DatesT dates;
|
||||
std::vector<std::string> file_ids = {"1", "2", "3", "4", "5", "6"};
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -303,22 +306,25 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
|
||||
{
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
{
|
||||
milvus::engine::QueryResults large_nq_results;
|
||||
stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), large_nq_results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
{//search by specify index file
|
||||
milvus::engine::meta::DatesT dates;
|
||||
std::vector<std::string> file_ids = {"1", "2", "3", "4", "5", "6"};
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -391,11 +397,12 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
|
|||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
milvus::engine::meta::DatesT dates;
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
std::vector<std::string> file_ids;
|
||||
stat = db_->Query(table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, results);
|
||||
stat = db_->Query(table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
stat = db_->DeleteTable(table_info.table_id_, dates);
|
||||
|
|
|
@ -81,7 +81,8 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
ASSERT_EQ(target_ids.size(), qb);
|
||||
|
||||
std::thread search([&]() {
|
||||
milvus::engine::QueryResults results;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
int k = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(5));
|
||||
|
||||
|
@ -96,25 +97,25 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
prev_count = count;
|
||||
|
||||
START_TIMER;
|
||||
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
|
||||
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
ASSERT_TRUE(stat.ok());
|
||||
for (auto k = 0; k < qb; ++k) {
|
||||
for (auto i = 0; i < qb; ++i) {
|
||||
// std::cout << results[k][0].first << " " << target_ids[k] << std::endl;
|
||||
// ASSERT_EQ(results[k][0].first, target_ids[k]);
|
||||
bool exists = false;
|
||||
for (auto &result : results[k]) {
|
||||
if (result.first == target_ids[k]) {
|
||||
for (auto t = 0; t < k; t++) {
|
||||
if (result_ids[i * k + t] == target_ids[i]) {
|
||||
exists = true;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(exists);
|
||||
ss.str("");
|
||||
ss << "Result [" << k << "]:";
|
||||
for (auto result : results[k]) {
|
||||
ss << result.first << " ";
|
||||
ss << "Result [" << i << "]:";
|
||||
for (auto t = 0; t < k; t++) {
|
||||
ss << result_ids[i * k + t] << " ";
|
||||
}
|
||||
/* LOG(DEBUG) << ss.str(); */
|
||||
}
|
||||
|
@ -188,8 +189,9 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
|
|||
|
||||
sleep(2); // wait until build index finish
|
||||
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
|
|
@ -259,10 +259,11 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
|
|||
int topk = 10, nprobe = 10;
|
||||
for (auto& pair : search_vectors) {
|
||||
auto& search = pair.second;
|
||||
milvus::engine::QueryResults results;
|
||||
stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), results);
|
||||
ASSERT_EQ(results[0][0].first, pair.first);
|
||||
ASSERT_LT(results[0][0].second, 1e-4);
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), result_ids, result_distances);
|
||||
ASSERT_EQ(result_ids[0], pair.first);
|
||||
ASSERT_LT(result_distances[0], 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,7 +315,8 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
|
|||
BuildVectors(qb, qxb);
|
||||
|
||||
std::thread search([&]() {
|
||||
milvus::engine::QueryResults results;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
int k = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
|
||||
|
@ -329,17 +331,17 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
|
|||
prev_count = count;
|
||||
|
||||
START_TIMER;
|
||||
stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), results);
|
||||
stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
ASSERT_TRUE(stat.ok());
|
||||
for (auto k = 0; k < qb; ++k) {
|
||||
ASSERT_EQ(results[k][0].first, target_ids[k]);
|
||||
for (auto i = 0; i < qb; ++i) {
|
||||
ASSERT_EQ(result_ids[i * k], target_ids[i]);
|
||||
ss.str("");
|
||||
ss << "Result [" << k << "]:";
|
||||
for (auto result : results[k]) {
|
||||
ss << result.first << " ";
|
||||
ss << "Result [" << i << "]:";
|
||||
for (auto t = 0; t < k; t++) {
|
||||
ss << result_ids[i * k + t] << " ";
|
||||
}
|
||||
/* LOG(DEBUG) << ss.str(); */
|
||||
}
|
||||
|
|
|
@ -85,8 +85,10 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
|
|||
uint64_t topk,
|
||||
uint64_t nq,
|
||||
bool ascending,
|
||||
const milvus::scheduler::ResultSet& result) {
|
||||
ASSERT_EQ(result.size(), nq);
|
||||
const ms::ResultIds& result_ids,
|
||||
const ms::ResultDistances& result_distances) {
|
||||
ASSERT_EQ(result_ids.size(), nq * topk);
|
||||
ASSERT_EQ(result_distances.size(), nq * topk);
|
||||
ASSERT_EQ(input_ids_1.size(), input_distance_1.size());
|
||||
ASSERT_EQ(input_ids_2.size(), input_distance_2.size());
|
||||
|
||||
|
@ -111,15 +113,16 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
|
|||
++iter;
|
||||
}
|
||||
|
||||
uint64_t n = std::min(topk, result[i].size());
|
||||
uint64_t n = std::min(topk, result_ids.size() / nq);
|
||||
for (uint64_t j = 0; j < n; j++) {
|
||||
if (result[i][j].first < 0) {
|
||||
uint64_t idx = i * n + j;
|
||||
if (result_ids[idx] < 0) {
|
||||
continue;
|
||||
}
|
||||
if (src_vec[j] != result[i][j].second) {
|
||||
std::cout << src_vec[j] << " " << result[i][j].second << std::endl;
|
||||
if (src_vec[j] != result_distances[idx]) {
|
||||
std::cout << src_vec[j] << " " << result_distances[idx] << std::endl;
|
||||
}
|
||||
ASSERT_TRUE(src_vec[j] == result[i][j].second);
|
||||
ASSERT_TRUE(src_vec[j] == result_distances[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -130,12 +133,13 @@ void
|
|||
MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) {
|
||||
std::vector<int64_t> ids1, ids2;
|
||||
std::vector<float> dist1, dist2;
|
||||
ms::ResultSet result;
|
||||
ms::ResultIds result_ids;
|
||||
ms::ResultDistances result_distances;
|
||||
BuildResult(ids1, dist1, topk_1, topk, nq, ascending);
|
||||
BuildResult(ids2, dist2, topk_2, topk, nq, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result_ids, result_distances);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result_ids, result_distances);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result_ids, result_distances);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, MERGE_RESULT_SET_TEST) {
|
||||
|
@ -222,9 +226,9 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
|
|||
int32_t index_file_num = 478; /* sift1B dataset, index files num */
|
||||
bool ascending = true;
|
||||
|
||||
std::vector<int32_t> thread_vec = {4, 8};
|
||||
std::vector<int32_t> nq_vec = {1, 10, 100};
|
||||
std::vector<int32_t> topk_vec = {1, 4, 16, 64};
|
||||
std::vector<int32_t> thread_vec = {4};
|
||||
std::vector<int32_t> nq_vec = {1000};
|
||||
std::vector<int32_t> topk_vec = {64};
|
||||
int32_t NQ = nq_vec[nq_vec.size() - 1];
|
||||
int32_t TOPK = topk_vec[topk_vec.size() - 1];
|
||||
|
||||
|
@ -247,7 +251,8 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
|
|||
|
||||
for (int32_t nq : nq_vec) {
|
||||
for (int32_t top_k : topk_vec) {
|
||||
ms::ResultSet final_result, final_result_2, final_result_3;
|
||||
ms::ResultIds final_result_ids, final_result_ids_2, final_result_ids_3;
|
||||
ms::ResultDistances final_result_distances, final_result_distances_2, final_result_distances_3;
|
||||
|
||||
std::vector<std::vector<int64_t>> id_vec_1(index_file_num);
|
||||
std::vector<std::vector<float>> dist_vec_1(index_file_num);
|
||||
|
@ -268,8 +273,10 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
|
|||
nq,
|
||||
top_k,
|
||||
ascending,
|
||||
final_result);
|
||||
ASSERT_EQ(final_result.size(), nq);
|
||||
final_result_ids,
|
||||
final_result_distances);
|
||||
ASSERT_EQ(final_result_ids.size(), nq * top_k);
|
||||
ASSERT_EQ(final_result_distances.size(), nq * top_k);
|
||||
}
|
||||
|
||||
rc1.RecordSection("reduce done");
|
||||
|
|
|
@ -75,7 +75,8 @@ TEST_F(MetricTest, METRIC_TEST) {
|
|||
}
|
||||
|
||||
std::thread search([&]() {
|
||||
milvus::engine::QueryResults results;
|
||||
// milvus::engine::ResultIds result_ids;
|
||||
// milvus::engine::ResultDistances result_distances;
|
||||
int k = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
|
||||
|
@ -90,7 +91,7 @@ TEST_F(MetricTest, METRIC_TEST) {
|
|||
prev_count = count;
|
||||
|
||||
START_TIMER;
|
||||
// stat = db_->Query(group_name, k, qb, qxb, results);
|
||||
// stat = db_->Query(group_name, k, qb, qxb, result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << (float) (count * group_dim * sizeof(float)) / (1024 * 1024)
|
||||
<< " M";
|
||||
|
||||
|
|
Loading…
Reference in New Issue