refactor knowhere inner unittest for annoy and hnsw, which a… (#1855)

* skip ci, refactor knowhere inner unittest for annoy and hnsw, which aimed to remove logs of std output

Signed-off-by: cmli <chengming.li@zilliz.com>

* skip ci, fix lint error

Signed-off-by: cmli <chengming.li@zilliz.com>

* fix annoy load bugs and add serialize test 4 annoy and hnsw

Signed-off-by: cmli <chengming.li@zilliz.com>

* update changelog

Signed-off-by: cmli <chengming.li@zilliz.com>

* update lower bound of annoy search_k param

Signed-off-by: cmli <chengming.li@zilliz.com>

* lint errorgit add src/utils/ValidationUtil.cpp

Signed-off-by: cmli <chengming.li@zilliz.com>

Co-authored-by: cmli <chengming.li@zilliz.com>
pull/1868/head
op-hunter 2020-04-03 14:14:04 +08:00 committed by GitHub
parent a57c97ddc3
commit 54337e9361
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 330 additions and 124 deletions

View File

@ -125,6 +125,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1601 External link bug in HTTP doc
- \#1609 Refine Compact function
- \#1808 Building index params check for Annoy
- \#1852 Search index type<Annoy> failed with reason `failed to load index file`
## Feature
- \#216 Add CLI to get server info

View File

@ -57,7 +57,7 @@ IndexAnnoy::Serialize(const Config& config) {
void
IndexAnnoy::Load(const BinarySet& index_binary) {
auto metric_type = index_binary.GetByName("annoy_metric_type");
metric_type_.resize((size_t)metric_type->size + 1);
metric_type_.resize((size_t)metric_type->size);
memcpy(metric_type_.data(), metric_type->data.get(), (size_t)metric_type->size);
auto dim_data = index_binary.GetByName("annoy_dim");
@ -74,7 +74,7 @@ IndexAnnoy::Load(const BinarySet& index_binary) {
auto index_data = index_binary.GetByName("annoy_index_data");
char* p = nullptr;
if (!index_->load_index(index_data->data.get(), index_data->size, &p)) {
if (!index_->load_index(reinterpret_cast<void*>(index_data->data.get()), index_data->size, &p)) {
std::string error_msg(p);
free(p);
KNOWHERE_THROW_MSG(error_msg);

View File

@ -817,7 +817,7 @@ class AnnoyIndexInterface {
virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0;
virtual void unload() = 0;
virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0;
virtual bool load_index(const unsigned char* index_data, const int64_t& index_size, char** error = NULL) = 0;
virtual bool load_index(void* index_data, const int64_t& index_size, char** error = NULL) = 0;
virtual T get_distance(S i, S j) const = 0;
virtual void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances,
faiss::ConcurrentBitsetPtr bitset = nullptr) const = 0;
@ -1109,7 +1109,7 @@ public:
return true;
}
bool load_index(const unsigned char* index_data, const int64_t& index_size, char** error) {
bool load_index(void* index_data, const int64_t& index_size, char** error) {
if (index_size == -1) {
set_error_from_errno(error, "Unable to get size");
return false;
@ -1123,7 +1123,8 @@ public:
}
_n_nodes = (S)(index_size / _s);
_nodes = (Node*)malloc(_s * _n_nodes);
// _nodes = (Node*)malloc(_s * _n_nodes);
_nodes = (Node*)malloc((size_t)index_size);
memcpy(_nodes, index_data, (size_t)index_size);
// Find the roots by scanning the end of the file and taking the nodes with most descendants
@ -1177,7 +1178,7 @@ public:
}
int64_t get_index_length() const {
return (int64_t)_s * _nodes_size;
return (int64_t)_s * _n_nodes;
}
void* get_index() const {

View File

@ -23,6 +23,184 @@ using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class AnnoyTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
// std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(128, 10000, 10);
index_ = std::make_shared<milvus::knowhere::IndexAnnoy>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::n_trees, 4},
{milvus::knowhere::IndexParams::search_k, 100},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
// Init_with_default();
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexAnnoy> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy"));
TEST_P(AnnoyTest, annoy_basic) {
assert(!xb.empty());
// index_->Train(base_dataset, conf);
index_->BuildAll(base_dataset, conf); // Train + Add
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k);
/*
* output result to check by eyes
{
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::stringstream ss_id;
std::stringstream ss_dist;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
ss_dist << *((float*)(dist) + i * k + j) << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
}
std::cout << "id\n" << ss_id.str() << std::endl;
std::cout << "dist\n" << ss_dist.str() << std::endl;
}
*/
}
TEST_P(AnnoyTest, annoy_delete) {
assert(!xb.empty());
index_->BuildAll(base_dataset, conf); // Train + Add
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
* delete result checked by eyes
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
std::cout << std::endl;
for (int i = 0; i < nq; ++ i) {
std::cout << "ids1: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids1 + i * k + j) << " ";
}
std::cout << " ids2: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids2 + i * k + j) << " ";
}
std::cout << std::endl;
for (int j = 0; j < std::min(5, k>>1); ++ j) {
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
}
}
*/
/*
* output result to check by eyes
{
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::stringstream ss_id;
std::stringstream ss_dist;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
ss_dist << *((float*)(dist) + i * k + j) << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
}
std::cout << "id\n" << ss_id.str() << std::endl;
std::cout << "dist\n" << ss_dist.str() << std::endl;
}
*/
}
TEST_P(AnnoyTest, annoy_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
// write and flush
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
// serialize index
index_->BuildAll(base_dataset, conf);
auto binaryset = index_->Serialize();
auto bin_data = binaryset.GetByName("annoy_index_data");
std::string filename1 = "/tmp/annoy_test_data_serialize.bin";
auto load_data1 = new uint8_t[bin_data->size];
serialize(filename1, bin_data, load_data1);
auto bin_metric_type = binaryset.GetByName("annoy_metric_type");
std::string filename2 = "/tmp/annoy_test_metric_type_serialize.bin";
auto load_data2 = new uint8_t[bin_metric_type->size];
serialize(filename2, bin_metric_type, load_data2);
auto bin_dim = binaryset.GetByName("annoy_dim");
std::string filename3 = "/tmp/annoy_test_dim_serialize.bin";
auto load_data3 = new uint8_t[bin_dim->size];
serialize(filename3, bin_dim, load_data3);
binaryset.clear();
std::shared_ptr<uint8_t[]> index_data(load_data1);
binaryset.Append("annoy_index_data", index_data, bin_data->size);
std::shared_ptr<uint8_t[]> metric_data(load_data2);
binaryset.Append("annoy_metric_type", metric_data, bin_metric_type->size);
std::shared_ptr<uint8_t[]> dim_data(load_data3);
binaryset.Append("annoy_dim", dim_data, bin_dim->size);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}
/*
* faiss style test
* keep it
int
main() {
int64_t d = 64; // dimension
@ -127,95 +305,4 @@ main() {
return 0;
}
/*
class AnnoyTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(128, 1000, 5);
index_ = std::make_shared<milvus::knowhere::IndexAnnoy>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, dim},
{milvus::knowhere::meta::TOPK, 1},
{milvus::knowhere::IndexParams::n_trees, 4},
{milvus::knowhere::IndexParams::search_k, 100},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
// Init_with_default();
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexAnnoy> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values(""));
TEST_P(AnnoyTest, annoy_basic) {
assert(!xb.empty());
// index_->Train(base_dataset, conf);
index_->BuildAll(base_dataset, conf);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k);
{
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::stringstream ss_id;
std::stringstream ss_dist;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
ss_dist << *((float*)(dist) + i * k + j) << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
}
std::cout << "id\n" << ss_id.str() << std::endl;
std::cout << "dist\n" << ss_dist.str() << std::endl;
}
}
TEST_P(AnnoyTest, annoy_delete) {
assert(!xb.empty());
// index_->Train(base_dataset, conf);
index_->BuildAll(base_dataset, conf);
// index_->Add(base_dataset, conf);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++ i) {
bitset->set(i);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k);
{
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::stringstream ss_id;
std::stringstream ss_dist;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
ss_dist << *((float*)(dist) + i * k + j) << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
}
std::cout << "id\n" << ss_id.str() << std::endl;
std::cout << "dist\n" << ss_dist.str() << std::endl;
} }
}
*/

View File

@ -9,12 +9,130 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <gtest/gtest.h>
#include <knowhere/index/vector_index/IndexHNSW.h>
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
#include <iostream>
#include <random>
#include "./utils.h"
#include "knowhere/common/Exception.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class HNSWTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
index_ = std::make_shared<milvus::knowhere::IndexHNSW>();
conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
};
}
protected:
milvus::knowhere::Config conf;
std::shared_ptr<milvus::knowhere::IndexHNSW> index_ = nullptr;
std::string IndexType;
};
INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWTest, Values("HNSW"));
TEST_P(HNSWTest, HNSW_basic) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k);
}
TEST_P(HNSWTest, HNSW_delete) {
assert(!xb.empty());
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) {
bitset->set(i);
}
auto result1 = index_->Query(query_dataset, conf);
AssertAnns(result1, nq, k);
index_->SetBlacklist(bitset);
auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
/*
* delete result checked by eyes
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
std::cout << std::endl;
for (int i = 0; i < nq; ++ i) {
std::cout << "ids1: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids1 + i * k + j) << " ";
}
std::cout << "ids2: ";
for (int j = 0; j < k; ++ j) {
std::cout << *(ids2 + i * k + j) << " ";
}
std::cout << std::endl;
for (int j = 0; j < std::min(5, k>>1); ++ j) {
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
}
}
*/
}
TEST_P(HNSWTest, HNSW_serialize) {
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
{
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
}
FileIOReader reader(filename);
reader(ret, bin->size);
};
{
index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize();
auto bin = binaryset.GetByName("HNSW");
std::string filename = "/tmp/HNSW_test_serialize.bin";
auto load_data = new uint8_t[bin->size];
serialize(filename, bin, load_data);
binaryset.clear();
std::shared_ptr<uint8_t[]> data(load_data);
binaryset.Append("HNSW", data, bin->size);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
}
}
/*
* faiss style test
* keep it
int
main() {
int64_t d = 64; // dimension
@ -34,33 +152,31 @@ main() {
xb[d * i] += i / 1000.;
ids[i] = i;
}
printf("gen xb and ids done! \n");
// printf("gen xb and ids done! \n");
// srand((unsigned)time(NULL));
auto random_seed = (unsigned)time(NULL);
printf("delete ids: \n");
// printf("delete ids: \n");
for (int i = 0; i < nq; i++) {
auto tmp = rand_r(&random_seed) % nb;
printf("%ld\n", tmp);
// printf("%ld\n", tmp);
// std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl;
bitset->set(tmp);
// std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl;
for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j];
// xq[d * i] += i / 1000.;
}
printf("\n");
// printf("\n");
int k = 4;
int m = 16;
int ef = 200;
milvus::knowhere::IndexHNSW index;
milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids);
/*
base_dataset->Set(milvus::knowhere::meta::ROWS, nb);
base_dataset->Set(milvus::knowhere::meta::DIM, d);
base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb);
base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids);
*/
// base_dataset->Set(milvus::knowhere::meta::ROWS, nb);
// base_dataset->Set(milvus::knowhere::meta::DIM, d);
// base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb);
// base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids);
milvus::knowhere::Config base_conf{
{milvus::knowhere::meta::DIM, d},
@ -81,27 +197,27 @@ main() {
index.Train(base_dataset, base_conf);
index.Add(base_dataset, base_conf);
printf("------------sanity check----------------\n");
// printf("------------sanity check----------------\n");
{ // sanity check
auto res = index.Query(query_dataset, query_conf);
printf("Query done!\n");
// printf("Query done!\n");
const int64_t* I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
float* D = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
// float* D = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
printf("I=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
printf("\n");
}
// printf("I=\n");
// for (int i = 0; i < 5; i++) {
// for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
// printf("\n");
// }
printf("D=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]);
printf("\n");
}
// printf("D=\n");
// for (int i = 0; i < 5; i++) {
// for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]);
// printf("\n");
// }
}
printf("---------------search xq-------------\n");
// printf("---------------search xq-------------\n");
{ // search xq
auto res = index.Query(query_dataset, query_conf);
const int64_t* I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
@ -132,3 +248,4 @@ main() {
return 0;
}
*/

View File

@ -311,8 +311,8 @@ ValidationUtil::ValidateSearchParams(const milvus::json& search_params,
break;
}
case (int32_t)engine::EngineType::ANNOY: {
auto status = CheckParameterRange(search_params, knowhere::IndexParams::search_k,
std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max());
auto status = CheckParameterRange(search_params, knowhere::IndexParams::search_k, topk,
std::numeric_limits<int64_t>::max());
if (!status.ok()) {
return status;
}