BinaryFlat support AVX2 (#4694)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/4698/head
shengjun.li 2021-02-04 16:42:42 +08:00 committed by GitHub
parent bbbf254b7c
commit 1211e9bac3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 337 additions and 36 deletions

View File

@ -8,9 +8,10 @@ Please mark all change in change log and use the issue from GitHub
- \#4678 Server crash on BinaryFlat if dimension is not a power of 2
## Feature
- \#4676 Support configurable metric labels for Prometheus
## Improvement
- \#1970 Improve the performance of BinaryFlat by AVX2
- \#4676 Support configurable metric labels cluster and instance for Prometheus
## Task

View File

@ -419,46 +419,56 @@ void binary_distance_knn_hc (
size_t dim = ncodes * 8;
switch (metric_type) {
case METRIC_Jaccard: {
switch (ncodes) {
#define binary_distance_knn_hc_jaccard(ncodes) \
case ncodes: \
binary_distance_knn_hc<C, faiss::JaccardComputer ## ncodes> \
(ncodes, ha, a, b, nb, bitset); \
break;
binary_distance_knn_hc_jaccard(8);
binary_distance_knn_hc_jaccard(16);
binary_distance_knn_hc_jaccard(32);
binary_distance_knn_hc_jaccard(64);
binary_distance_knn_hc_jaccard(128);
binary_distance_knn_hc_jaccard(256);
binary_distance_knn_hc_jaccard(512);
#undef binary_distence_knn_hc_jaccard
default:
binary_distance_knn_hc<C, faiss::JaccardComputerDefault>
if (support_avx2() && ncodes > 64) {
binary_distance_knn_hc<C, faiss::JaccardComputerAVX2>
(ncodes, ha, a, b, nb, bitset);
break;
} else {
switch (ncodes) {
#define binary_distance_knn_hc_jaccard(ncodes) \
case ncodes: \
binary_distance_knn_hc<C, faiss::JaccardComputer ## ncodes> \
(ncodes, ha, a, b, nb, bitset); \
break;
binary_distance_knn_hc_jaccard(8);
binary_distance_knn_hc_jaccard(16);
binary_distance_knn_hc_jaccard(32);
binary_distance_knn_hc_jaccard(64);
binary_distance_knn_hc_jaccard(128);
binary_distance_knn_hc_jaccard(256);
binary_distance_knn_hc_jaccard(512);
#undef binary_distence_knn_hc_jaccard
default:
binary_distance_knn_hc<C, faiss::JaccardComputerDefault>
(ncodes, ha, a, b, nb, bitset);
break;
}
}
break;
}
case METRIC_Hamming: {
switch (ncodes) {
#define binary_distance_knn_hc_hamming(ncodes) \
case ncodes: \
binary_distance_knn_hc<C, faiss::HammingComputer ## ncodes> \
(ncodes, ha, a, b, nb, bitset); \
break;
binary_distance_knn_hc_hamming(4);
binary_distance_knn_hc_hamming(8);
binary_distance_knn_hc_hamming(16);
binary_distance_knn_hc_hamming(20);
binary_distance_knn_hc_hamming(32);
binary_distance_knn_hc_hamming(64);
#undef binary_distence_knn_hc_jaccard
default:
binary_distance_knn_hc<C, faiss::HammingComputerDefault>
if (support_avx2() && ncodes > 64) {
binary_distance_knn_hc<C, faiss::HammingComputerAVX2>
(ncodes, ha, a, b, nb, bitset);
break;
} else {
switch (ncodes) {
#define binary_distance_knn_hc_hamming(ncodes) \
case ncodes: \
binary_distance_knn_hc<C, faiss::HammingComputer ## ncodes> \
(ncodes, ha, a, b, nb, bitset); \
break;
binary_distance_knn_hc_hamming(4);
binary_distance_knn_hc_hamming(8);
binary_distance_knn_hc_hamming(16);
binary_distance_knn_hc_hamming(20);
binary_distance_knn_hc_hamming(32);
binary_distance_knn_hc_hamming(64);
#undef binary_distence_knn_hc_jaccard
default:
binary_distance_knn_hc<C, faiss::HammingComputerDefault>
(ncodes, ha, a, b, nb, bitset);
break;
}
}
break;
}

View File

@ -7,6 +7,7 @@
#pragma once
#include <stddef.h>
#include <stdint.h>
namespace faiss {
@ -29,4 +30,21 @@ fvec_L1_avx(const float* x, const float* y, size_t d);
float
fvec_Linf_avx(const float* x, const float* y, size_t d);
/// binary distance
int
xor_popcnt_AVX2_lookup(const uint8_t* data1, const uint8_t* data2, const size_t n);
int
or_popcnt_AVX2_lookup(const uint8_t* data1, const uint8_t* data2, const size_t n);
int
and_popcnt_AVX2_lookup(const uint8_t* data1, const uint8_t* data2, const size_t n);
/// popcnt
int
popcnt_AVX2_lookup(const uint8_t* data, const size_t n);
float
jaccard__AVX2(const uint8_t * a, const uint8_t * b, size_t n);
} // namespace faiss

View File

@ -13,6 +13,8 @@
namespace faiss {
extern const uint8_t lookup8bit[256];
#ifdef __SSE__
// reads 0 <= d < 4 floats as __m128
static inline __m128 masked_read (int d, const float *x) {
@ -186,6 +188,233 @@ float fvec_Linf_avx (const float* x, const float* y, size_t d) {
return _mm_cvtss_f32 (msum2);
}
const __m256i lookup = _mm256_setr_epi8(
/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4,
/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4
);
int popcnt_AVX2_lookup(const uint8_t* data, const size_t n) {
size_t i = 0;
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i acc = _mm256_setzero_si256();
#define ITER { \
const __m256i vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data + i)); \
const __m256i lo = _mm256_and_si256(vec, low_mask); \
const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
local = _mm256_add_epi8(local, popcnt1); \
local = _mm256_add_epi8(local, popcnt2); \
i += 32; \
}
while (i + 8*32 <= n) {
__m256i local = _mm256_setzero_si256();
ITER ITER ITER ITER
ITER ITER ITER ITER
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
}
__m256i local = _mm256_setzero_si256();
while (i + 32 <= n) {
ITER;
}
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
#undef ITER
int result = 0;
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));
for (/**/; i < n; i++) {
result += lookup8bit[data[i]];
}
return result;
}
int xor_popcnt_AVX2_lookup(const uint8_t* data1, const uint8_t* data2, const size_t n) {
size_t i = 0;
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i acc = _mm256_setzero_si256();
#define ITER { \
const __m256i s1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data1 + i)); \
const __m256i s2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data2 + i)); \
const __m256i vec = _mm256_xor_si256(s1, s2);\
const __m256i lo = _mm256_and_si256(vec, low_mask); \
const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
local = _mm256_add_epi8(local, popcnt1); \
local = _mm256_add_epi8(local, popcnt2); \
i += 32; \
}
while (i + 8*32 <= n) {
__m256i local = _mm256_setzero_si256();
ITER ITER ITER ITER
ITER ITER ITER ITER
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
}
__m256i local = _mm256_setzero_si256();
while (i + 32 <= n) {
ITER;
}
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
#undef ITER
int result = 0;
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));
for (/**/; i < n; i++) {
result += lookup8bit[data1[i]^data2[i]];
}
return result;
}
int or_popcnt_AVX2_lookup(const uint8_t* data1, const uint8_t* data2, const size_t n) {
size_t i = 0;
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i acc = _mm256_setzero_si256();
#define ITER { \
const __m256i s1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data1 + i)); \
const __m256i s2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data2 + i)); \
const __m256i vec = _mm256_or_si256(s1, s2);\
const __m256i lo = _mm256_and_si256(vec, low_mask); \
const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
local = _mm256_add_epi8(local, popcnt1); \
local = _mm256_add_epi8(local, popcnt2); \
i += 32; \
}
while (i + 8*32 <= n) {
__m256i local = _mm256_setzero_si256();
ITER ITER ITER ITER
ITER ITER ITER ITER
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
}
__m256i local = _mm256_setzero_si256();
while (i + 32 <= n) {
ITER;
}
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
#undef ITER
int result = 0;
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));
for (/**/; i < n; i++) {
result += lookup8bit[data1[i]|data2[i]];
}
return result;
}
int and_popcnt_AVX2_lookup(const uint8_t* data1, const uint8_t* data2, const size_t n) {
size_t i = 0;
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i acc = _mm256_setzero_si256();
#define ITER { \
const __m256i s1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data1 + i)); \
const __m256i s2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data2 + i)); \
const __m256i vec = _mm256_and_si256(s1, s2);\
const __m256i lo = _mm256_and_si256(vec, low_mask); \
const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
local = _mm256_add_epi8(local, popcnt1); \
local = _mm256_add_epi8(local, popcnt2); \
i += 32; \
}
while (i + 8*32 <= n) {
__m256i local = _mm256_setzero_si256();
ITER ITER ITER ITER
ITER ITER ITER ITER
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
}
__m256i local = _mm256_setzero_si256();
while (i + 32 <= n) {
ITER;
}
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
#undef ITER
int result = 0;
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));
for (/**/; i < n; i++) {
result += lookup8bit[(data1[i]&data2[i])];
}
return result;
}
float
jaccard__AVX2(const uint8_t * a, const uint8_t * b, size_t n) {
int accu_num = and_popcnt_AVX2_lookup(a,b,n);
int accu_den = or_popcnt_AVX2_lookup(a,b,n);
return (accu_den == 0) ? 1.0 : ((float)(accu_den - accu_num) / (float)(accu_den));
}
#else
float fvec_inner_product_avx(const float* x, const float* y, size_t d) {

View File

@ -6,9 +6,11 @@
*/
#include <faiss/utils/BinaryDistance.h>
#include <faiss/utils/distances_avx.h>
namespace faiss {
extern const uint8_t lookup8bit[256];
inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size):
code (code), code_size (code_size), i(0)
@ -237,6 +239,27 @@ struct HammingComputerDefault {
};
struct HammingComputerAVX2 {
const uint8_t *a;
int n;
HammingComputerAVX2 () {}
HammingComputerAVX2 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
}
int compute (const uint8_t *b8) const {
return xor_popcnt_AVX2_lookup(a, b8, n);
}
};
/***************************************************************************
* Equivalence with a template class when code size is known at compile time
**************************************************************************/

View File

@ -35,8 +35,6 @@ typedef int32_t hamdis_t;
namespace faiss {
extern const uint8_t lookup8bit[256];
/**************************************************
* General bit vector functions
**************************************************/

View File

@ -13,6 +13,7 @@
#define FAISS_JACCARD_INL_H
#include <faiss/utils/BinaryDistance.h>
#include <faiss/utils/distances_avx.h>
namespace faiss {
@ -356,6 +357,27 @@ struct JaccardComputer256 {
};
struct JaccardComputerAVX2 {
const uint8_t *a;
int n;
JaccardComputerAVX2 () {}
JaccardComputerAVX2 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
}
float compute (const uint8_t *b8) const {
return jaccard__AVX2(a, b8, n);
}
};
// default template
template<int CODE_SIZE>
struct JaccardComputer: JaccardComputerDefault {