mirror of https://github.com/milvus-io/milvus.git
Fix binaryflat (#4681)
* fix binary distance Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * improve the performance of BinaryFlat Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * avoid negative zero in Tanimoto Signed-off-by: shengjun.li <shengjun.li@zilliz.com>pull/4694/head
parent
3efbab1df5
commit
bbbf254b7c
|
@ -4,9 +4,11 @@ Please mark all change in change log and use the issue from GitHub
|
|||
|
||||
# Milvus 0.10.6 (TBD)
|
||||
## Bug
|
||||
- \#4683 A negative zero may be returned if the metric_type is Tanimoto
|
||||
- \#4678 Server crash on BinaryFlat if dimension is not a power of 2
|
||||
|
||||
## Feature
|
||||
- \#4676 make metrics label configurable
|
||||
- \#4676 Support configurable metric labels for Prometheus
|
||||
|
||||
## Improvement
|
||||
|
||||
|
|
|
@ -41,60 +41,34 @@ void IndexBinaryFlat::reset() {
|
|||
void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels,
|
||||
ConcurrentBitsetPtr bitset) const {
|
||||
const idx_t block_size = query_batch_size;
|
||||
|
||||
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
|
||||
float *D = reinterpret_cast<float*>(distances);
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
float_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, D
|
||||
};
|
||||
binary_distance_knn_hc(METRIC_Jaccard, &res, x, xb.data(), ntotal, code_size, bitset);
|
||||
|
||||
// We see the distances and labels as heaps.
|
||||
float_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, D + s * k
|
||||
};
|
||||
|
||||
binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
/* ordered = */ true, bitset);
|
||||
|
||||
}
|
||||
if (metric_type == METRIC_Tanimoto) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = -log2(1-D[i]);
|
||||
D[i] = Jaccard_2_Tanimoto(D[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (metric_type == METRIC_Hamming) {
|
||||
int_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, distances
|
||||
};
|
||||
binary_distance_knn_hc(METRIC_Hamming, &res, x, xb.data(), ntotal, code_size, bitset);
|
||||
|
||||
} else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
|
||||
float *D = reinterpret_cast<float*>(distances);
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
|
||||
// only match ids will be chosed, not to use heap
|
||||
binary_distence_knn_mc(metric_type, x + s * code_size, xb.data(), nn, ntotal, k, code_size,
|
||||
D + s * k, labels + s * k, bitset);
|
||||
}
|
||||
// only matched ids will be chosen, not to use heap
|
||||
binary_distance_knn_mc(metric_type, x, xb.data(), n, ntotal, k, code_size,
|
||||
D, labels, bitset);
|
||||
} else {
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
if (use_heap) {
|
||||
// We see the distances and labels as heaps.
|
||||
int_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, distances + s * k
|
||||
};
|
||||
|
||||
hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
/* ordered = */ true, bitset);
|
||||
} else {
|
||||
hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
|
||||
distances + s * k, labels + s * k, bitset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -260,12 +260,12 @@ struct FlatHammingDis : DistanceComputer {
|
|||
|
||||
float operator () (idx_t i) override {
|
||||
ndis++;
|
||||
return hc.hamming(b + i * code_size);
|
||||
return hc.compute(b + i * code_size);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return HammingComputerDefault(b + j * code_size, code_size)
|
||||
.hamming(b + i * code_size);
|
||||
.compute(b + i * code_size);
|
||||
}
|
||||
|
||||
|
||||
|
@ -312,11 +312,7 @@ DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
|
|||
case 64:
|
||||
return new FlatHammingDis<HammingComputer64>(*flat_storage);
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
return new FlatHammingDis<HammingComputerM8>(*flat_storage);
|
||||
} else if (code_size % 4 == 0) {
|
||||
return new FlatHammingDis<HammingComputerM4>(*flat_storage);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
|
||||
|
|
|
@ -173,7 +173,7 @@ search_single_query_template(const IndexBinaryHash & index, const uint8_t *q,
|
|||
} else {
|
||||
const uint8_t *codes = il.vecs.data();
|
||||
for (size_t i = 0; i < nv; i++) {
|
||||
int dis = hc.hamming (codes);
|
||||
int dis = hc.compute (codes);
|
||||
res.add(dis, il.ids[i]);
|
||||
codes += code_size;
|
||||
}
|
||||
|
@ -196,12 +196,8 @@ search_single_query(const IndexBinaryHash & index, const uint8_t *q,
|
|||
case 16: HC(HammingComputer16); break;
|
||||
case 20: HC(HammingComputer20); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (index.code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -364,7 +360,7 @@ void verify_shortlist(
|
|||
const uint8_t *codes = index.xb.data();
|
||||
|
||||
for (auto i: shortlist) {
|
||||
int dis = hc.hamming (codes + i * code_size);
|
||||
int dis = hc.compute (codes + i * code_size);
|
||||
res.add(dis, i);
|
||||
}
|
||||
}
|
||||
|
@ -416,12 +412,7 @@ search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
|
|||
case 16: HC(HammingComputer16); break;
|
||||
case 20: HC(HammingComputer20); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (index.code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/jaccard-inl.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
@ -362,7 +363,7 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
}
|
||||
|
||||
uint32_t distance_to_code (const uint8_t *code) const override {
|
||||
return hc.hamming (code);
|
||||
return hc.compute (code);
|
||||
}
|
||||
|
||||
size_t scan_codes (size_t n,
|
||||
|
@ -377,7 +378,7 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
size_t nup = 0;
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
if (!bitset || !bitset->test(ids[j])) {
|
||||
uint32_t dis = hc.hamming (codes);
|
||||
uint32_t dis = hc.compute (codes);
|
||||
if (dis < simi[0]) {
|
||||
idx_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
heap_swap_top<C> (k, simi, idxi, dis, id);
|
||||
|
@ -397,7 +398,7 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
{
|
||||
size_t nup = 0;
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
uint32_t dis = hc.hamming (codes);
|
||||
uint32_t dis = hc.compute (codes);
|
||||
if (dis < radius) {
|
||||
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
||||
result.add (dis, id);
|
||||
|
@ -471,14 +472,7 @@ BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
|
|||
case 20: HC(HammingComputer20);
|
||||
case 32: HC(HammingComputer32);
|
||||
case 64: HC(HammingComputer64);
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else if (code_size % 4 == 0) {
|
||||
HC(HammingComputerM4);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -798,16 +792,8 @@ void search_knn_hamming_count_1 (
|
|||
HANDLE_CS(64);
|
||||
#undef HANDLE_CS
|
||||
default:
|
||||
if (ivf.code_size % 8 == 0) {
|
||||
search_knn_hamming_count<HammingComputerM8, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, bitset);
|
||||
} else if (ivf.code_size % 4 == 0) {
|
||||
search_knn_hamming_count<HammingComputerM4, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, bitset);
|
||||
} else {
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, bitset);
|
||||
}
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params, bitset);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -856,7 +842,7 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
|
|||
params, bitset);
|
||||
if (metric_type == METRIC_Tanimoto) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = -log2(1-D[i]);
|
||||
D[i] = Jaccard_2_Tanimoto(D[i]);
|
||||
}
|
||||
}
|
||||
memcpy(distances, D, sizeof(float) * n * k);
|
||||
|
|
|
@ -969,7 +969,7 @@ struct IVFPQScannerT: QueryTables {
|
|||
|
||||
for (size_t j = 0; j < ncode; j++) {
|
||||
const uint8_t *b_code = codes;
|
||||
int hd = hc.hamming (b_code);
|
||||
int hd = hc.compute (b_code);
|
||||
if (hd < ht) {
|
||||
n_hamming_pass ++;
|
||||
PQDecoder decoder(codes, pq.nbits);
|
||||
|
@ -1013,14 +1013,9 @@ struct IVFPQScannerT: QueryTables {
|
|||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
if (pq.code_size % 8 == 0)
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerM8, SearchResultType>
|
||||
(ncode, codes, res, bitset);
|
||||
else
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerM4, SearchResultType>
|
||||
(ncode, codes, res, bitset);
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerDefault, SearchResultType>
|
||||
(ncode, codes, res, bitset);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -254,7 +254,7 @@ struct IVFScanner: InvertedListScanner {
|
|||
}
|
||||
|
||||
float distance_to_code (const uint8_t *code) const final {
|
||||
return hc.hamming (code);
|
||||
return hc.compute (code);
|
||||
}
|
||||
|
||||
size_t scan_codes (size_t list_size,
|
||||
|
@ -267,7 +267,7 @@ struct IVFScanner: InvertedListScanner {
|
|||
size_t nup = 0;
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
if (!bitset || !bitset->test(ids[j])) {
|
||||
float dis = hc.hamming (codes);
|
||||
float dis = hc.compute (codes);
|
||||
|
||||
if (dis < simi [0]) {
|
||||
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
|
@ -288,7 +288,7 @@ struct IVFScanner: InvertedListScanner {
|
|||
ConcurrentBitsetPtr bitset = nullptr) const override
|
||||
{
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
float dis = hc.hamming (codes);
|
||||
float dis = hc.compute (codes);
|
||||
if (dis < radius) {
|
||||
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
||||
res.add (dis, id);
|
||||
|
@ -317,10 +317,8 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
|
|||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
return new IVFScanner<HammingComputerM8>(this, store_pairs);
|
||||
} else if (code_size % 4 == 0) {
|
||||
return new IVFScanner<HammingComputerM4>(this, store_pairs);
|
||||
if (code_size % 4 == 0) {
|
||||
return new IVFScanner<HammingComputerDefault>(this, store_pairs);
|
||||
} else {
|
||||
FAISS_THROW_MSG("not supported");
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
@ -147,9 +149,8 @@ void IndexLSH::search (
|
|||
|
||||
int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
|
||||
|
||||
hammings_knn_hc (&res, qcodes, codes.data(),
|
||||
ntotal, bytes_per_vec, true);
|
||||
|
||||
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)&qcodes, codes.data(), ntotal,
|
||||
bytes_per_vec, bitset);
|
||||
|
||||
// convert distances to floats
|
||||
for (int i = 0; i < k * n; i++)
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
@ -198,11 +200,6 @@ DistanceComputer * IndexPQ::get_distance_computer() const {
|
|||
/*****************************************
|
||||
* IndexPQ polysemous search routines
|
||||
******************************************/
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void IndexPQ::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
ConcurrentBitsetPtr bitset) const
|
||||
|
@ -264,8 +261,8 @@ void IndexPQ::search (idx_t n, const float *x, idx_t k,
|
|||
|
||||
if (search_type == ST_HE) {
|
||||
|
||||
hammings_knn_hc (&res, q_codes, codes.data(),
|
||||
ntotal, pq.code_size, true);
|
||||
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t *)q_codes, codes.data(),
|
||||
ntotal, pq.code_size, bitset);
|
||||
|
||||
} else if (search_type == ST_generalized_HE) {
|
||||
|
||||
|
@ -317,7 +314,7 @@ static size_t polysemous_inner_loop (
|
|||
HammingComputer hc (q_code, code_size);
|
||||
|
||||
for (int64_t bi = 0; bi < ntotal; bi++) {
|
||||
int hd = hc.hamming (b_code);
|
||||
int hd = hc.compute (b_code);
|
||||
|
||||
if (hd < ht) {
|
||||
n_pass_i ++;
|
||||
|
@ -400,11 +397,8 @@ void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
|
|||
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
||||
break;
|
||||
default:
|
||||
if (pq.code_size % 8 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerM8>
|
||||
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
||||
} else if (pq.code_size % 4 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerM4>
|
||||
if (pq.code_size % 4 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerDefault>
|
||||
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
||||
} else {
|
||||
FAISS_THROW_FMT(
|
||||
|
|
|
@ -51,9 +51,9 @@ TEST(BinaryFlat, accuracy) {
|
|||
|
||||
for (size_t i = 0; i < nq; ++i) {
|
||||
faiss::HammingComputer8 hc(queries.data() + i * (d / 8), d / 8);
|
||||
hamdis_t dist_min = hc.hamming(database.data());
|
||||
hamdis_t dist_min = hc.compute(database.data());
|
||||
for (size_t j = 1; j < nb; ++j) {
|
||||
hamdis_t dist = hc.hamming(database.data() + j * (d / 8));
|
||||
hamdis_t dist = hc.compute(database.data() + j * (d / 8));
|
||||
if (dist < dist_min) {
|
||||
dist_min = dist;
|
||||
}
|
||||
|
|
|
@ -1,168 +1,129 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <limits.h>
|
||||
#include <omp.h>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/FaissHook.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/jaccard-inl.h>
|
||||
#include <faiss/utils/substructure-inl.h>
|
||||
#include <faiss/utils/superstructure-inl.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
static const size_t size_1M = 1 * 1024 * 1024;
|
||||
static const size_t batch_size = 65536;
|
||||
|
||||
template <class T>
|
||||
static
|
||||
void binary_distence_knn_hc(
|
||||
int bytes_per_code,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
size_t n2,
|
||||
bool order = true,
|
||||
bool init_heap = true,
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = ha->k;
|
||||
|
||||
if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) {
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
// init heap
|
||||
size_t thread_heap_size = ha->nh * k;
|
||||
size_t all_heap_size = thread_heap_size * thread_max_num;
|
||||
float *value = new float[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
for (int i = 0; i < all_heap_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
T *hc = new T[ha->nh];
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
int thread_no = omp_get_thread_num();
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
tadis_t dis = hc[i].compute (bs2_);
|
||||
|
||||
float * val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_swap_top<tadis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
faiss::maxheap_swap_top<tadis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(ha->val, value, thread_heap_size * sizeof(float));
|
||||
memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t));
|
||||
|
||||
delete[] hc;
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
if (init_heap) ha->heapify ();
|
||||
|
||||
const size_t block_size = batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
T hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
|
||||
tadis_t dis;
|
||||
tadis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
size_t j;
|
||||
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
dis = hc.compute (bs2_);
|
||||
if (dis < bh_val_[0]) {
|
||||
faiss::maxheap_swap_top<tadis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
#define fast_loop_imp(fun_u64, fun_u8) \
|
||||
auto a = reinterpret_cast<const uint64_t*>(data1); \
|
||||
auto b = reinterpret_cast<const uint64_t*>(data2); \
|
||||
int div = code_size / 8; \
|
||||
int mod = code_size % 8; \
|
||||
int i = 0, len = div; \
|
||||
switch(len & 7) { \
|
||||
default: \
|
||||
while (len > 7) { \
|
||||
len -= 8; \
|
||||
fun_u64; i++; \
|
||||
case 7: fun_u64; i++; \
|
||||
case 6: fun_u64; i++; \
|
||||
case 5: fun_u64; i++; \
|
||||
case 4: fun_u64; i++; \
|
||||
case 3: fun_u64; i++; \
|
||||
case 2: fun_u64; i++; \
|
||||
case 1: fun_u64; i++; \
|
||||
} \
|
||||
} \
|
||||
if (mod) { \
|
||||
auto a = data1 + 8 * div; \
|
||||
auto b = data2 + 8 * div; \
|
||||
switch (mod) { \
|
||||
case 7: fun_u8(6); \
|
||||
case 6: fun_u8(5); \
|
||||
case 5: fun_u8(4); \
|
||||
case 4: fun_u8(3); \
|
||||
case 3: fun_u8(2); \
|
||||
case 2: fun_u8(1); \
|
||||
case 1: fun_u8(0); \
|
||||
default: break; \
|
||||
} \
|
||||
}
|
||||
|
||||
if (order) ha->reorder ();
|
||||
int popcnt(const uint8_t* data, const size_t code_size) {
|
||||
auto data1 = data, data2 = data; // for the macro fast_loop_imp
|
||||
#define fun_u64 accu += popcount64(a[i])
|
||||
#define fun_u8(i) accu += lookup8bit[a[i]]
|
||||
int accu = 0;
|
||||
fast_loop_imp(fun_u64, fun_u8);
|
||||
return accu;
|
||||
#undef fun_u64
|
||||
#undef fun_u8
|
||||
}
|
||||
|
||||
void binary_distence_knn_hc (
|
||||
MetricType metric_type,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int order,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
switch (metric_type) {
|
||||
case METRIC_Jaccard:
|
||||
case METRIC_Tanimoto:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_hc_jaccard(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_hc<faiss::JaccardComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, order, true, bitset); \
|
||||
break;
|
||||
binary_distence_knn_hc_jaccard(8);
|
||||
binary_distence_knn_hc_jaccard(16);
|
||||
binary_distence_knn_hc_jaccard(32);
|
||||
binary_distence_knn_hc_jaccard(64);
|
||||
binary_distence_knn_hc_jaccard(128);
|
||||
binary_distence_knn_hc_jaccard(256);
|
||||
binary_distence_knn_hc_jaccard(512);
|
||||
#undef binary_distence_knn_hc_jaccard
|
||||
default:
|
||||
binary_distence_knn_hc<faiss::JaccardComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
int xor_popcnt(const uint8_t* data1, const uint8_t*data2, const size_t code_size) {
|
||||
#define fun_u64 accu += popcount64(a[i] ^ b[i]);
|
||||
#define fun_u8(i) accu += lookup8bit[a[i] ^ b[i]];
|
||||
int accu = 0;
|
||||
fast_loop_imp(fun_u64, fun_u8);
|
||||
return accu;
|
||||
#undef fun_u64
|
||||
#undef fun_u8
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
int or_popcnt(const uint8_t* data1, const uint8_t*data2, const size_t code_size) {
|
||||
#define fun_u64 accu += popcount64(a[i] | b[i])
|
||||
#define fun_u8(i) accu += lookup8bit[a[i] | b[i]]
|
||||
int accu = 0;
|
||||
fast_loop_imp(fun_u64, fun_u8);
|
||||
return accu;
|
||||
#undef fun_u64
|
||||
#undef fun_u8
|
||||
}
|
||||
|
||||
int and_popcnt(const uint8_t* data1, const uint8_t*data2, const size_t code_size) {
|
||||
#define fun_u64 accu += popcount64(a[i] & b[i])
|
||||
#define fun_u8(i) accu += lookup8bit[a[i] & b[i]]
|
||||
int accu = 0;
|
||||
fast_loop_imp(fun_u64, fun_u8);
|
||||
return accu;
|
||||
#undef fun_u64
|
||||
#undef fun_u8
|
||||
}
|
||||
|
||||
bool is_subset(const uint8_t* data1, const uint8_t* data2, const size_t code_size) {
|
||||
#define fun_u64 if((a[i] & b[i]) != a[i]) return false
|
||||
#define fun_u8(i) if((a[i] & b[i]) != a[i]) return false
|
||||
fast_loop_imp(fun_u64, fun_u8);
|
||||
return true;
|
||||
#undef fun_u64
|
||||
#undef fun_u8
|
||||
}
|
||||
|
||||
float bvec_jaccard (const uint8_t* data1, const uint8_t* data2, const size_t code_size) {
|
||||
#define fun_u64 accu_num += popcount64(a[i] & b[i]); accu_den += popcount64(a[i] | b[i])
|
||||
#define fun_u8(i) accu_num += lookup8bit[a[i] & b[i]]; accu_den += lookup8bit[a[i] | b[i]]
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
fast_loop_imp(fun_u64, fun_u8);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
#undef fun_u64
|
||||
#undef fun_u8
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static
|
||||
void binary_distence_knn_mc(
|
||||
void binary_distance_knn_mc(
|
||||
int bytes_per_code,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
|
@ -173,9 +134,13 @@ void binary_distence_knn_mc(
|
|||
int64_t *labels,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
if ((bytes_per_code + sizeof(size_t) + k * sizeof(int64_t)) * n1 < size_1M) {
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
size_t l3_size = get_L3_Size();
|
||||
|
||||
/*
|
||||
* Later we may propose a more reasonable strategy.
|
||||
*/
|
||||
if (n1 < n2) {
|
||||
size_t group_num = n1 * thread_max_num;
|
||||
size_t *match_num = new size_t[group_num];
|
||||
int64_t *match_data = new int64_t[group_num * k];
|
||||
|
@ -229,12 +194,13 @@ void binary_distence_knn_mc(
|
|||
delete[] match_data;
|
||||
|
||||
} else {
|
||||
const size_t block_size = l3_size / bytes_per_code;
|
||||
|
||||
size_t *num = new size_t[n1];
|
||||
for (size_t i = 0; i < n1; i++) {
|
||||
num[i] = 0;
|
||||
}
|
||||
|
||||
const size_t block_size = batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
|
@ -272,7 +238,7 @@ void binary_distence_knn_mc(
|
|||
}
|
||||
}
|
||||
|
||||
void binary_distence_knn_mc (
|
||||
void binary_distance_knn_mc (
|
||||
MetricType metric_type,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
|
@ -287,21 +253,21 @@ void binary_distence_knn_mc (
|
|||
switch (metric_type) {
|
||||
case METRIC_Substructure:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_mc_Substructure(ncodes) \
|
||||
#define binary_distance_knn_mc_Substructure(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_mc<faiss::SubstructureComputer ## ncodes> \
|
||||
binary_distance_knn_mc<faiss::SubstructureComputer ## ncodes> \
|
||||
(ncodes, a, b, na, nb, k, distances, labels, bitset); \
|
||||
break;
|
||||
binary_distence_knn_mc_Substructure(8);
|
||||
binary_distence_knn_mc_Substructure(16);
|
||||
binary_distence_knn_mc_Substructure(32);
|
||||
binary_distence_knn_mc_Substructure(64);
|
||||
binary_distence_knn_mc_Substructure(128);
|
||||
binary_distence_knn_mc_Substructure(256);
|
||||
binary_distence_knn_mc_Substructure(512);
|
||||
#undef binary_distence_knn_mc_Substructure
|
||||
binary_distance_knn_mc_Substructure(8);
|
||||
binary_distance_knn_mc_Substructure(16);
|
||||
binary_distance_knn_mc_Substructure(32);
|
||||
binary_distance_knn_mc_Substructure(64);
|
||||
binary_distance_knn_mc_Substructure(128);
|
||||
binary_distance_knn_mc_Substructure(256);
|
||||
binary_distance_knn_mc_Substructure(512);
|
||||
#undef binary_distance_knn_mc_Substructure
|
||||
default:
|
||||
binary_distence_knn_mc<faiss::SubstructureComputerDefault>
|
||||
binary_distance_knn_mc<faiss::SubstructureComputerDefault>
|
||||
(ncodes, a, b, na, nb, k, distances, labels, bitset);
|
||||
break;
|
||||
}
|
||||
|
@ -309,21 +275,21 @@ void binary_distence_knn_mc (
|
|||
|
||||
case METRIC_Superstructure:
|
||||
switch (ncodes) {
|
||||
#define binary_distence_knn_mc_Superstructure(ncodes) \
|
||||
#define binary_distance_knn_mc_Superstructure(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distence_knn_mc<faiss::SuperstructureComputer ## ncodes> \
|
||||
binary_distance_knn_mc<faiss::SuperstructureComputer ## ncodes> \
|
||||
(ncodes, a, b, na, nb, k, distances, labels, bitset); \
|
||||
break;
|
||||
binary_distence_knn_mc_Superstructure(8);
|
||||
binary_distence_knn_mc_Superstructure(16);
|
||||
binary_distence_knn_mc_Superstructure(32);
|
||||
binary_distence_knn_mc_Superstructure(64);
|
||||
binary_distence_knn_mc_Superstructure(128);
|
||||
binary_distence_knn_mc_Superstructure(256);
|
||||
binary_distence_knn_mc_Superstructure(512);
|
||||
#undef binary_distence_knn_mc_Superstructure
|
||||
binary_distance_knn_mc_Superstructure(8);
|
||||
binary_distance_knn_mc_Superstructure(16);
|
||||
binary_distance_knn_mc_Superstructure(32);
|
||||
binary_distance_knn_mc_Superstructure(64);
|
||||
binary_distance_knn_mc_Superstructure(128);
|
||||
binary_distance_knn_mc_Superstructure(256);
|
||||
binary_distance_knn_mc_Superstructure(512);
|
||||
#undef binary_distance_knn_mc_Superstructure
|
||||
default:
|
||||
binary_distence_knn_mc<faiss::SuperstructureComputerDefault>
|
||||
binary_distance_knn_mc<faiss::SuperstructureComputerDefault>
|
||||
(ncodes, a, b, na, nb, k, distances, labels, bitset);
|
||||
break;
|
||||
}
|
||||
|
@ -334,4 +300,193 @@ void binary_distence_knn_mc (
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <class C, class MetricComputer>
|
||||
void binary_distance_knn_hc (
|
||||
int bytes_per_code,
|
||||
HeapArray<C> * ha,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
size_t n2,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
typedef typename C::T T;
|
||||
size_t k = ha->k;
|
||||
|
||||
size_t l3_size = get_L3_Size();
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
|
||||
/*
|
||||
* Here is an empirical formula, and later we may propose a more reasonable strategy.
|
||||
*/
|
||||
if ((bytes_per_code + k * (sizeof(T) + sizeof(int64_t))) * ha->nh * thread_max_num <= l3_size &&
|
||||
(ha->nh < (n2 >> 11) + thread_max_num / 3)) {
|
||||
// init heap
|
||||
size_t thread_heap_size = ha->nh * k;
|
||||
size_t all_heap_size = thread_heap_size * thread_max_num;
|
||||
T *value = new T[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
T init_value = (typeid(T) == typeid(float)) ? (1.0 / 0.0) : 0x7fffffff;
|
||||
for (int i = 0; i < all_heap_size; i++) {
|
||||
value[i] = init_value;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
MetricComputer *hc = new MetricComputer[ha->nh];
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
int thread_no = omp_get_thread_num();
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
T dis = hc[i].compute (bs2_);
|
||||
T *val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t *ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (C::cmp(val_[0], dis)) {
|
||||
faiss::heap_swap_top<C>(k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
T * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
T *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (C::cmp(value_x[0], value_x_t[j])) {
|
||||
faiss::heap_swap_top<C>(k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(ha->val, value, thread_heap_size * sizeof(T));
|
||||
memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t));
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
const size_t block_size = l3_size / bytes_per_code;
|
||||
|
||||
ha->heapify ();
|
||||
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
MetricComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
|
||||
const uint8_t *bs2_ = bs2 + j0 * bytes_per_code;
|
||||
T dis;
|
||||
T *__restrict bh_val_ = ha->val + i * k;
|
||||
int64_t *__restrict bh_ids_ = ha->ids + i * k;
|
||||
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
|
||||
if (!bitset || !bitset->test(j)) {
|
||||
dis = hc.compute (bs2_);
|
||||
if (C::cmp(bh_val_[0], dis)) {
|
||||
faiss::heap_swap_top<C>(k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
ha->reorder ();
|
||||
}
|
||||
|
||||
template <class C>
|
||||
void binary_distance_knn_hc (
|
||||
MetricType metric_type,
|
||||
HeapArray<C> * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
size_t dim = ncodes * 8;
|
||||
switch (metric_type) {
|
||||
case METRIC_Jaccard: {
|
||||
switch (ncodes) {
|
||||
#define binary_distance_knn_hc_jaccard(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distance_knn_hc<C, faiss::JaccardComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, bitset); \
|
||||
break;
|
||||
binary_distance_knn_hc_jaccard(8);
|
||||
binary_distance_knn_hc_jaccard(16);
|
||||
binary_distance_knn_hc_jaccard(32);
|
||||
binary_distance_knn_hc_jaccard(64);
|
||||
binary_distance_knn_hc_jaccard(128);
|
||||
binary_distance_knn_hc_jaccard(256);
|
||||
binary_distance_knn_hc_jaccard(512);
|
||||
#undef binary_distence_knn_hc_jaccard
|
||||
default:
|
||||
binary_distance_knn_hc<C, faiss::JaccardComputerDefault>
|
||||
(ncodes, ha, a, b, nb, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case METRIC_Hamming: {
|
||||
switch (ncodes) {
|
||||
#define binary_distance_knn_hc_hamming(ncodes) \
|
||||
case ncodes: \
|
||||
binary_distance_knn_hc<C, faiss::HammingComputer ## ncodes> \
|
||||
(ncodes, ha, a, b, nb, bitset); \
|
||||
break;
|
||||
binary_distance_knn_hc_hamming(4);
|
||||
binary_distance_knn_hc_hamming(8);
|
||||
binary_distance_knn_hc_hamming(16);
|
||||
binary_distance_knn_hc_hamming(20);
|
||||
binary_distance_knn_hc_hamming(32);
|
||||
binary_distance_knn_hc_hamming(64);
|
||||
#undef binary_distence_knn_hc_jaccard
|
||||
default:
|
||||
binary_distance_knn_hc<C, faiss::HammingComputerDefault>
|
||||
(ncodes, ha, a, b, nb, bitset);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template
|
||||
void binary_distance_knn_hc<CMax<int, int64_t>>(
|
||||
MetricType metric_type,
|
||||
int_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
ConcurrentBitsetPtr bitset);
|
||||
|
||||
template
|
||||
void binary_distance_knn_hc<CMax<float, int64_t>>(
|
||||
MetricType metric_type,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
ConcurrentBitsetPtr bitset);
|
||||
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -1,39 +1,83 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#ifndef FAISS_BINARY_DISTANCE_H
|
||||
#define FAISS_BINARY_DISTANCE_H
|
||||
|
||||
#include "faiss/Index.h"
|
||||
|
||||
#include <faiss/utils/hamming.h>
|
||||
|
||||
#include "faiss/MetricType.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
/* The binary distance type */
|
||||
typedef float tadis_t;
|
||||
|
||||
namespace faiss {
|
||||
/**
|
||||
* Calculate the number of bit 1
|
||||
*/
|
||||
extern int popcnt(
|
||||
const uint8_t* data,
|
||||
const size_t code_size);
|
||||
|
||||
/** Return the k smallest distances for a set of binary query vectors,
|
||||
* using a max heap.
|
||||
* @param a queries, size ha->nh * ncodes
|
||||
* @param b database, size nb * ncodes
|
||||
* @param nb number of database vectors
|
||||
* @param ncodes size of the binary codes (bytes)
|
||||
* @param ordered if != 0: order the results by decreasing distance
|
||||
* (may be bottleneck for k/n > 0.01) */
|
||||
void binary_distence_knn_hc (
|
||||
MetricType metric_type,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int ordered,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
/**
|
||||
* Calculate the number of bit 1 after xor
|
||||
*/
|
||||
extern int xor_popcnt(
|
||||
const uint8_t* data1,
|
||||
const uint8_t* data2,
|
||||
const size_t code_size);
|
||||
|
||||
/**
|
||||
* Calculate the number of bit 1 after or
|
||||
*/
|
||||
extern int or_popcnt(
|
||||
const uint8_t* data1,
|
||||
const uint8_t* data2,
|
||||
const size_t code_size);
|
||||
|
||||
/**
|
||||
* Calculate the number of bit 1 after and
|
||||
*/
|
||||
extern int and_popcnt(
|
||||
const uint8_t* data1,
|
||||
const uint8_t* data2,
|
||||
const size_t code_size);
|
||||
|
||||
/**
|
||||
* Judge whether data1 is a subset of data2
|
||||
*/
|
||||
extern bool is_subset(
|
||||
const uint8_t* data1,
|
||||
const uint8_t* data2,
|
||||
const size_t code_size);
|
||||
|
||||
/**
|
||||
* Calculate Jaccard distance
|
||||
*/
|
||||
extern float bvec_jaccard (
|
||||
const uint8_t* data1,
|
||||
const uint8_t* data2,
|
||||
const size_t code_size);
|
||||
|
||||
/**
|
||||
* Distance conversion between Jaccard and Tanimoto
|
||||
*/
|
||||
inline float Jaccard_2_Tanimoto (float jcd) {
|
||||
// To avoid a negative zero in C language
|
||||
return (jcd == 0.0) ? 0.0 : -log2(1 - jcd);
|
||||
}
|
||||
|
||||
/** Return the k matched distances for a set of binary query vectors,
|
||||
* using a max heap.
|
||||
* using an array.
|
||||
* @param a queries, size ha->nh * ncodes
|
||||
* @param b database, size nb * ncodes
|
||||
* @param na number of queries vectors
|
||||
|
@ -41,7 +85,7 @@ namespace faiss {
|
|||
* @param k number of the matched vectors to return
|
||||
* @param ncodes size of the binary codes (bytes)
|
||||
*/
|
||||
void binary_distence_knn_mc (
|
||||
void binary_distance_knn_mc (
|
||||
MetricType metric_type,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
|
@ -51,12 +95,50 @@ namespace faiss {
|
|||
size_t ncodes,
|
||||
float *distances,
|
||||
int64_t *labels,
|
||||
ConcurrentBitsetPtr bitset);
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
/** Return the k smallest distances for a set of binary query vectors,
|
||||
* using a heap.
|
||||
* @param a queries, size ha->nh * ncodes
|
||||
* @param b database, size nb * ncodes
|
||||
* @param nb number of database vectors
|
||||
* @param ncodes size of the binary codes (bytes)
|
||||
* @param ordered if != 0: order the results by decreasing distance
|
||||
* (may be bottleneck for k/n > 0.01)
|
||||
*/
|
||||
template <class C>
|
||||
void binary_distance_knn_hc (
|
||||
MetricType metric_type,
|
||||
HeapArray<C> * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
|
||||
extern template
|
||||
void binary_distance_knn_hc<CMax<int, int64_t>>(
|
||||
MetricType metric_type,
|
||||
int_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
extern template
|
||||
void binary_distance_knn_hc<CMax<float, int64_t>>(
|
||||
MetricType metric_type,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
} // namespace faiss
|
||||
|
||||
#include <faiss/utils/jaccard-inl.h>
|
||||
#include <faiss/utils/substructure-inl.h>
|
||||
#include <faiss/utils/superstructure-inl.h>
|
||||
|
||||
|
||||
#endif // FAISS_BINARY_DISTANCE_H
|
||||
|
|
|
@ -129,9 +129,6 @@ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/***************************************************************************
|
||||
* KNN functions
|
||||
***************************************************************************/
|
||||
|
|
|
@ -46,11 +46,6 @@ float fvec_Linf_sse (
|
|||
size_t d);
|
||||
#endif
|
||||
|
||||
float fvec_jaccard (
|
||||
const float * x,
|
||||
const float * y,
|
||||
size_t d);
|
||||
|
||||
/** Compute pairwise distances between sets of vectors
|
||||
*
|
||||
* @param d dimension of the vectors
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
@ -93,7 +93,7 @@ struct HammingComputer4 {
|
|||
a0 = *(uint32_t *)a;
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b) const {
|
||||
inline int compute (const uint8_t *b) const {
|
||||
return popcount64 (*(uint32_t *)b ^ a0);
|
||||
}
|
||||
|
||||
|
@ -113,7 +113,7 @@ struct HammingComputer8 {
|
|||
a0 = *(uint64_t *)a;
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b) const {
|
||||
inline int compute (const uint8_t *b) const {
|
||||
return popcount64 (*(uint64_t *)b ^ a0);
|
||||
}
|
||||
|
||||
|
@ -135,7 +135,7 @@ struct HammingComputer16 {
|
|||
a0 = a[0]; a1 = a[1];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
inline int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1);
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ struct HammingComputer20 {
|
|||
a0 = a[0]; a1 = a[1]; a2 = a[2];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
inline int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) +
|
||||
popcount64 (*(uint32_t*)(b + 2) ^ a2);
|
||||
|
@ -182,7 +182,7 @@ struct HammingComputer32 {
|
|||
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
inline int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) +
|
||||
popcount64 (b[2] ^ a2) + popcount64 (b[3] ^ a3);
|
||||
|
@ -206,7 +206,7 @@ struct HammingComputer64 {
|
|||
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
inline int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) +
|
||||
popcount64 (b[2] ^ a2) + popcount64 (b[3] ^ a3) +
|
||||
|
@ -216,7 +216,6 @@ struct HammingComputer64 {
|
|||
|
||||
};
|
||||
|
||||
// very inefficient...
|
||||
struct HammingComputerDefault {
|
||||
const uint8_t *a;
|
||||
int n;
|
||||
|
@ -232,64 +231,8 @@ struct HammingComputerDefault {
|
|||
n = code_size;
|
||||
}
|
||||
|
||||
int hamming (const uint8_t *b8) const {
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64 (a[i] ^ b8[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct HammingComputerM8 {
|
||||
const uint64_t *a;
|
||||
int n;
|
||||
|
||||
HammingComputerM8 () {}
|
||||
|
||||
HammingComputerM8 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size % 8 == 0);
|
||||
a = (uint64_t *)a8;
|
||||
n = code_size / 8;
|
||||
}
|
||||
|
||||
int hamming (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64 (a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// even more inefficient!
|
||||
struct HammingComputerM4 {
|
||||
const uint32_t *a;
|
||||
int n;
|
||||
|
||||
HammingComputerM4 () {}
|
||||
|
||||
HammingComputerM4 (const uint8_t *a4, int code_size) {
|
||||
set (a4, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a4, int code_size) {
|
||||
assert (code_size % 4 == 0);
|
||||
a = (uint32_t *)a4;
|
||||
n = code_size / 4;
|
||||
}
|
||||
|
||||
int hamming (const uint8_t *b8) const {
|
||||
const uint32_t *b = (uint32_t *)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64 (a[i] ^ b[i]);
|
||||
return accu;
|
||||
int compute (const uint8_t *b8) const {
|
||||
return xor_popcnt(a, b8, n);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -300,9 +243,9 @@ struct HammingComputerM4 {
|
|||
|
||||
// default template
|
||||
template<int CODE_SIZE>
|
||||
struct HammingComputer: HammingComputerM8 {
|
||||
struct HammingComputer: HammingComputerDefault {
|
||||
HammingComputer (const uint8_t *a, int code_size):
|
||||
HammingComputerM8(a, code_size) {}
|
||||
HammingComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
|
@ -345,7 +288,7 @@ struct GenHammingComputer8 {
|
|||
a0 = *(uint64_t *)a;
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b) const {
|
||||
inline int compute (const uint8_t *b) const {
|
||||
return generalized_hamming_64 (*(uint64_t *)b ^ a0);
|
||||
}
|
||||
|
||||
|
@ -360,7 +303,7 @@ struct GenHammingComputer16 {
|
|||
a0 = a[0]; a1 = a[1];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
inline int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
return generalized_hamming_64 (b[0] ^ a0) +
|
||||
generalized_hamming_64 (b[1] ^ a1);
|
||||
|
@ -377,7 +320,7 @@ struct GenHammingComputer32 {
|
|||
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
inline int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
return generalized_hamming_64 (b[0] ^ a0) +
|
||||
generalized_hamming_64 (b[1] ^ a1) +
|
||||
|
@ -397,7 +340,7 @@ struct GenHammingComputerM8 {
|
|||
n = code_size / 8;
|
||||
}
|
||||
|
||||
int hamming (const uint8_t *b8) const {
|
||||
int compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
|
@ -449,7 +392,7 @@ struct HCounterState {
|
|||
k(k) {}
|
||||
|
||||
void update_counter(const uint8_t *y, size_t j) {
|
||||
int32_t dis = hc.hamming(y);
|
||||
int32_t dis = hc.compute(y);
|
||||
|
||||
if (dis <= thres) {
|
||||
if (dis < thres) {
|
||||
|
|
|
@ -33,9 +33,9 @@
|
|||
#include <omp.h>
|
||||
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
|
||||
static const size_t BLOCKSIZE_QUERY = 8192;
|
||||
static const size_t size_1M = 1 * 1024 * 1024;
|
||||
|
@ -44,7 +44,7 @@ namespace faiss {
|
|||
|
||||
size_t hamming_batch_size = 65536;
|
||||
|
||||
static const uint8_t hamdis_tab_ham_bytes[256] = {
|
||||
const uint8_t lookup8bit[256] = {
|
||||
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
|
@ -63,7 +63,6 @@ static const uint8_t hamdis_tab_ham_bytes[256] = {
|
|||
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
|
||||
};
|
||||
|
||||
|
||||
/* Elementary Hamming distance computation: unoptimized */
|
||||
template <size_t nbits, typename T>
|
||||
T hamming (const uint8_t *bs1,
|
||||
|
@ -73,7 +72,7 @@ T hamming (const uint8_t *bs1,
|
|||
size_t i;
|
||||
T h = 0;
|
||||
for (i = 0; i < nbytes; i++)
|
||||
h += (T) hamdis_tab_ham_bytes[bs1[i]^bs2[i]];
|
||||
h += (T) lookup8bit[bs1[i]^bs2[i]];
|
||||
return h;
|
||||
}
|
||||
|
||||
|
@ -262,254 +261,6 @@ size_t match_hamming_thres (
|
|||
}
|
||||
|
||||
|
||||
/* Return closest neighbors w.r.t Hamming distance, using a heap. */
|
||||
template <class HammingComputer>
|
||||
static
|
||||
void hammings_knn_hc (
|
||||
int bytes_per_code,
|
||||
int_maxheap_array_t * ha,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
size_t n2,
|
||||
bool order = true,
|
||||
bool init_heap = true,
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = ha->k;
|
||||
|
||||
if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) {
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
// init heap
|
||||
size_t thread_heap_size = ha->nh * k;
|
||||
size_t all_heap_size = thread_heap_size * thread_max_num;
|
||||
hamdis_t *value = new hamdis_t[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
for (int i = 0; i < all_heap_size; i++) {
|
||||
value[i] = 0x7fffffff;
|
||||
labels[i] = -1;
|
||||
}
|
||||
|
||||
HammingComputer *hc = new HammingComputer[ha->nh];
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
int thread_no = omp_get_thread_num();
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hamdis_t dis = hc[i].hamming (bs2_);
|
||||
|
||||
hamdis_t * val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_swap_top<hamdis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
hamdis_t * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
hamdis_t *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
faiss::maxheap_swap_top<hamdis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(ha->val, value, thread_heap_size * sizeof(hamdis_t));
|
||||
memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t));
|
||||
|
||||
delete[] hc;
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
if (init_heap) ha->heapify ();
|
||||
const size_t block_size = hamming_batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
|
||||
hamdis_t dis;
|
||||
hamdis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
size_t j;
|
||||
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
dis = hc.hamming (bs2_);
|
||||
if (dis < bh_val_[0]) {
|
||||
faiss::maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (order) ha->reorder ();
|
||||
}
|
||||
|
||||
/* Return closest neighbors w.r.t Hamming distance, using max count. */
|
||||
template <class HammingComputer>
|
||||
static
|
||||
void hammings_knn_mc (
|
||||
int bytes_per_code,
|
||||
const uint8_t *a,
|
||||
const uint8_t *b,
|
||||
size_t na,
|
||||
size_t nb,
|
||||
size_t k,
|
||||
int32_t *distances,
|
||||
int64_t *labels,
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
const int nBuckets = bytes_per_code * 8 + 1;
|
||||
std::vector<int> all_counters(na * nBuckets, 0);
|
||||
std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
|
||||
|
||||
std::vector<HCounterState<HammingComputer>> cs;
|
||||
for (size_t i = 0; i < na; ++i) {
|
||||
cs.push_back(HCounterState<HammingComputer>(
|
||||
all_counters.data() + i * nBuckets,
|
||||
all_ids_per_dis.get() + i * nBuckets * k,
|
||||
a + i * bytes_per_code,
|
||||
8 * bytes_per_code,
|
||||
k
|
||||
));
|
||||
}
|
||||
|
||||
const size_t block_size = hamming_batch_size;
|
||||
for (size_t j0 = 0; j0 < nb; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, nb);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < na; ++i) {
|
||||
for (size_t j = j0; j < j1; ++j) {
|
||||
if (!bitset || !bitset->test(j)) {
|
||||
cs[i].update_counter(b + j * bytes_per_code, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < na; ++i) {
|
||||
HCounterState<HammingComputer>& csi = cs[i];
|
||||
|
||||
int nres = 0;
|
||||
for (int b = 0; b < nBuckets && nres < k; b++) {
|
||||
for (int l = 0; l < csi.counters[b] && nres < k; l++) {
|
||||
labels[i * k + nres] = csi.ids_per_dis[b * k + l];
|
||||
distances[i * k + nres] = b;
|
||||
nres++;
|
||||
}
|
||||
}
|
||||
while (nres < k) {
|
||||
labels[i * k + nres] = -1;
|
||||
distances[i * k + nres] = std::numeric_limits<int32_t>::max();
|
||||
++nres;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// works faster than the template version
|
||||
static
|
||||
void hammings_knn_hc_1 (
|
||||
int_maxheap_array_t * ha,
|
||||
const uint64_t * bs1,
|
||||
const uint64_t * bs2,
|
||||
size_t n2,
|
||||
bool order = true,
|
||||
bool init_heap = true,
|
||||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
const size_t nwords = 1;
|
||||
size_t k = ha->k;
|
||||
|
||||
if (init_heap) {
|
||||
ha->heapify ();
|
||||
}
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh == 1) {
|
||||
// omp for n2
|
||||
int all_heap_size = thread_max_num * k;
|
||||
hamdis_t *value = new hamdis_t[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
|
||||
// init heap
|
||||
for (int i = 0; i < all_heap_size; i++) {
|
||||
value[i] = 0x7fffffff;
|
||||
}
|
||||
const uint64_t bs1_ = bs1[0];
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < n2; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
hamdis_t dis = popcount64 (bs1_ ^ bs2[j]);
|
||||
|
||||
int thread_no = omp_get_thread_num();
|
||||
hamdis_t * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (dis < val_[0]) {
|
||||
faiss::maxheap_swap_top<hamdis_t> (k, val_, ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
// merge heap
|
||||
hamdis_t * __restrict bh_val_ = ha->val;
|
||||
int64_t * __restrict bh_ids_ = ha->ids;
|
||||
for (int i = 0; i < all_heap_size; i++) {
|
||||
if (value[i] < bh_val_[0]) {
|
||||
faiss::maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
const uint64_t bs1_ = bs1 [i];
|
||||
const uint64_t * bs2_ = bs2;
|
||||
hamdis_t dis;
|
||||
hamdis_t * bh_val_ = ha->val + i * k;
|
||||
hamdis_t bh_val_0 = bh_val_[0];
|
||||
int64_t * bh_ids_ = ha->ids + i * k;
|
||||
size_t j;
|
||||
for (j = 0; j < n2; j++, bs2_+= nwords) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
dis = popcount64 (bs1_ ^ *bs2_);
|
||||
if (dis < bh_val_0) {
|
||||
faiss::maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
bh_val_0 = bh_val_[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (order) {
|
||||
ha->reorder ();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/* Functions to maps vectors to bits. Assume proper allocation done beforehand,
|
||||
meaning that b should be be able to receive as many bits as x may produce. */
|
||||
|
||||
|
@ -648,102 +399,6 @@ void hammings (
|
|||
}
|
||||
}
|
||||
|
||||
void hammings_knn(
|
||||
int_maxheap_array_t *ha,
|
||||
const uint8_t *a,
|
||||
const uint8_t *b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int order)
|
||||
{
|
||||
hammings_knn_hc(ha, a, b, nb, ncodes, order);
|
||||
}
|
||||
|
||||
void hammings_knn_hc (
|
||||
int_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int order,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
switch (ncodes) {
|
||||
case 4:
|
||||
hammings_knn_hc<faiss::HammingComputer4>
|
||||
(4, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
case 8:
|
||||
hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true, bitset);
|
||||
// hammings_knn_hc<faiss::HammingComputer8>
|
||||
// (8, ha, a, b, nb, order, true);
|
||||
break;
|
||||
case 16:
|
||||
hammings_knn_hc<faiss::HammingComputer16>
|
||||
(16, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
case 32:
|
||||
hammings_knn_hc<faiss::HammingComputer32>
|
||||
(32, ha, a, b, nb, order, true, bitset);
|
||||
break;
|
||||
default:
|
||||
if(ncodes % 8 == 0) {
|
||||
hammings_knn_hc<faiss::HammingComputerM8>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
} else {
|
||||
hammings_knn_hc<faiss::HammingComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true, bitset);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void hammings_knn_mc(
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t na,
|
||||
size_t nb,
|
||||
size_t k,
|
||||
size_t ncodes,
|
||||
int32_t *distances,
|
||||
int64_t *labels,
|
||||
ConcurrentBitsetPtr bitset)
|
||||
{
|
||||
switch (ncodes) {
|
||||
case 4:
|
||||
hammings_knn_mc<faiss::HammingComputer4>(
|
||||
4, a, b, na, nb, k, distances, labels, bitset
|
||||
);
|
||||
break;
|
||||
case 8:
|
||||
// TODO(hoss): Write analog to hammings_knn_hc_1
|
||||
// hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true);
|
||||
hammings_knn_mc<faiss::HammingComputer8>(
|
||||
8, a, b, na, nb, k, distances, labels, bitset
|
||||
);
|
||||
break;
|
||||
case 16:
|
||||
hammings_knn_mc<faiss::HammingComputer16>(
|
||||
16, a, b, na, nb, k, distances, labels, bitset
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
hammings_knn_mc<faiss::HammingComputer32>(
|
||||
32, a, b, na, nb, k, distances, labels, bitset
|
||||
);
|
||||
break;
|
||||
default:
|
||||
if(ncodes % 8 == 0) {
|
||||
hammings_knn_mc<faiss::HammingComputerM8>(
|
||||
ncodes, a, b, na, nb, k, distances, labels, bitset
|
||||
);
|
||||
} else {
|
||||
hammings_knn_mc<faiss::HammingComputerDefault>(
|
||||
ncodes, a, b, na, nb, k, distances, labels, bitset
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <class HammingComputer>
|
||||
static
|
||||
void hamming_range_search_template (
|
||||
|
@ -767,7 +422,7 @@ void hamming_range_search_template (
|
|||
RangeQueryResult & qres = pres.new_result (i);
|
||||
|
||||
for (size_t j = 0; j < nb; j++) {
|
||||
int dis = hc.hamming (yi);
|
||||
int dis = hc.compute (yi);
|
||||
if (dis < radius) {
|
||||
qres.add(dis, j);
|
||||
}
|
||||
|
@ -795,12 +450,7 @@ void hamming_range_search (
|
|||
case 8: HC(HammingComputer8); break;
|
||||
case 16: HC(HammingComputer16); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -922,7 +572,7 @@ static void hamming_dis_inner_loop (
|
|||
HammingComputer hc (ca, code_size);
|
||||
|
||||
for (size_t j = 0; j < nb; j++) {
|
||||
int ndiff = hc.hamming (cb);
|
||||
int ndiff = hc.compute (cb);
|
||||
cb += code_size;
|
||||
if (ndiff < bh_val_[0]) {
|
||||
maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
|
||||
|
|
|
@ -35,6 +35,8 @@ typedef int32_t hamdis_t;
|
|||
|
||||
namespace faiss {
|
||||
|
||||
extern const uint8_t lookup8bit[256];
|
||||
|
||||
/**************************************************
|
||||
* General bit vector functions
|
||||
**************************************************/
|
||||
|
@ -107,8 +109,6 @@ struct BitstringReader {
|
|||
* Hamming distance computation functions
|
||||
**************************************************/
|
||||
|
||||
|
||||
|
||||
extern size_t hamming_batch_size;
|
||||
|
||||
inline int popcount64(uint64_t x) {
|
||||
|
@ -131,58 +131,6 @@ void hammings (
|
|||
hamdis_t * dis);
|
||||
|
||||
|
||||
|
||||
|
||||
/** Return the k smallest Hamming distances for a set of binary query vectors,
|
||||
* using a max heap.
|
||||
* @param a queries, size ha->nh * ncodes
|
||||
* @param b database, size nb * ncodes
|
||||
* @param nb number of database vectors
|
||||
* @param ncodes size of the binary codes (bytes)
|
||||
* @param ordered if != 0: order the results by decreasing distance
|
||||
* (may be bottleneck for k/n > 0.01) */
|
||||
void hammings_knn_hc (
|
||||
int_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int ordered,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
/* Legacy alias to hammings_knn_hc. */
|
||||
void hammings_knn (
|
||||
int_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int ordered,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
/** Return the k smallest Hamming distances for a set of binary query vectors,
|
||||
* using counting max.
|
||||
* @param a queries, size na * ncodes
|
||||
* @param b database, size nb * ncodes
|
||||
* @param na number of query vectors
|
||||
* @param nb number of database vectors
|
||||
* @param k number of vectors/distances to return
|
||||
* @param ncodes size of the binary codes (bytes)
|
||||
* @param distances output distances from each query vector to its k nearest
|
||||
* neighbors
|
||||
* @param labels output ids of the k nearest neighbors to each query vector
|
||||
*/
|
||||
void hammings_knn_mc (
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t na,
|
||||
size_t nb,
|
||||
size_t k,
|
||||
size_t ncodes,
|
||||
int32_t *distances,
|
||||
int64_t *labels,
|
||||
ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
/** same as hammings_knn except we are doing a range search with radius */
|
||||
void hamming_range_search (
|
||||
const uint8_t * a,
|
||||
|
|
|
@ -1,3 +1,19 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#ifndef FAISS_JACCARD_INL_H
|
||||
#define FAISS_JACCARD_INL_H
|
||||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
struct JaccardComputer8 {
|
||||
|
@ -19,9 +35,7 @@ namespace faiss {
|
|||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu_num = popcount64 (b[0] & a0);
|
||||
int accu_den = popcount64 (b[0] | a0);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -45,9 +59,7 @@ namespace faiss {
|
|||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
|
||||
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -73,9 +85,7 @@ namespace faiss {
|
|||
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
|
||||
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
|
||||
popcount64 (b[2] | a2) + popcount64 (b[3] | a3);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -106,9 +116,7 @@ namespace faiss {
|
|||
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
|
||||
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
|
||||
popcount64 (b[6] | a6) + popcount64 (b[7] | a7);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -150,9 +158,7 @@ namespace faiss {
|
|||
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
|
||||
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
|
||||
popcount64 (b[14] | a14) + popcount64 (b[15] | a15);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -216,9 +222,7 @@ struct JaccardComputer256 {
|
|||
popcount64 (b[26] | a26) + popcount64 (b[27] | a27) +
|
||||
popcount64 (b[28] | a28) + popcount64 (b[29] | a29) +
|
||||
popcount64 (b[30] | a30) + popcount64 (b[31] | a31);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -326,9 +330,7 @@ struct JaccardComputer256 {
|
|||
popcount64 (b[58] | a58) + popcount64 (b[59] | a59) +
|
||||
popcount64 (b[60] | a60) + popcount64 (b[61] | a61) +
|
||||
popcount64 (b[62] | a62) + popcount64 (b[63] | a63);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -349,15 +351,7 @@ struct JaccardComputer256 {
|
|||
}
|
||||
|
||||
float compute (const uint8_t *b8) const {
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
accu_num += popcount64(a[i] & b8[i]);
|
||||
accu_den += popcount64(a[i] | b8[i]);
|
||||
}
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
return bvec_jaccard(a, b8, n);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -387,3 +381,5 @@ struct JaccardComputer256 {
|
|||
#undef SPECIALIZED_HC
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1,3 +1,19 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#ifndef FAISS_SBUSTRUCTURE_INL_H
|
||||
#define FAISS_SBUSTRUCTURE_INL_H
|
||||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
struct SubstructureComputer8 {
|
||||
|
@ -264,13 +280,7 @@ namespace faiss {
|
|||
}
|
||||
|
||||
bool compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
for (int i = 0; i < n; i++) {
|
||||
if ((a[i] & b[i]) != a[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return is_subset(a, b8, n);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -300,3 +310,5 @@ namespace faiss {
|
|||
#undef SPECIALIZED_HC
|
||||
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,3 +1,19 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#ifndef FAISS_SUPERSTRUCTURE_INL_H
|
||||
#define FAISS_SUPERSTRUCTURE_INL_H
|
||||
|
||||
#include <faiss/utils/BinaryDistance.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
struct SuperstructureComputer8 {
|
||||
|
@ -264,13 +280,7 @@ namespace faiss {
|
|||
}
|
||||
|
||||
bool compute (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
for (int i = 0; i < n; i++) {
|
||||
if ((a[i] & b[i]) != b[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return is_subset(b8, a, n);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -300,3 +310,5 @@ namespace faiss {
|
|||
#undef SPECIALIZED_HC
|
||||
|
||||
}
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue