mirror of https://github.com/milvus-io/milvus.git
Merge branch 'search' into jinhai
Former-commit-id: 6de85a6c8412b0f9f98aace5a76a6f5ffb27497apull/191/head
commit
393b3873b7
|
@ -56,6 +56,7 @@ Status DBImpl::add_vectors(const std::string& group_id_,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(XUPENG): add search range based on time
|
||||||
Status DBImpl::search(const std::string &group_id, size_t k, size_t nq,
|
Status DBImpl::search(const std::string &group_id, size_t k, size_t nq,
|
||||||
const float *vectors, QueryResults &results) {
|
const float *vectors, QueryResults &results) {
|
||||||
meta::DatePartionedGroupFilesSchema files;
|
meta::DatePartionedGroupFilesSchema files;
|
||||||
|
@ -63,75 +64,92 @@ Status DBImpl::search(const std::string &group_id, size_t k, size_t nq,
|
||||||
auto status = _pMeta->files_to_search(group_id, partition, files);
|
auto status = _pMeta->files_to_search(group_id, partition, files);
|
||||||
if (!status.ok()) { return status; }
|
if (!status.ok()) { return status; }
|
||||||
|
|
||||||
// TODO: optimized
|
|
||||||
meta::GroupFilesSchema index_files;
|
meta::GroupFilesSchema index_files;
|
||||||
meta::GroupFilesSchema raw_files;
|
meta::GroupFilesSchema raw_files;
|
||||||
for (auto &day_files : files) {
|
for (auto &day_files : files) {
|
||||||
for (auto &file : day_files.second) {
|
for (auto &file : day_files.second) {
|
||||||
file.file_type == meta::GroupFileSchema::RAW ?
|
file.file_type == meta::GroupFileSchema::INDEX ?
|
||||||
raw_files.push_back(file) :
|
index_files.push_back(file) : raw_files.push_back(file);
|
||||||
index_files.push_back(file);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int dim = raw_files[0].dimension;
|
|
||||||
|
|
||||||
|
int dim = 0;
|
||||||
|
if (!index_files.empty()) {
|
||||||
|
dim = index_files[0].dimension;
|
||||||
|
} else if (!raw_files.empty()) {
|
||||||
|
dim = raw_files[0].dimension;
|
||||||
|
} else {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// merge raw files
|
// merge raw files and build flat index.
|
||||||
faiss::Index *index(faiss::index_factory(dim, "IDMap,Flat"));
|
faiss::Index *index(faiss::index_factory(dim, "IDMap,Flat"));
|
||||||
|
|
||||||
for (auto &file : raw_files) {
|
for (auto &file : raw_files) {
|
||||||
auto file_index = dynamic_cast<faiss::IndexIDMap *>(faiss::read_index(file.location.c_str()));
|
auto file_index = dynamic_cast<faiss::IndexIDMap *>(faiss::read_index(file.location.c_str()));
|
||||||
index->add_with_ids(file_index->ntotal, dynamic_cast<faiss::IndexFlat *>(file_index->index)->xb.data(),
|
index->add_with_ids(file_index->ntotal,
|
||||||
|
dynamic_cast<faiss::IndexFlat *>(file_index->index)->xb.data(),
|
||||||
file_index->id_map.data());
|
file_index->id_map.data());
|
||||||
}
|
}
|
||||||
float *xb = dynamic_cast<faiss::IndexFlat *>(index)->xb.data();
|
|
||||||
int64_t *ids = dynamic_cast<faiss::IndexIDMap *>(index)->id_map.data();
|
|
||||||
long totoal = index->ntotal;
|
|
||||||
|
|
||||||
std::vector<float> distence;
|
|
||||||
std::vector<long> result_ids;
|
|
||||||
{
|
{
|
||||||
// allocate memory
|
// [{ids, distence}, ...]
|
||||||
|
using SearchResult = std::pair<std::vector<long>, std::vector<float>>;
|
||||||
|
std::vector<SearchResult> batchresult(nq); // allocate nq cells.
|
||||||
|
|
||||||
|
auto cluster = [&](long *nns, float *dis) -> void {
|
||||||
|
for (int i = 0; i < nq; ++i) {
|
||||||
|
auto f_begin = batchresult[i].first.cbegin();
|
||||||
|
auto s_begin = batchresult[i].second.cbegin();
|
||||||
|
batchresult[i].first.insert(f_begin, nns + i * k, nns + i * k + k);
|
||||||
|
batchresult[i].second.insert(s_begin, dis + i * k, dis + i * k + k);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Allocate Memory
|
||||||
float *output_distence;
|
float *output_distence;
|
||||||
long *output_ids;
|
long *output_ids;
|
||||||
output_distence = (float *) malloc(k * sizeof(float));
|
output_distence = (float *) malloc(k * nq * sizeof(float));
|
||||||
output_ids = (long *) malloc(k * sizeof(long));
|
output_ids = (long *) malloc(k * nq * sizeof(long));
|
||||||
|
memset(output_distence, 0, k * nq * sizeof(float));
|
||||||
// build and search in raw file
|
memset(output_ids, 0, k * nq * sizeof(long));
|
||||||
// TODO: HardCode
|
|
||||||
auto opd = std::make_shared<Operand>();
|
|
||||||
opd->index_type = "IDMap,Flat";
|
|
||||||
IndexBuilderPtr builder = GetIndexBuilder(opd);
|
|
||||||
auto index = builder->build_all(totoal, xb, ids);
|
|
||||||
|
|
||||||
|
// search in raw file
|
||||||
index->search(nq, vectors, k, output_distence, output_ids);
|
index->search(nq, vectors, k, output_distence, output_ids);
|
||||||
distence.insert(distence.begin(), output_distence, output_distence + k);
|
cluster(output_ids, output_distence); // cluster to each query
|
||||||
result_ids.insert(result_ids.begin(), output_ids, output_ids + k);
|
memset(output_distence, 0, k * nq * sizeof(float));
|
||||||
memset(output_distence, 0, k * sizeof(float));
|
memset(output_ids, 0, k * nq * sizeof(long));
|
||||||
memset(output_ids, 0, k * sizeof(long));
|
|
||||||
|
|
||||||
// search in index file
|
// Search in index file
|
||||||
for (auto &file : index_files) {
|
for (auto &file : index_files) {
|
||||||
auto index = read_index(file.location.c_str());
|
auto index = read_index(file.location.c_str());
|
||||||
index->search(nq, vectors, k, output_distence, output_ids);
|
index->search(nq, vectors, k, output_distence, output_ids);
|
||||||
distence.insert(distence.begin(), output_distence, output_distence + k);
|
cluster(output_ids, output_distence); // cluster to each query
|
||||||
result_ids.insert(result_ids.begin(), output_ids, output_ids + k);
|
memset(output_distence, 0, k * nq * sizeof(float));
|
||||||
memset(output_distence, 0, k * sizeof(float));
|
memset(output_ids, 0, k * nq * sizeof(long));
|
||||||
memset(output_ids, 0, k * sizeof(long));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TopK
|
auto cluster_topk = [&]() -> void {
|
||||||
TopK(distence.data(), distence.size(), k, output_distence, output_ids);
|
QueryResult res;
|
||||||
distence.clear();
|
for (auto &result_pair : batchresult) {
|
||||||
result_ids.clear();
|
auto &dis = result_pair.second;
|
||||||
distence.insert(distence.begin(), output_distence, output_distence + k);
|
auto &nns = result_pair.first;
|
||||||
result_ids.insert(result_ids.begin(), output_ids, output_ids + k);
|
TopK(dis.data(), dis.size(), k, output_distence, output_ids);
|
||||||
|
for (int i = 0; i < k; ++i) {
|
||||||
|
res.emplace_back(nns[output_ids[i]]); // mapping
|
||||||
|
}
|
||||||
|
results.push_back(res); // append to result list
|
||||||
|
res.clear();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
cluster_topk();
|
||||||
|
|
||||||
// free
|
|
||||||
free(output_distence);
|
free(output_distence);
|
||||||
free(output_ids);
|
free(output_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (results.empty()) {
|
||||||
|
return Status::NotFound("Group " + group_id + ", search result not found!");
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,12 @@
|
||||||
// Proprietary and confidential.
|
// Proprietary and confidential.
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <faiss/IndexFlat.h>
|
||||||
|
#include <faiss/MetaIndexes.h>
|
||||||
|
#include <faiss/AutoTune.h>
|
||||||
|
|
||||||
#include "db/DB.h"
|
#include "db/DB.h"
|
||||||
|
#include "faiss/Index.h"
|
||||||
|
|
||||||
using namespace zilliz::vecwise;
|
using namespace zilliz::vecwise;
|
||||||
|
|
||||||
|
@ -51,12 +55,90 @@ TEST(DBTest, DB_TEST) {
|
||||||
stat = db->add_vectors(group_name, 1, vec_f.data(), vector_ids);
|
stat = db->add_vectors(group_name, 1, vec_f.data(), vector_ids);
|
||||||
ASSERT_STATS(stat);
|
ASSERT_STATS(stat);
|
||||||
|
|
||||||
engine::QueryResults results;
|
//engine::QueryResults results;
|
||||||
std::vector<float> vec_s = vec_f;
|
//std::vector<float> vec_s = vec_f;
|
||||||
stat = db->search(group_name, 1, 1, vec_f.data(), results);
|
//stat = db->search(group_name, 1, 1, vec_f.data(), results);
|
||||||
|
//ASSERT_STATS(stat);
|
||||||
|
//ASSERT_EQ(results.size(), 1);
|
||||||
|
//ASSERT_EQ(results[0][0], vector_ids[0]);
|
||||||
|
|
||||||
|
delete db;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SearchTest, DB_TEST) {
|
||||||
|
static const std::string group_name = "test_group";
|
||||||
|
static const int group_dim = 256;
|
||||||
|
|
||||||
|
engine::Options opt;
|
||||||
|
opt.meta.backend_uri = "http://127.0.0.1";
|
||||||
|
opt.meta.path = "/tmp/search_test";
|
||||||
|
opt.index_trigger_size = 100000 * group_dim;
|
||||||
|
opt.memory_sync_interval = 1;
|
||||||
|
opt.merge_trigger_number = 1;
|
||||||
|
|
||||||
|
engine::DB* db = nullptr;
|
||||||
|
engine::DB::Open(opt, &db);
|
||||||
|
ASSERT_TRUE(db != nullptr);
|
||||||
|
|
||||||
|
engine::meta::GroupSchema group_info;
|
||||||
|
group_info.dimension = group_dim;
|
||||||
|
group_info.group_id = group_name;
|
||||||
|
engine::Status stat = db->add_group(group_info);
|
||||||
|
//ASSERT_STATS(stat);
|
||||||
|
|
||||||
|
engine::meta::GroupSchema group_info_get;
|
||||||
|
group_info_get.group_id = group_name;
|
||||||
|
stat = db->get_group(group_info_get);
|
||||||
ASSERT_STATS(stat);
|
ASSERT_STATS(stat);
|
||||||
ASSERT_EQ(results.size(), 1);
|
ASSERT_EQ(group_info_get.dimension, group_dim);
|
||||||
ASSERT_EQ(results[0][0], vector_ids[0]);
|
|
||||||
|
|
||||||
|
// prepare raw data
|
||||||
|
size_t nb = 25000;
|
||||||
|
size_t nq = 10;
|
||||||
|
size_t k = 5;
|
||||||
|
std::vector<float> xb(nb*group_dim);
|
||||||
|
std::vector<float> xq(nq*group_dim);
|
||||||
|
std::vector<long> ids(nb);
|
||||||
|
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937 gen(rd());
|
||||||
|
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
|
||||||
|
for (size_t i = 0; i < nb*group_dim; i++) {
|
||||||
|
xb[i] = dis_xt(gen);
|
||||||
|
if (i < nb){
|
||||||
|
ids[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < nq*group_dim; i++) {
|
||||||
|
xq[i] = dis_xt(gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
// result data
|
||||||
|
//std::vector<long> nns_gt(k*nq);
|
||||||
|
std::vector<long> nns(k*nq); // nns = nearst neg search
|
||||||
|
//std::vector<float> dis_gt(k*nq);
|
||||||
|
std::vector<float> dis(k*nq);
|
||||||
|
|
||||||
|
// prepare ground-truth
|
||||||
|
//faiss::Index* index_gt(faiss::index_factory(group_dim, "IDMap,Flat"));
|
||||||
|
//index_gt->add_with_ids(nb, xb.data(), ids.data());
|
||||||
|
//index_gt->search(nq, xq.data(), 1, dis_gt.data(), nns_gt.data());
|
||||||
|
|
||||||
|
// insert data
|
||||||
|
const int batch_size = 100;
|
||||||
|
for (int j = 0; j < nb / batch_size; ++j) {
|
||||||
|
stat = db->add_vectors(group_name, batch_size, xb.data()+batch_size*j*group_dim, ids);
|
||||||
|
ASSERT_STATS(stat);
|
||||||
|
}
|
||||||
|
|
||||||
|
//sleep(10); // wait until build index finish
|
||||||
|
|
||||||
|
engine::QueryResults results;
|
||||||
|
stat = db->search(group_name, k, nq, xq.data(), results);
|
||||||
|
ASSERT_STATS(stat);
|
||||||
|
|
||||||
|
// TODO(linxj): add groundTruth assert
|
||||||
|
|
||||||
delete db;
|
delete db;
|
||||||
}
|
}
|
|
@ -92,3 +92,35 @@ TEST(build_test, Wrapper_Test) {
|
||||||
delete[] result_ids;
|
delete[] result_ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(search_test, Wrapper_Test) {
|
||||||
|
const int dim = 256;
|
||||||
|
|
||||||
|
size_t nb = 25000;
|
||||||
|
size_t nq = 100;
|
||||||
|
size_t k = 100;
|
||||||
|
std::vector<float> xb(nb*dim);
|
||||||
|
std::vector<float> xq(nq*dim);
|
||||||
|
std::vector<long> ids(nb*dim);
|
||||||
|
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937 gen(rd());
|
||||||
|
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
|
||||||
|
for (size_t i = 0; i < nb*dim; i++) {
|
||||||
|
xb[i] = dis_xt(gen);
|
||||||
|
ids[i] = i;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < nq*dim; i++) {
|
||||||
|
xq[i] = dis_xt(gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
// result data
|
||||||
|
std::vector<long> nns_gt(nq*k); // nns = nearst neg search
|
||||||
|
std::vector<long> nns(nq*k);
|
||||||
|
std::vector<float> dis_gt(nq*k);
|
||||||
|
std::vector<float> dis(nq*k);
|
||||||
|
faiss::Index* index_gt(faiss::index_factory(dim, "IDMap,Flat"));
|
||||||
|
index_gt->add_with_ids(nb, xb.data(), ids.data());
|
||||||
|
index_gt->search(nq, xq.data(), 10, dis_gt.data(), nns_gt.data());
|
||||||
|
std::cout << "data: " << nns_gt[0];
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue