update wrapper and wrapper test

Former-commit-id: a57d040aae7d8f5ba20c99ea6a0c6220efcaeacd
pull/191/head
xj.lin 2019-07-02 18:37:33 +08:00
parent a8b068db57
commit 82150885d0
11 changed files with 369 additions and 118 deletions

View File

@ -0,0 +1,48 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "data_transfer.h"
namespace zilliz {
namespace vecwise {
namespace engine {
using namespace zilliz::knowhere;
DatasetPtr
GenDatasetWithIds(const int64_t &nb, const int64_t &dim, const float *xb, const long *ids) {
std::vector<int64_t> shape{nb, dim};
auto tensor = ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape);
std::vector<TensorPtr> tensors{tensor};
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
auto tensor_schema = std::make_shared<Schema>(tensor_fields);
auto id_array = ConstructInt64Array((uint8_t *) ids, nb * sizeof(int64_t));
std::vector<ArrayPtr> arrays{id_array};
std::vector<FieldPtr> array_fields{ConstructInt64Field("id")};
auto array_schema = std::make_shared<Schema>(tensor_fields);
auto dataset = std::make_shared<Dataset>(std::move(arrays), array_schema,
std::move(tensors), tensor_schema);
return dataset;
}
DatasetPtr
GenDataset(const int64_t &nb, const int64_t &dim, const float *xb) {
std::vector<int64_t> shape{nb, dim};
auto tensor = ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape);
std::vector<TensorPtr> tensors{tensor};
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
auto tensor_schema = std::make_shared<Schema>(tensor_fields);
auto dataset = std::make_shared<Dataset>(std::move(tensors), tensor_schema);
return dataset;
}
}
}
}

View File

@ -6,24 +6,19 @@
#pragma once
#define GENDATASET(n,d,xb,ids)\
size_t elems = (n) * (d);\
std::vector<int64_t> shape{n, d};\
auto tensor = ConstructFloatTensor((uint8_t *) (xb), elems * sizeof(float), shape);\
std::vector<TensorPtr> tensors{tensor};\
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};\
auto tensor_schema = std::make_shared<Schema>(tensor_fields);\
auto id_array = ConstructInt64Array((uint8_t *) (ids), (n) * sizeof(int64_t));\
std::vector<ArrayPtr> arrays{id_array};\
std::vector<FieldPtr> array_fields{ConstructInt64Field("id")};\
auto array_schema = std::make_shared<Schema>(tensor_fields);\
auto dataset = std::make_shared<Dataset>(std::move(arrays), array_schema, std::move(tensors), tensor_schema);\
#include "knowhere/adapter/structure.h"
#define GENQUERYDATASET(n,d,xb)\
size_t elems = (n) * (d);\
std::vector<int64_t> shape{(n), (d)};\
auto tensor = ConstructFloatTensor((uint8_t *) (xb), elems * sizeof(float), shape);\
std::vector<TensorPtr> tensors{tensor};\
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};\
auto tensor_schema = std::make_shared<Schema>(tensor_fields);\
auto dataset = std::make_shared<Dataset>(std::move(tensors), tensor_schema);\
namespace zilliz {
namespace vecwise {
namespace engine {
extern zilliz::knowhere::DatasetPtr
GenDatasetWithIds(const int64_t &nb, const int64_t &dim, const float *xb, const long *ids);
extern zilliz::knowhere::DatasetPtr
GenDataset(const int64_t &nb, const int64_t &dim, const float *xb);
}
}
}

View File

@ -13,11 +13,6 @@
#include "vec_impl.h"
#include "data_transfer.h"
//using Index = zilliz::knowhere::Index;
//using IndexModel = zilliz::knowhere::IndexModel;
//using IndexType = zilliz::knowhere::IndexType;
//using IndexPtr = std::shared_ptr<Index>;
//using IndexModelPtr = std::shared_ptr<IndexModel>;
namespace zilliz {
namespace vecwise {
@ -31,24 +26,21 @@ void VecIndexImpl::BuildAll(const long &nb,
const Config &cfg,
const long &nt,
const float *xt) {
using namespace zilliz::knowhere;
auto d = cfg["dim"].as<int>();
GENDATASET(nb, d, xb, ids)
auto dataset = GenDatasetWithIds(nb, d, xb, ids);
Config train_cfg;
Config add_cfg;
Config search_cfg;
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
index_->set_preprocessor(preprocessor);
auto model = index_->Train(dataset, cfg);
index_->set_index_model(model);
index_->Add(dataset, add_cfg);
index_->Add(dataset, cfg);
}
void VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
// TODO: Assert index is trained;
// TODO(linxj): Assert index is trained;
auto d = cfg["dim"].as<int>();
GENDATASET(nb, d, xb, ids)
auto dataset = GenDatasetWithIds(nb, d, xb, ids);
index_->Add(dataset, cfg);
}
@ -58,12 +50,13 @@ void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *id
auto d = cfg["dim"].as<int>();
auto k = cfg["k"].as<int>();
GENQUERYDATASET(nq, d, xq)
auto dataset = GenDataset(nq, d, xq);
Config search_cfg;
auto res = index_->Search(dataset, cfg);
auto ids_array = res->array()[0];
auto dis_array = res->array()[1];
//{
// auto& ids = ids_array;
// auto& dists = dis_array;
@ -81,10 +74,10 @@ void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *id
// std::cout << "dist\n" << ss_dist.str() << std::endl;
//}
// TODO: deep copy here.
auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
auto p_dist = ids_array->data()->GetValues<float>(1, 0);
// TODO(linxj): avoid copy here.
memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
memcpy(dist, p_dist, sizeof(float) * nq * k);
}

View File

@ -7,6 +7,7 @@
#pragma once
#include "knowhere/index/vector_index/vector_index.h"
#include "vec_index.h"
@ -16,7 +17,7 @@ namespace engine {
class VecIndexImpl : public VecIndex {
public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index):index_(std::move(index)){};
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : index_(std::move(index)) {};
void BuildAll(const long &nb,
const float *xb,
const long *ids,

View File

@ -5,6 +5,7 @@
////////////////////////////////////////////////////////////////////////////////
#include "knowhere/index/vector_index/ivf.h"
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "vec_index.h"
#include "vec_impl.h"
@ -14,18 +15,29 @@ namespace zilliz {
namespace vecwise {
namespace engine {
// TODO(linxj): index_type => enum struct
VecIndexPtr GetVecIndexFactory(const std::string &index_type) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
if (index_type == "IVF") {
index = std::make_shared<zilliz::knowhere::IVF>();
} else if (index_type == "GPUIVF") {
index = std::make_shared<zilliz::knowhere::GPUIVF>();
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
} else if (index_type == "SPTAG") {
index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
}
auto ret_index = std::make_shared<VecIndexImpl>(index);
//return std::static_pointer_cast<VecIndex>(std::make_shared<VecIndexImpl>(index));
// TODO(linxj): Support NSG
//else if (index_type == "NSG") {
// index = std::make_shared<zilliz::knowhere::NSG>();
//}
return std::make_shared<VecIndexImpl>(index);
}
VecIndexPtr LoadVecIndex(const std::string &index_type, const zilliz::knowhere::BinarySet &index_binary) {
auto index = GetVecIndexFactory(index_type);
index->Load(index_binary);
return index;
}
}
}
}

View File

@ -17,6 +17,7 @@ namespace zilliz {
namespace vecwise {
namespace engine {
// TODO(linxj): jsoncons => rapidjson or other.
using Config = zilliz::knowhere::Config;
class VecIndex {
@ -31,13 +32,13 @@ class VecIndex {
virtual void Add(const long &nb,
const float *xb,
const long *ids,
const Config &cfg) = 0;
const Config &cfg = Config()) = 0;
virtual void Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const Config &cfg) = 0;
const Config &cfg = Config()) = 0;
virtual zilliz::knowhere::BinarySet Serialize() = 0;
@ -48,8 +49,7 @@ using VecIndexPtr = std::shared_ptr<VecIndex>;
extern VecIndexPtr GetVecIndexFactory(const std::string &index_type);
// TODO
extern VecIndexPtr LoadVecIndex(const zilliz::knowhere::BinarySet &index_binary);
extern VecIndexPtr LoadVecIndex(const std::string &index_type, const zilliz::knowhere::BinarySet &index_binary);
}
}

@ -1 +1 @@
Subproject commit 291b3b422664f2509bab79d5cc63823dedbe903c
Subproject commit 32187bacbaac0460676f5f6aa54ad904f5f2b5bc

View File

@ -3,6 +3,9 @@ link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper/knowhere knowhere_src)
set(helper
utils.cpp)
set(knowhere_libs
knowhere
SPTAGLibStatic
@ -11,9 +14,10 @@ set(knowhere_libs
faiss
openblas
lapack
tbb
cudart
cublas
)
add_executable(knowhere_test knowhere_test.cpp ${knowhere_src})
add_executable(knowhere_test knowhere_test.cpp ${knowhere_src} ${helper})
target_link_libraries(knowhere_test ${knowhere_libs} ${unittest_libs})

View File

@ -8,86 +8,142 @@
#include <wrapper/knowhere/vec_index.h>
#include "utils.h"
using namespace zilliz::vecwise::engine;
using namespace zilliz::knowhere;
TEST(knowhere_test, ivf_test) {
auto d = 128;
auto nt = 1000;
auto nb = 10000;
auto nq = 10;
//{
// std::vector<float> xb;
// std::vector<float> xt;
// std::vector<float> xq;
// std::vector<long> ids;
//
// //prepare train data
// std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
// std::random_device rd;
// std::mt19937 gen(rd());
// xt.resize(nt*d);
// for (size_t i = 0; i < nt * d; i++) {
// xt[i] = dis_xt(gen);
// }
// xb.resize(nb*d);
// ids.resize(nb);
// for (size_t i = 0; i < nb * d; i++) {
// xb[i] = dis_xt(gen);
// if (i < nb) {
// ids[i] = i;
// }
// }
// xq.resize(nq*d);
// for (size_t i = 0; i < nq * d; i++) {
// xq[i] = dis_xt(gen);
// }
//}
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::Combine;
auto elems = nb * d;
auto p_data = (float *) malloc(elems * sizeof(float));
auto p_id = (int64_t *) malloc(elems * sizeof(int64_t));
assert(p_data != nullptr && p_id != nullptr);
for (auto i = 0; i < nb; ++i) {
for (auto j = 0; j < d; ++j) {
p_data[i * d + j] = drand48();
}
p_data[d * i] += i / 1000.;
p_id[i] = i;
class KnowhereWrapperTest
: public TestWithParam<::std::tuple<std::string, std::string, int, int, int, int, Config, Config>> {
protected:
void SetUp() override {
std::string generator_type;
std::tie(index_type, generator_type, dim, nb, nq, k, train_cfg, search_cfg) = GetParam();
//auto generator = GetGenerateFactory(generator_type);
auto generator = std::make_shared<DataGenBase>();
generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids);
index_ = GetVecIndexFactory(index_type);
}
auto q_elems = nq * d;
auto q_data = (float *) malloc(q_elems * sizeof(float));
protected:
std::string index_type;
Config train_cfg;
Config search_cfg;
for (auto i = 0; i < nq; ++i) {
for (auto j = 0; j < d; ++j) {
q_data[i * d + j] = drand48();
}
q_data[d * i] += i / 1000.;
int dim = 64;
int nb = 10000;
int nq = 10;
int k = 10;
std::vector<float> xb;
std::vector<float> xq;
std::vector<long> ids;
VecIndexPtr index_ = nullptr;
// Ground Truth
std::vector<long> gt_ids;
};
INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
Values(
// ["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
std::make_tuple("IVF", "Default",
64, 10000, 10, 10,
Config::object{{"nlist", 100}, {"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 20}}
),
std::make_tuple("SPTAG", "Default",
64, 10000, 10, 10,
Config::object{{"TPTNumber", 1}, {"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}}
)
)
);
void AssertAnns(const std::vector<long> &gt,
const std::vector<long> &res,
const int &nq,
const int &k) {
EXPECT_EQ(res.size(), nq * k);
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(gt[i * k], res[i * k]);
}
Config build_cfg = Config::object{
{"dim", d},
{"nlist", 100},
};
int match = 0;
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < k; ++j) {
for (int l = 0; l < k; ++l) {
if (gt[i * nq + j] == res[i * nq + l]) match++;
}
}
}
auto k = 10;
Config search_cfg = Config::object{
{"dim", d},
{"k", k},
};
std::vector<float> ret_dist(nq*k);
std::vector<long> ret_ids(nq*k);
const std::string& index_type = "IVF";
auto index = GetVecIndexFactory(index_type);
index->BuildAll(nb, p_data, p_id, build_cfg);
auto add_bin = index->Serialize();
index->Load(add_bin);
index->Search(nq, q_data, ret_dist.data(), ret_ids.data(), search_cfg);
std::cout << "he";
// TODO(linxj): percision check
EXPECT_GT(float(match/nq*k), 0.5);
}
TEST_P(KnowhereWrapperTest, base_test) {
std::vector<long> res_ids;
float *D = new float[k * nq];
res_ids.resize(nq * k);
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
delete[] D;
}
TEST_P(KnowhereWrapperTest, serialize_test) {
std::vector<long> res_ids;
float *D = new float[k * nq];
res_ids.resize(nq * k);
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
{
auto binaryset = index_->Serialize();
int fileno = 0;
const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
std::vector<std::string> filename_list;
std::vector<std::pair<std::string, size_t >> meta_list;
for (auto &iter: binaryset.binary_map_) {
const std::string &filename = base_name + std::to_string(fileno);
FileIOWriter writer(filename);
writer(iter.second.data, iter.second.size);
meta_list.push_back(std::make_pair(iter.first, iter.second.size));
filename_list.push_back(filename);
++fileno;
}
BinarySet load_data_list;
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
auto bin_size = meta_list[i].second;
FileIOReader reader(filename_list[i]);
std::vector<uint8_t> load_data(bin_size);
reader(load_data.data(), bin_size);
load_data_list.Append(meta_list[i].first, load_data);
}
res_ids.clear();
res_ids.resize(nq * k);
auto new_index = GetVecIndexFactory(index_type);
new_index->Load(load_data_list);
new_index->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
}
delete[] D;
}

View File

@ -0,0 +1,81 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <faiss/IndexFlat.h>
#include "utils.h"
DataGenPtr GetGenerateFactory(const std::string &gen_type) {
std::shared_ptr<DataGenBase> generator;
if (gen_type == "default") {
generator = std::make_shared<DataGenBase>();
}
return generator;
}
void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
float *xb, float *xq, long *ids,
const int &k, long *gt_ids) {
for (auto i = 0; i < nb; ++i) {
for (auto j = 0; j < dim; ++j) {
//p_data[i * d + j] = float(base + i);
xb[i * dim + j] = drand48();
}
xb[dim * i] += i / 1000.;
ids[i] = i;
}
for (size_t i = 0; i < nq * dim; ++i) {
xq[i] = xb[i];
}
faiss::IndexFlatL2 index(dim);
//index.add_with_ids(nb, xb, ids);
index.add(nb, xb);
float *D = new float[k * nq];
index.search(nq, xq, k, D, gt_ids);
}
void DataGenBase::GenData(const int &dim,
const int &nb,
const int &nq,
std::vector<float> &xb,
std::vector<float> &xq,
std::vector<long> &ids,
const int &k,
std::vector<long> &gt_ids) {
xb.resize(nb * dim);
xq.resize(nq * dim);
ids.resize(nb);
gt_ids.resize(nq * k);
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data());
}
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}

View File

@ -0,0 +1,61 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <memory>
#include <vector>
#include <cstdlib>
#include <cstdio>
#include <fstream>
class DataGenBase;
using DataGenPtr = std::shared_ptr<DataGenBase>;
extern DataGenPtr GetGenerateFactory(const std::string &gen_type);
class DataGenBase {
public:
virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
const int &k, long *gt_ids);
virtual void GenData(const int &dim,
const int &nb,
const int &nq,
std::vector<float> &xb,
std::vector<float> &xq,
std::vector<long> &ids,
const int &k,
std::vector<long> &gt_ids);
};
class SanityCheck : public DataGenBase {
public:
void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
const int &k, long *gt_ids) override;
};
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
};