mirror of https://github.com/milvus-io/milvus.git
refactor knowhere inner unittest for annoy and hnsw, which a… (#1855)
* skip ci, refactor knowhere inner unittest for annoy and hnsw, which aimed to remove logs of std output Signed-off-by: cmli <chengming.li@zilliz.com> * skip ci, fix lint error Signed-off-by: cmli <chengming.li@zilliz.com> * fix annoy load bugs and add serialize test 4 annoy and hnsw Signed-off-by: cmli <chengming.li@zilliz.com> * update changelog Signed-off-by: cmli <chengming.li@zilliz.com> * update lower bound of annoy search_k param Signed-off-by: cmli <chengming.li@zilliz.com> * lint errorgit add src/utils/ValidationUtil.cpp Signed-off-by: cmli <chengming.li@zilliz.com> Co-authored-by: cmli <chengming.li@zilliz.com>pull/1868/head
parent
a57c97ddc3
commit
54337e9361
|
@ -125,6 +125,7 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#1601 External link bug in HTTP doc
|
||||
- \#1609 Refine Compact function
|
||||
- \#1808 Building index params check for Annoy
|
||||
- \#1852 Search index type<Annoy> failed with reason `failed to load index file`
|
||||
|
||||
## Feature
|
||||
- \#216 Add CLI to get server info
|
||||
|
|
|
@ -57,7 +57,7 @@ IndexAnnoy::Serialize(const Config& config) {
|
|||
void
|
||||
IndexAnnoy::Load(const BinarySet& index_binary) {
|
||||
auto metric_type = index_binary.GetByName("annoy_metric_type");
|
||||
metric_type_.resize((size_t)metric_type->size + 1);
|
||||
metric_type_.resize((size_t)metric_type->size);
|
||||
memcpy(metric_type_.data(), metric_type->data.get(), (size_t)metric_type->size);
|
||||
|
||||
auto dim_data = index_binary.GetByName("annoy_dim");
|
||||
|
@ -74,7 +74,7 @@ IndexAnnoy::Load(const BinarySet& index_binary) {
|
|||
|
||||
auto index_data = index_binary.GetByName("annoy_index_data");
|
||||
char* p = nullptr;
|
||||
if (!index_->load_index(index_data->data.get(), index_data->size, &p)) {
|
||||
if (!index_->load_index(reinterpret_cast<void*>(index_data->data.get()), index_data->size, &p)) {
|
||||
std::string error_msg(p);
|
||||
free(p);
|
||||
KNOWHERE_THROW_MSG(error_msg);
|
||||
|
|
|
@ -817,7 +817,7 @@ class AnnoyIndexInterface {
|
|||
virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0;
|
||||
virtual void unload() = 0;
|
||||
virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0;
|
||||
virtual bool load_index(const unsigned char* index_data, const int64_t& index_size, char** error = NULL) = 0;
|
||||
virtual bool load_index(void* index_data, const int64_t& index_size, char** error = NULL) = 0;
|
||||
virtual T get_distance(S i, S j) const = 0;
|
||||
virtual void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances,
|
||||
faiss::ConcurrentBitsetPtr bitset = nullptr) const = 0;
|
||||
|
@ -1109,7 +1109,7 @@ public:
|
|||
return true;
|
||||
}
|
||||
|
||||
bool load_index(const unsigned char* index_data, const int64_t& index_size, char** error) {
|
||||
bool load_index(void* index_data, const int64_t& index_size, char** error) {
|
||||
if (index_size == -1) {
|
||||
set_error_from_errno(error, "Unable to get size");
|
||||
return false;
|
||||
|
@ -1123,7 +1123,8 @@ public:
|
|||
}
|
||||
|
||||
_n_nodes = (S)(index_size / _s);
|
||||
_nodes = (Node*)malloc(_s * _n_nodes);
|
||||
// _nodes = (Node*)malloc(_s * _n_nodes);
|
||||
_nodes = (Node*)malloc((size_t)index_size);
|
||||
memcpy(_nodes, index_data, (size_t)index_size);
|
||||
|
||||
// Find the roots by scanning the end of the file and taking the nodes with most descendants
|
||||
|
@ -1177,7 +1178,7 @@ public:
|
|||
}
|
||||
|
||||
int64_t get_index_length() const {
|
||||
return (int64_t)_s * _nodes_size;
|
||||
return (int64_t)_s * _n_nodes;
|
||||
}
|
||||
|
||||
void* get_index() const {
|
||||
|
|
|
@ -23,6 +23,184 @@ using ::testing::Combine;
|
|||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class AnnoyTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
// std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
|
||||
Generate(128, 10000, 10);
|
||||
index_ = std::make_shared<milvus::knowhere::IndexAnnoy>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::IndexParams::n_trees, 4},
|
||||
{milvus::knowhere::IndexParams::search_k, 100},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
|
||||
// Init_with_default();
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexAnnoy> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy"));
|
||||
|
||||
TEST_P(AnnoyTest, annoy_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
// index_->Train(base_dataset, conf);
|
||||
index_->BuildAll(base_dataset, conf); // Train + Add
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
|
||||
/*
|
||||
* output result to check by eyes
|
||||
{
|
||||
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
|
||||
std::stringstream ss_id;
|
||||
std::stringstream ss_dist;
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
|
||||
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
|
||||
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
|
||||
ss_dist << *((float*)(dist) + i * k + j) << " ";
|
||||
}
|
||||
ss_id << std::endl;
|
||||
ss_dist << std::endl;
|
||||
}
|
||||
std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
TEST_P(AnnoyTest, annoy_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->BuildAll(base_dataset, conf); // Train + Add
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
/*
|
||||
* delete result checked by eyes
|
||||
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
std::cout << std::endl;
|
||||
for (int i = 0; i < nq; ++ i) {
|
||||
std::cout << "ids1: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids1 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << " ids2: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids2 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (int j = 0; j < std::min(5, k>>1); ++ j) {
|
||||
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
|
||||
}
|
||||
}
|
||||
*/
|
||||
/*
|
||||
* output result to check by eyes
|
||||
{
|
||||
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
|
||||
std::stringstream ss_id;
|
||||
std::stringstream ss_dist;
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
|
||||
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
|
||||
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
|
||||
ss_dist << *((float*)(dist) + i * k + j) << " ";
|
||||
}
|
||||
ss_id << std::endl;
|
||||
ss_dist << std::endl;
|
||||
}
|
||||
std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
TEST_P(AnnoyTest, annoy_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
// write and flush
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize();
|
||||
|
||||
auto bin_data = binaryset.GetByName("annoy_index_data");
|
||||
std::string filename1 = "/tmp/annoy_test_data_serialize.bin";
|
||||
auto load_data1 = new uint8_t[bin_data->size];
|
||||
serialize(filename1, bin_data, load_data1);
|
||||
|
||||
auto bin_metric_type = binaryset.GetByName("annoy_metric_type");
|
||||
std::string filename2 = "/tmp/annoy_test_metric_type_serialize.bin";
|
||||
auto load_data2 = new uint8_t[bin_metric_type->size];
|
||||
serialize(filename2, bin_metric_type, load_data2);
|
||||
|
||||
auto bin_dim = binaryset.GetByName("annoy_dim");
|
||||
std::string filename3 = "/tmp/annoy_test_dim_serialize.bin";
|
||||
auto load_data3 = new uint8_t[bin_dim->size];
|
||||
serialize(filename3, bin_dim, load_data3);
|
||||
|
||||
binaryset.clear();
|
||||
std::shared_ptr<uint8_t[]> index_data(load_data1);
|
||||
binaryset.Append("annoy_index_data", index_data, bin_data->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> metric_data(load_data2);
|
||||
binaryset.Append("annoy_metric_type", metric_data, bin_metric_type->size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> dim_data(load_data3);
|
||||
binaryset.Append("annoy_dim", dim_data, bin_dim->size);
|
||||
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* faiss style test
|
||||
* keep it
|
||||
int
|
||||
main() {
|
||||
int64_t d = 64; // dimension
|
||||
|
@ -127,95 +305,4 @@ main() {
|
|||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
class AnnoyTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
|
||||
Generate(128, 1000, 5);
|
||||
index_ = std::make_shared<milvus::knowhere::IndexAnnoy>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, dim},
|
||||
{milvus::knowhere::meta::TOPK, 1},
|
||||
{milvus::knowhere::IndexParams::n_trees, 4},
|
||||
{milvus::knowhere::IndexParams::search_k, 100},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
|
||||
// Init_with_default();
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexAnnoy> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values(""));
|
||||
|
||||
TEST_P(AnnoyTest, annoy_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
// index_->Train(base_dataset, conf);
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
|
||||
{
|
||||
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
|
||||
std::stringstream ss_id;
|
||||
std::stringstream ss_dist;
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
|
||||
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
|
||||
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
|
||||
ss_dist << *((float*)(dist) + i * k + j) << " ";
|
||||
}
|
||||
ss_id << std::endl;
|
||||
ss_dist << std::endl;
|
||||
}
|
||||
std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(AnnoyTest, annoy_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
// index_->Train(base_dataset, conf);
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
// index_->Add(base_dataset, conf);
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++ i) {
|
||||
bitset->set(i);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
|
||||
{
|
||||
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto dist = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
|
||||
std::stringstream ss_id;
|
||||
std::stringstream ss_dist;
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
|
||||
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
|
||||
ss_id << *((int64_t*)(ids) + i * k + j) << " ";
|
||||
ss_dist << *((float*)(dist) + i * k + j) << " ";
|
||||
}
|
||||
ss_id << std::endl;
|
||||
ss_dist << std::endl;
|
||||
}
|
||||
std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
} }
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -9,12 +9,130 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <knowhere/index/vector_index/IndexHNSW.h>
|
||||
#include <src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include "./utils.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class HNSWTest : public DataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
IndexType = GetParam();
|
||||
std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
|
||||
Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10
|
||||
index_ = std::make_shared<milvus::knowhere::IndexHNSW>();
|
||||
conf = milvus::knowhere::Config{
|
||||
{milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10},
|
||||
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
|
||||
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::knowhere::Config conf;
|
||||
std::shared_ptr<milvus::knowhere::IndexHNSW> index_ = nullptr;
|
||||
std::string IndexType;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWTest, Values("HNSW"));
|
||||
|
||||
TEST_P(HNSWTest, HNSW_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(HNSWTest, HNSW_delete) {
|
||||
assert(!xb.empty());
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result1, nq, k);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
|
||||
/*
|
||||
* delete result checked by eyes
|
||||
auto ids1 = result1->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
auto ids2 = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
std::cout << std::endl;
|
||||
for (int i = 0; i < nq; ++ i) {
|
||||
std::cout << "ids1: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids1 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << "ids2: ";
|
||||
for (int j = 0; j < k; ++ j) {
|
||||
std::cout << *(ids2 + i * k + j) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (int j = 0; j < std::min(5, k>>1); ++ j) {
|
||||
ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j));
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
TEST_P(HNSWTest, HNSW_serialize) {
|
||||
auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
{
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
}
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
{
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize();
|
||||
auto bin = binaryset.GetByName("HNSW");
|
||||
|
||||
std::string filename = "/tmp/HNSW_test_serialize.bin";
|
||||
auto load_data = new uint8_t[bin->size];
|
||||
serialize(filename, bin, load_data);
|
||||
|
||||
binaryset.clear();
|
||||
std::shared_ptr<uint8_t[]> data(load_data);
|
||||
binaryset.Append("HNSW", data, bin->size);
|
||||
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* faiss style test
|
||||
* keep it
|
||||
int
|
||||
main() {
|
||||
int64_t d = 64; // dimension
|
||||
|
@ -34,33 +152,31 @@ main() {
|
|||
xb[d * i] += i / 1000.;
|
||||
ids[i] = i;
|
||||
}
|
||||
printf("gen xb and ids done! \n");
|
||||
// printf("gen xb and ids done! \n");
|
||||
|
||||
// srand((unsigned)time(NULL));
|
||||
auto random_seed = (unsigned)time(NULL);
|
||||
printf("delete ids: \n");
|
||||
// printf("delete ids: \n");
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto tmp = rand_r(&random_seed) % nb;
|
||||
printf("%ld\n", tmp);
|
||||
// printf("%ld\n", tmp);
|
||||
// std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl;
|
||||
bitset->set(tmp);
|
||||
// std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl;
|
||||
for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j];
|
||||
// xq[d * i] += i / 1000.;
|
||||
}
|
||||
printf("\n");
|
||||
// printf("\n");
|
||||
|
||||
int k = 4;
|
||||
int m = 16;
|
||||
int ef = 200;
|
||||
milvus::knowhere::IndexHNSW index;
|
||||
milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids);
|
||||
/*
|
||||
base_dataset->Set(milvus::knowhere::meta::ROWS, nb);
|
||||
base_dataset->Set(milvus::knowhere::meta::DIM, d);
|
||||
base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb);
|
||||
base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids);
|
||||
*/
|
||||
// base_dataset->Set(milvus::knowhere::meta::ROWS, nb);
|
||||
// base_dataset->Set(milvus::knowhere::meta::DIM, d);
|
||||
// base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb);
|
||||
// base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids);
|
||||
|
||||
milvus::knowhere::Config base_conf{
|
||||
{milvus::knowhere::meta::DIM, d},
|
||||
|
@ -81,27 +197,27 @@ main() {
|
|||
index.Train(base_dataset, base_conf);
|
||||
index.Add(base_dataset, base_conf);
|
||||
|
||||
printf("------------sanity check----------------\n");
|
||||
// printf("------------sanity check----------------\n");
|
||||
{ // sanity check
|
||||
auto res = index.Query(query_dataset, query_conf);
|
||||
printf("Query done!\n");
|
||||
// printf("Query done!\n");
|
||||
const int64_t* I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
float* D = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
// float* D = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
|
||||
|
||||
printf("I=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
// printf("I=\n");
|
||||
// for (int i = 0; i < 5; i++) {
|
||||
// for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]);
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
printf("D=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
// printf("D=\n");
|
||||
// for (int i = 0; i < 5; i++) {
|
||||
// for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]);
|
||||
// printf("\n");
|
||||
// }
|
||||
}
|
||||
|
||||
printf("---------------search xq-------------\n");
|
||||
// printf("---------------search xq-------------\n");
|
||||
{ // search xq
|
||||
auto res = index.Query(query_dataset, query_conf);
|
||||
const int64_t* I = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
|
@ -132,3 +248,4 @@ main() {
|
|||
|
||||
return 0;
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -311,8 +311,8 @@ ValidationUtil::ValidateSearchParams(const milvus::json& search_params,
|
|||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::ANNOY: {
|
||||
auto status = CheckParameterRange(search_params, knowhere::IndexParams::search_k,
|
||||
std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max());
|
||||
auto status = CheckParameterRange(search_params, knowhere::IndexParams::search_k, topk,
|
||||
std::numeric_limits<int64_t>::max());
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue