fix hamming (#3338)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/3470/head
shengjun.li 2020-08-20 09:42:04 +08:00 committed by GitHub
parent 1962b6e78b
commit 83cbe0f490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 145 deletions

View File

@ -53,64 +53,13 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
// QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config());
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
return ret_ds;
}
#if 0
DatasetPtr
BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto dim = dataset_ptr->Get<int64_t>(meta::DIM);
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
int64_t k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
auto* pdistances = (int32_t*)p_dist;
index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
#endif
int64_t
BinaryIDMAP::Count() {
@ -187,39 +136,25 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
index_->add_with_ids(rows, (uint8_t*)p_data, new_ids.data());
}
#if 0
DatasetPtr
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
// GETBINARYTENSOR(dataset_ptr)
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
size_t p_x_size = sizeof(uint8_t) * elems;
auto p_x = (uint8_t*)malloc(p_x_size);
index_->get_vector_by_id(1, p_data, p_x, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::TENSOR, p_x);
return ret_ds;
}
#endif
void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
int32_t* pdistances = (int32_t*)distances;
auto flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
auto default_type = flat_index->metric_type;
if (config.contains(Metric::TYPE))
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_);
// if hamming, it need transform int32 to float
if (flat_index->metric_type == faiss::METRIC_Hamming) {
int64_t num = n * k;
for (int64_t i = 0; i < num; i++) {
distances[i] = static_cast<float>(i_distances[i]);
}
}
flat_index->metric_type = default_type;
}

View File

@ -50,11 +50,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
int64_t
Count() override;
@ -66,11 +61,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
return Count() * Dim() / 8;
}
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
virtual const uint8_t*
GetRawVectors();

View File

@ -62,19 +62,9 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
@ -215,11 +205,10 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
int32_t* pdistances = (int32_t*)distances;
stdclock::time_point before = stdclock::now();
// todo: remove static cast (zhiru)
static_cast<faiss::IndexBinary*>(index_.get())->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
stdclock::time_point before = stdclock::now();
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
@ -228,6 +217,14 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
// if hamming, it need transform int32 to float
if (ivf_index->metric_type == faiss::METRIC_Hamming) {
int64_t num = n * k;
for (int64_t i = 0; i < num; i++) {
distances[i] = static_cast<float>(i_distances[i]);
}
}
}
} // namespace knowhere

View File

@ -109,7 +109,26 @@ TestProcess(std::shared_ptr<milvus::Connection> connection,
TOP_K,
NPROBE,
search_entity_array,
topk_query_result);
topk_query_result,
milvus::MetricType::HAMMING);
milvus_sdk::Utils::DoSearch(connection,
collection_param.collection_name,
partition_tags,
TOP_K,
NPROBE,
search_entity_array,
topk_query_result,
milvus::MetricType::SUBSTRUCTURE);
milvus_sdk::Utils::DoSearch(connection,
collection_param.collection_name,
partition_tags,
TOP_K,
NPROBE,
search_entity_array,
topk_query_result,
milvus::MetricType::SUPERSTRUCTURE);
}
{ // wait unit build index finish
@ -170,41 +189,5 @@ ClientTest::Test(const std::string& address, const std::string& port) {
TestProcess(connection, collection_param, index_param);
}
{
milvus::CollectionParam collection_param = {
"collection_2",
512, // dimension
512, // index file size
milvus::MetricType::SUBSTRUCTURE
};
JSON json_params = {};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::FLAT,
json_params.dump()
};
TestProcess(connection, collection_param, index_param);
}
{
milvus::CollectionParam collection_param = {
"collection_3",
128, // dimension
1024, // index file size
milvus::MetricType::SUPERSTRUCTURE
};
JSON json_params = {};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::FLAT,
json_params.dump()
};
TestProcess(connection, collection_param, index_param);
}
milvus::Connection::Destroy(connection);
}

View File

@ -202,7 +202,7 @@ void
Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
milvus::TopKQueryResult& topk_query_result) {
milvus::TopKQueryResult& topk_query_result, milvus::MetricType metric_type) {
topk_query_result.clear();
std::vector<milvus::Entity> temp_entity_array;
@ -213,6 +213,10 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& col
{
BLOCK_SPLITER
JSON json_params = {{"nprobe", nprobe}};
if (metric_type != milvus::MetricType::INVALID) {
json_params["metric_type"] = metric_type;
}
milvus_sdk::TimeRecorder rc("Search");
milvus::Status stat =
conn->Search(collection_name,

View File

@ -69,7 +69,8 @@ class Utils {
DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
milvus::TopKQueryResult& topk_query_result);
milvus::TopKQueryResult& topk_query_result,
milvus::MetricType metric_type = milvus::MetricType::INVALID);
static std::vector<milvus::LeafQueryPtr>
GenLeafQuery();

View File

@ -42,6 +42,7 @@ enum class IndexType {
};
enum class MetricType {
INVALID = 0,
L2 = 1, // Euclidean Distance
IP = 2, // Cosine Similarity
HAMMING = 3, // Hamming Distance