mirror of https://github.com/milvus-io/milvus.git
Let FAISS support all CPU SIMD (#14587)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/14601/head
parent
c4946442cb
commit
7f6b7998db
|
@ -55,12 +55,9 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
|||
}
|
||||
|
||||
std::string cpu_flag;
|
||||
if (faiss::hook_init(cpu_flag)) {
|
||||
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
|
||||
return cpu_flag;
|
||||
}
|
||||
|
||||
KNOWHERE_THROW_MSG("FAISS hook fail, CPU not supported!");
|
||||
faiss::hook_init(cpu_flag);
|
||||
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
|
||||
return cpu_flag;
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -32,34 +32,28 @@ sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner = sq_select_inverted_li
|
|||
|
||||
/*****************************************************************************/
|
||||
|
||||
bool support_avx512() {
|
||||
if (!faiss_use_avx512) return false;
|
||||
|
||||
bool cpu_support_avx512() {
|
||||
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
|
||||
return (instruction_set_inst.AVX512F() &&
|
||||
instruction_set_inst.AVX512DQ() &&
|
||||
instruction_set_inst.AVX512BW());
|
||||
}
|
||||
|
||||
bool support_avx2() {
|
||||
if (!faiss_use_avx2) return false;
|
||||
|
||||
bool cpu_support_avx2() {
|
||||
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
|
||||
return (instruction_set_inst.AVX2());
|
||||
}
|
||||
|
||||
bool support_sse4_2() {
|
||||
if (!faiss_use_sse4_2) return false;
|
||||
|
||||
bool cpu_support_sse4_2() {
|
||||
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
|
||||
return (instruction_set_inst.SSE42());
|
||||
}
|
||||
|
||||
bool hook_init(std::string& cpu_flag) {
|
||||
void hook_init(std::string& cpu_flag) {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
|
||||
if (support_avx512()) {
|
||||
if (faiss_use_avx512 && cpu_support_avx512()) {
|
||||
/* for IVFFLAT */
|
||||
fvec_inner_product = fvec_inner_product_avx512;
|
||||
fvec_L2sqr = fvec_L2sqr_avx512;
|
||||
|
@ -72,7 +66,7 @@ bool hook_init(std::string& cpu_flag) {
|
|||
sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx512;
|
||||
|
||||
cpu_flag = "AVX512";
|
||||
} else if (support_avx2()) {
|
||||
} else if (faiss_use_avx2 && cpu_support_avx2()) {
|
||||
/* for IVFFLAT */
|
||||
fvec_inner_product = fvec_inner_product_avx;
|
||||
fvec_L2sqr = fvec_L2sqr_avx;
|
||||
|
@ -85,7 +79,7 @@ bool hook_init(std::string& cpu_flag) {
|
|||
sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx;
|
||||
|
||||
cpu_flag = "AVX2";
|
||||
} else if (support_sse4_2()) {
|
||||
} else if (faiss_use_sse4_2 && cpu_support_sse4_2()) {
|
||||
/* for IVFFLAT */
|
||||
fvec_inner_product = fvec_inner_product_sse;
|
||||
fvec_L2sqr = fvec_L2sqr_sse;
|
||||
|
@ -99,11 +93,19 @@ bool hook_init(std::string& cpu_flag) {
|
|||
|
||||
cpu_flag = "SSE4_2";
|
||||
} else {
|
||||
cpu_flag = "UNSUPPORTED";
|
||||
return false;
|
||||
}
|
||||
/* for IVFFLAT */
|
||||
fvec_inner_product = fvec_inner_product_ref;
|
||||
fvec_L2sqr = fvec_L2sqr_ref;
|
||||
fvec_L1 = fvec_L1_ref;
|
||||
fvec_Linf = fvec_Linf_ref;
|
||||
|
||||
return true;
|
||||
/* for IVFSQ */
|
||||
sq_get_distance_computer = sq_get_distance_computer_ref;
|
||||
sq_sel_quantizer = sq_select_quantizer_ref;
|
||||
sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_ref;
|
||||
|
||||
cpu_flag = "REF";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -31,10 +31,10 @@ extern sq_get_distance_computer_func_ptr sq_get_distance_computer;
|
|||
extern sq_sel_quantizer_func_ptr sq_sel_quantizer;
|
||||
extern sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner;
|
||||
|
||||
bool support_avx512();
|
||||
bool support_avx2();
|
||||
bool support_sse4_2();
|
||||
bool cpu_support_avx512();
|
||||
bool cpu_support_avx2();
|
||||
bool cpu_support_sse4_2();
|
||||
|
||||
bool hook_init(std::string& cpu_flag);
|
||||
void hook_init(std::string& cpu_flag);
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -421,7 +421,7 @@ void binary_distance_knn_hc (
|
|||
{
|
||||
switch (metric_type) {
|
||||
case METRIC_Jaccard: {
|
||||
if (support_avx2() && ncodes > 64) {
|
||||
if (cpu_support_avx2() && ncodes > 64) {
|
||||
binary_distance_knn_hc<C, faiss::JaccardComputerAVX2>
|
||||
(ncodes, ha, a, b, nb, bitset);
|
||||
} else {
|
||||
|
@ -449,7 +449,7 @@ void binary_distance_knn_hc (
|
|||
}
|
||||
|
||||
case METRIC_Hamming: {
|
||||
if (support_avx2() && ncodes > 64) {
|
||||
if (cpu_support_avx2() && ncodes > 64) {
|
||||
binary_distance_knn_hc<C, faiss::HammingComputerAVX2>
|
||||
(ncodes, ha, a, b, nb, bitset);
|
||||
} else {
|
||||
|
@ -554,7 +554,7 @@ void binary_range_search(
|
|||
case METRIC_Tanimoto:
|
||||
radius = Tanimoto_2_Jaccard(radius);
|
||||
case METRIC_Jaccard: {
|
||||
if (support_avx2() && ncodes > 64) {
|
||||
if (cpu_support_avx2() && ncodes > 64) {
|
||||
binary_range_search<C, T, faiss::JaccardComputerAVX2>
|
||||
(a, b, na, nb, ncodes, radius, result, buffer_size, bitset);
|
||||
} else {
|
||||
|
@ -592,7 +592,7 @@ void binary_range_search(
|
|||
}
|
||||
|
||||
case METRIC_Hamming: {
|
||||
if (support_avx2() && ncodes > 64) {
|
||||
if (cpu_support_avx2() && ncodes > 64) {
|
||||
binary_range_search<C, T, faiss::HammingComputerAVX2>
|
||||
(a, b, na, nb, ncodes, radius, result, buffer_size, bitset);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue