1. support IDMap

2. fix some bug
3. background job from IDMap -> IVF


Former-commit-id: ba8f24f09c5481103ad3f4c1c91d4deb70f26dad
pull/191/head
xj.lin 2019-05-05 20:50:08 +08:00
parent 56bbe40faf
commit 2ac87c1e47
6 changed files with 115 additions and 30 deletions

View File

@ -19,7 +19,7 @@ namespace vecwise {
namespace engine {
const std::string RawIndexType = "IDMap,Flat";
const std::string BuildIndexType = "IDMap,Flat";
const std::string BuildIndexType = "IVF"; // IDMap / IVF
FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension, const std::string& location)

View File

@ -9,6 +9,7 @@
#include <faiss/gpu/StandardGpuResources.h>
#include "faiss/gpu/GpuIndexIVFFlat.h"
#include "faiss/gpu/GpuAutoTune.h"
#include "faiss/IndexFlat.h"
#include "IndexBuilder.h"
@ -20,6 +21,7 @@ namespace engine {
using std::vector;
static std::mutex gpu_resource;
static std::mutex cpu_resource;
IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
opd_ = opd;
@ -27,14 +29,14 @@ IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
// Default: build use gpu
Index_ptr IndexBuilder::build_all(const long &nb,
const float* xb,
const long* ids,
const float *xb,
const long *ids,
const long &nt,
const float* xt) {
const float *xt) {
std::shared_ptr<faiss::Index> host_index = nullptr;
{
// TODO: list support index-type.
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->index_type.c_str());
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
std::lock_guard<std::mutex> lk(gpu_resource);
faiss::gpu::StandardGpuResources res;
@ -43,7 +45,7 @@ Index_ptr IndexBuilder::build_all(const long &nb,
nt == 0 || xt == nullptr ? device_index->train(nb, xb)
: device_index->train(nt, xt);
}
device_index->add_with_ids(nb, xb, ids);
device_index->add_with_ids(nb, xb, ids); // TODO: support with add_with_IDMAP
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
@ -60,8 +62,32 @@ Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
return build_all(nb, xb.data(), ids.data(), nt, xt.data());
}
// Be Factory pattern later
BgCpuBuilder::BgCpuBuilder(const zilliz::vecwise::engine::Operand_ptr &opd) : IndexBuilder(opd) {};
Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) {
std::shared_ptr<faiss::Index> index = nullptr;
index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()));
{
std::lock_guard<std::mutex> lk(cpu_resource);
if (!index->is_trained) {
nt == 0 || xt == nullptr ? index->train(nb, xb)
: index->train(nt, xt);
}
index->add_with_ids(nb, xb, ids);
}
return std::make_shared<Index>(index);
}
// TODO: Be Factory pattern later
IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd) {
if (opd->index_type == "IDMap") {
// TODO: fix hardcode
IndexBuilderPtr index = nullptr;
return std::make_shared<BgCpuBuilder>(opd);
}
return std::make_shared<IndexBuilder>(opd);
}

View File

@ -11,25 +11,26 @@
#include "Operand.h"
#include "Index.h"
namespace zilliz {
namespace vecwise {
namespace engine {
class IndexBuilder {
public:
public:
explicit IndexBuilder(const Operand_ptr &opd);
Index_ptr build_all(const long &nb,
const float* xb,
const long* ids,
const long &nt = 0,
const float* xt = nullptr);
virtual Index_ptr build_all(const long &nb,
const float *xb,
const long *ids,
const long &nt = 0,
const float *xt = nullptr);
Index_ptr build_all(const long &nb,
const std::vector<float> &xb,
const std::vector<long> &ids,
const long &nt = 0,
const std::vector<float> &xt = std::vector<float>());
virtual Index_ptr build_all(const long &nb,
const std::vector<float> &xb,
const std::vector<long> &ids,
const long &nt = 0,
const std::vector<float> &xt = std::vector<float>());
void train(const long &nt,
const std::vector<float> &xt);
@ -41,10 +42,21 @@ public:
void set_build_option(const Operand_ptr &opd);
private:
protected:
Operand_ptr opd_ = nullptr;
};
class BgCpuBuilder : public IndexBuilder {
public:
BgCpuBuilder(const Operand_ptr &opd);
virtual Index_ptr build_all(const long &nb,
const float *xb,
const long *ids,
const long &nt = 0,
const float *xt = nullptr) override;
};
using IndexBuilderPtr = std::shared_ptr<IndexBuilder>;
extern IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd);

View File

@ -6,10 +6,53 @@
#include "Operand.h"
namespace zilliz {
namespace vecwise {
namespace engine {
using std::string;
enum IndexType {
Invalid_Option = 0,
IVF = 1,
IDMAP = 2
};
IndexType resolveIndexType(const string &index_type) {
if (index_type == "IVF") { return IndexType::IVF; }
if (index_type == "IDMap") { return IndexType::IDMAP; }
return IndexType::Invalid_Option;
}
// nb at least 100
string Operand::get_index_type(const int &nb) {
if (!index_str.empty()) { return index_str; }
// TODO: support OPQ or ...
if (!preproc.empty()) { index_str += (preproc + ","); }
switch (resolveIndexType(index_type)) {
case Invalid_Option: {
// TODO: add exception
break;
}
case IVF: {
index_str += (ncent != 0 ? index_type + std::to_string(ncent) :
index_type + std::to_string(int(nb / 1000000.0 * 16384)));
break;
}
case IDMAP: {
index_str += index_type;
break;
}
}
// TODO: support PQ or ...
if (!postproc.empty()) { index_str += ("," + postproc); }
return index_str;
}
std::ostream &operator<<(std::ostream &os, const Operand &obj) {
os << obj.d << " "
<< obj.index_type << " "

View File

@ -11,6 +11,7 @@
#include <iostream>
#include <sstream>
namespace zilliz {
namespace vecwise {
namespace engine {
@ -21,11 +22,14 @@ struct Operand {
friend std::istream &operator>>(std::istream &is, Operand &obj);
int d;
std::string index_type = "IVF13864,Flat";
std::string metric_type = "L2"; //> L2 / Inner Product
std::string index_type = "IVF";
std::string metric_type = "L2"; //> L2 / IP(Inner Product)
std::string preproc;
std::string postproc;
int ncent;
std::string postproc = "Flat";
std::string index_str;
int ncent = 0;
std::string get_index_type(const int &nb);
};
using Operand_ptr = std::shared_ptr<Operand>;

View File

@ -18,17 +18,17 @@ TEST(operand_test, Wrapper_Test) {
using std::endl;
auto opd = std::make_shared<Operand>();
opd->index_type = "IDMap,Flat";
opd->preproc = "opq";
opd->postproc = "pq";
opd->index_type = "IVF";
opd->preproc = "OPQ";
opd->postproc = "PQ";
opd->metric_type = "L2";
opd->ncent = 256;
opd->d = 64;
auto opd_str = operand_to_str(opd);
auto new_opd = str_to_operand(opd_str);
assert(new_opd->index_type == opd->index_type);
// TODO: fix all place where using opd to build index.
assert(new_opd->get_index_type(10000) == opd->get_index_type(10000));
}
TEST(build_test, Wrapper_Test) {
@ -56,7 +56,7 @@ TEST(build_test, Wrapper_Test) {
//train the index
auto opd = std::make_shared<Operand>();
opd->index_type = "IVF16,Flat";
opd->index_type = "IVF";
opd->d = d;
opd->ncent = ncentroids;
IndexBuilderPtr index_builder_1 = GetIndexBuilder(opd);
@ -120,7 +120,7 @@ TEST(gpu_build_test, Wrapper_Test) {
for (int i = 0; i < nb; ++i) { ids[i] = i; }
auto opd = std::make_shared<Operand>();
opd->index_type = "IVF256,Flat";
opd->index_type = "IVF";
opd->d = d;
opd->ncent = 256;