Let FAISS support all CPU SIMD (#14587)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/14601/head
Cai Yudong 2021-12-30 12:09:46 +08:00 committed by GitHub
parent c4946442cb
commit 7f6b7998db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 31 deletions

View File

@ -55,12 +55,9 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) {
}
std::string cpu_flag;
if (faiss::hook_init(cpu_flag)) {
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
return cpu_flag;
}
KNOWHERE_THROW_MSG("FAISS hook fail, CPU not supported!");
faiss::hook_init(cpu_flag);
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
return cpu_flag;
}
void

View File

@ -32,34 +32,28 @@ sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner = sq_select_inverted_li
/*****************************************************************************/
bool support_avx512() {
if (!faiss_use_avx512) return false;
bool cpu_support_avx512() {
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
return (instruction_set_inst.AVX512F() &&
instruction_set_inst.AVX512DQ() &&
instruction_set_inst.AVX512BW());
}
bool support_avx2() {
if (!faiss_use_avx2) return false;
bool cpu_support_avx2() {
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
return (instruction_set_inst.AVX2());
}
bool support_sse4_2() {
if (!faiss_use_sse4_2) return false;
bool cpu_support_sse4_2() {
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
return (instruction_set_inst.SSE42());
}
bool hook_init(std::string& cpu_flag) {
void hook_init(std::string& cpu_flag) {
static std::mutex hook_mutex;
std::lock_guard<std::mutex> lock(hook_mutex);
if (support_avx512()) {
if (faiss_use_avx512 && cpu_support_avx512()) {
/* for IVFFLAT */
fvec_inner_product = fvec_inner_product_avx512;
fvec_L2sqr = fvec_L2sqr_avx512;
@ -72,7 +66,7 @@ bool hook_init(std::string& cpu_flag) {
sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx512;
cpu_flag = "AVX512";
} else if (support_avx2()) {
} else if (faiss_use_avx2 && cpu_support_avx2()) {
/* for IVFFLAT */
fvec_inner_product = fvec_inner_product_avx;
fvec_L2sqr = fvec_L2sqr_avx;
@ -85,7 +79,7 @@ bool hook_init(std::string& cpu_flag) {
sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx;
cpu_flag = "AVX2";
} else if (support_sse4_2()) {
} else if (faiss_use_sse4_2 && cpu_support_sse4_2()) {
/* for IVFFLAT */
fvec_inner_product = fvec_inner_product_sse;
fvec_L2sqr = fvec_L2sqr_sse;
@ -99,11 +93,19 @@ bool hook_init(std::string& cpu_flag) {
cpu_flag = "SSE4_2";
} else {
cpu_flag = "UNSUPPORTED";
return false;
}
/* for IVFFLAT */
fvec_inner_product = fvec_inner_product_ref;
fvec_L2sqr = fvec_L2sqr_ref;
fvec_L1 = fvec_L1_ref;
fvec_Linf = fvec_Linf_ref;
return true;
/* for IVFSQ */
sq_get_distance_computer = sq_get_distance_computer_ref;
sq_sel_quantizer = sq_select_quantizer_ref;
sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_ref;
cpu_flag = "REF";
}
}
} // namespace faiss

View File

@ -31,10 +31,10 @@ extern sq_get_distance_computer_func_ptr sq_get_distance_computer;
extern sq_sel_quantizer_func_ptr sq_sel_quantizer;
extern sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner;
bool support_avx512();
bool support_avx2();
bool support_sse4_2();
bool cpu_support_avx512();
bool cpu_support_avx2();
bool cpu_support_sse4_2();
bool hook_init(std::string& cpu_flag);
void hook_init(std::string& cpu_flag);
} // namespace faiss

View File

@ -421,7 +421,7 @@ void binary_distance_knn_hc (
{
switch (metric_type) {
case METRIC_Jaccard: {
if (support_avx2() && ncodes > 64) {
if (cpu_support_avx2() && ncodes > 64) {
binary_distance_knn_hc<C, faiss::JaccardComputerAVX2>
(ncodes, ha, a, b, nb, bitset);
} else {
@ -449,7 +449,7 @@ void binary_distance_knn_hc (
}
case METRIC_Hamming: {
if (support_avx2() && ncodes > 64) {
if (cpu_support_avx2() && ncodes > 64) {
binary_distance_knn_hc<C, faiss::HammingComputerAVX2>
(ncodes, ha, a, b, nb, bitset);
} else {
@ -554,7 +554,7 @@ void binary_range_search(
case METRIC_Tanimoto:
radius = Tanimoto_2_Jaccard(radius);
case METRIC_Jaccard: {
if (support_avx2() && ncodes > 64) {
if (cpu_support_avx2() && ncodes > 64) {
binary_range_search<C, T, faiss::JaccardComputerAVX2>
(a, b, na, nb, ncodes, radius, result, buffer_size, bitset);
} else {
@ -592,7 +592,7 @@ void binary_range_search(
}
case METRIC_Hamming: {
if (support_avx2() && ncodes > 64) {
if (cpu_support_avx2() && ncodes > 64) {
binary_range_search<C, T, faiss::HammingComputerAVX2>
(a, b, na, nb, ncodes, radius, result, buffer_size, bitset);
} else {