index search unittest

Former-commit-id: 9632c0668ce71d07bba6f13e9ee46df4e1af6b38
pull/191/head
groot 2019-04-21 15:55:48 +08:00
parent 885b95e2ea
commit 2b035e54a6
1 changed files with 60 additions and 24 deletions

View File

@ -23,36 +23,72 @@ TEST(operand_test, Wrapper_Test) {
TEST(build_test, Wrapper_Test) {
// dimension of the vectors to index
int d = 64;
int d = 3;
// make a set of nt training vectors in the unit cube
size_t nt = 10000;
// a reasonable number of cetroids to index nb vectors
int ncentroids = 16;
std::random_device rd;
std::mt19937 gen(rd());
std::vector<float> xb;
std::vector<long> ids;
//prepare train data
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
std::vector<float> xt(nt * d);
for (size_t i = 0; i < nt * d; i++) {
xt[i] = dis_xt(gen);
}
//train the index
auto opd = std::make_shared<Operand>();
opd->index_type = "IVF16,Flat";
opd->d = d;
opd->ncent = ncentroids;
IndexBuilderPtr index_builder_1 = GetIndexBuilder(opd);
auto index_1 = index_builder_1->build_all(0, xb, ids, nt, xt);
ASSERT_TRUE(index_1 != nullptr);
// size of the database we plan to index
size_t nb = 100000;
// make a set of nt training vectors in the unit cube
size_t nt = 150000;
// a reasonable number of cetroids to index nb vectors
int ncentroids = 25;
srand48(35); // seed
std::vector<float> xb(nb * d);
for (size_t i = 0; i < nb * d; i++) {
xb[i] = drand48();
}
std::vector<long> ids(nb);
//prepare raw data
xb.resize(nb);
ids.resize(nb);
for (size_t i = 0; i < nb; i++) {
ids[i] = drand48();
xb[i] = dis_xt(gen);
ids[i] = i;
}
index_1->add_with_ids(nb, xb.data(), ids.data());
//search in first quadrant
int nq = 1, k = 10;
std::vector<float> xq = {0.5, 0.5, 0.5};
float* result_dists = new float[k];
long* result_ids = new long[k];
index_1->search(nq, xq.data(), k, result_dists, result_ids);
for(int i = 0; i < k; i++) {
if(result_ids[i] < 0) {
ASSERT_TRUE(false);
break;
}
long id = result_ids[i];
std::cout << "No." << id << " [" << xb[id*3] << ", " << xb[id*3 + 1] << ", "
<< xb[id*3 + 2] <<"] distance = " << result_dists[i] << std::endl;
//makesure result vector is in first quadrant
ASSERT_TRUE(xb[id*3] > 0.0);
ASSERT_TRUE(xb[id*3 + 1] > 0.0);
ASSERT_TRUE(xb[id*3 + 2] > 0.0);
}
std::vector<float> xt(nt * d);
for (size_t i = 0; i < nt * d; i++) {
xt[i] = drand48();
}
auto opd = std::make_shared<Operand>();
IndexBuilderPtr index_builder_1 = GetIndexBuilder(opd);
auto index_1 = index_builder_1->build_all(nb, xb, ids, nt, xt);
delete[] result_dists;
delete[] result_ids;
}