diff --git a/internal/core/src/index/thirdparty/faiss/IndexFlat.cpp b/internal/core/src/index/thirdparty/faiss/IndexFlat.cpp index d1e3066e14..7ba448fd20 100644 --- a/internal/core/src/index/thirdparty/faiss/IndexFlat.cpp +++ b/internal/core/src/index/thirdparty/faiss/IndexFlat.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances.cpp b/internal/core/src/index/thirdparty/faiss/utils/distances.cpp index ce0ffc246f..b2806a8019 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/internal/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -16,7 +16,6 @@ #include #include -#include #include @@ -792,240 +791,6 @@ void knn_L2sqr_by_idx (const float * x, } - - - -/*************************************************************************** - * Range search - ***************************************************************************/ - -/** Find the nearest neighbors for nx queries in a set of ny vectors - * compute_l2 = compute pairwise squared L2 distance rather than inner prod - */ - template -static void range_search_blas ( - const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &res, - size_t buffer_size, - const BitsetView &bitset) -{ - - // BLAS does not like empty matrices - if (nx == 0 || ny == 0) return; - - /* block sizes */ - const size_t bs_x = 4096, bs_y = 1024; - // const size_t bs_x = 16, bs_y = 16; - float *ip_block = new float[bs_x * bs_y]; - ScopeDeleter del0(ip_block); - - float *x_norms = nullptr, *y_norms = nullptr; - ScopeDeleter del1, del2; - if (compute_l2) { - x_norms = new float[nx]; - del1.set (x_norms); - fvec_norms_L2sqr (x_norms, x, d, nx); - - y_norms = new float[ny]; - del2.set (y_norms); - fvec_norms_L2sqr (y_norms, y, d, ny); - } - - for (size_t j0 = 0; j0 < ny; j0 += bs_y) { - size_t j1 = j0 + bs_y; - if (j1 > ny) j1 = ny; - RangeSearchResult *tmp_res = new RangeSearchResult(nx); - tmp_res->buffer_size = buffer_size; - RangeSearchPartialResult * pres = new RangeSearchPartialResult (tmp_res); - res.push_back (pres); - - for (size_t i0 = 0; i0 < nx; i0 += bs_x) { - size_t i1 = i0 + bs_x; - if(i1 > nx) i1 = nx; - - /* compute the actual dot products */ - { - float one = 1, zero = 0; - FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; - sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one, - y + j0 * d, &di, - x + i0 * d, &di, &zero, - ip_block, &nyi); - } - - - for (size_t i = i0; i < i1; i++) { - const float *ip_line = ip_block + (i - i0) * (j1 - j0); - - RangeQueryResult & qres = pres->new_result (i); - - for (size_t j = j0; j < j1; j++) { - float ip = *ip_line++; - if (bitset.empty() || !bitset.test((int64_t)j)) { - if (compute_l2) { - float dis = x_norms[i] + y_norms[j] - 2 * ip; - if (dis < radius) { - qres.add (dis, j); - } - } else { - if (ip > radius) { - qres.add (ip, j); - } - } - } - } - } - } - InterruptCallback::check (); - } - -// RangeSearchPartialResult::merge (partial_results); -} - - -template -static void range_search_sse (const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &res, - size_t buffer_size, - const BitsetView &bitset) -{ - -#pragma omp parallel - { - RangeSearchResult *tmp_res = new RangeSearchResult(nx); - tmp_res->buffer_size = buffer_size; - auto pres = new RangeSearchPartialResult(tmp_res); - -#pragma omp for - for (size_t i = 0; i < nx; i++) { - const float * x_ = x + i * d; - const float * y_ = y; - size_t j; - - RangeQueryResult & qres = pres->new_result (i); - - for (j = 0; j < ny; j++) { - if (bitset.empty() || !bitset.test((int64_t)j)) { - if (compute_l2) { - float disij = fvec_L2sqr (x_, y_, d); - if (disij < radius) { - qres.add (disij, j); - } - } else { - float ip = fvec_inner_product (x_, y_, d); - if (ip > radius) { - qres.add (ip, j); - } - } - } - y_ += d; - } - - } -#pragma omp critical - res.push_back(pres); - } - - // check just at the end because the use case is typically just - // when the nb of queries is low. - InterruptCallback::check(); -} - -// range search by sse when nq = 1, namely single query situation -template -static void range_search_sse_sq (const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &res, - size_t buffer_size, - const BitsetView &bitset) -{ - -#pragma omp parallel - { - RangeSearchResult *tmp_res = new RangeSearchResult(nx); - tmp_res->buffer_size = buffer_size; - auto pres = new RangeSearchPartialResult(tmp_res); - - const float * x_ = x; - size_t j; - RangeQueryResult & qres = pres->new_result (0); - -#pragma omp for - for (j = 0; j < ny; j++) { - const float * y_ = y + j * d; - if (bitset.empty() || !bitset.test((int64_t)j)) { - if (compute_l2) { - float disij = fvec_L2sqr (x_, y_, d); - if (disij < radius) { - qres.add (disij, j); - } - } else { - float ip = fvec_inner_product (x_, y_, d); - if (ip > radius) { - qres.add (ip, j); - } - } - } - } -#pragma omp critical - res.push_back(pres); - } - - // check just at the end because the use case is typically just - // when the nb of queries is low. - InterruptCallback::check(); -} - - -void range_search_L2sqr ( - const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &res, - size_t buffer_size, - const BitsetView &bitset) -{ - - if (nx < distance_compute_blas_threshold) { - if (nx == 1) { - range_search_sse_sq (x, y, d, nx, ny, radius, res, buffer_size, bitset); - } else { - range_search_sse (x, y, d, nx, ny, radius, res, buffer_size, bitset); - } - } else { - range_search_blas (x, y, d, nx, ny, radius, res, buffer_size, bitset); - } -} - -void range_search_inner_product ( - const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &res, - size_t buffer_size, - const BitsetView &bitset) -{ - - if (nx < distance_compute_blas_threshold) { - if (nx == 1) - range_search_sse_sq (x, y, d, nx, ny, radius, res, buffer_size, bitset); - else - range_search_sse (x, y, d, nx, ny, radius, res, buffer_size, bitset); - } else { - range_search_blas (x, y, d, nx, ny, radius, res, buffer_size, bitset); - } -} - void pairwise_L2sqr (int64_t d, int64_t nq, const float *xq, int64_t nb, const float *xb, @@ -1071,6 +836,10 @@ void pairwise_L2sqr (int64_t d, } +/*************************************************************************** + * elkan + ***************************************************************************/ + void elkan_L2_sse ( const float * x, const float * y, diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances.h b/internal/core/src/index/thirdparty/faiss/utils/distances.h index eb59711faf..b59f94616a 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/distances.h +++ b/internal/core/src/index/thirdparty/faiss/utils/distances.h @@ -16,7 +16,6 @@ #include #include -#include namespace faiss { @@ -214,42 +213,6 @@ void knn_L2sqr_by_idx (const float * x, size_t d, size_t nx, size_t ny, float_maxheap_array_t * res); -/*************************************************************************** - * Range search - ***************************************************************************/ - - - -/// Forward declaration, see AuxIndexStructures.h -struct RangeSearchResult; - -/** Return the k nearest neighors of each of the nx vectors x among the ny - * vector y, w.r.t to max inner product - * - * @param x query vectors, size nx * d - * @param y database vectors, size ny * d - * @param radius search radius around the x vectors - * @param result result structure - */ -void range_search_L2sqr ( - const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &result, - size_t buffer_size, - const BitsetView &bitset = nullptr); - -/// same as range_search_L2sqr for the inner product similarity -void range_search_inner_product ( - const float * x, - const float * y, - size_t d, size_t nx, size_t ny, - float radius, - std::vector &result, - size_t buffer_size, - const BitsetView &bitset = nullptr); - /*************************************************************************** * elkan diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances_range.cpp b/internal/core/src/index/thirdparty/faiss/utils/distances_range.cpp new file mode 100644 index 0000000000..89adeebb02 --- /dev/null +++ b/internal/core/src/index/thirdparty/faiss/utils/distances_range.cpp @@ -0,0 +1,272 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + + +#ifndef FINTEGER +#define FINTEGER long +#endif + + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +} + + +namespace faiss { + +/*************************************************************************** + * Range search + ***************************************************************************/ + +/** Find the nearest neighbors for nx queries in a set of ny vectors + * compute_l2 = compute pairwise squared L2 distance rather than inner prod + */ + template +static void range_search_blas ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &res, + size_t buffer_size, + const BitsetView &bitset) +{ + + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) return; + + /* block sizes */ + const size_t bs_x = 4096, bs_y = 1024; + // const size_t bs_x = 16, bs_y = 16; + float *ip_block = new float[bs_x * bs_y]; + ScopeDeleter del0(ip_block); + + float *x_norms = nullptr, *y_norms = nullptr; + ScopeDeleter del1, del2; + if (compute_l2) { + x_norms = new float[nx]; + del1.set (x_norms); + fvec_norms_L2sqr (x_norms, x, d, nx); + + y_norms = new float[ny]; + del2.set (y_norms); + fvec_norms_L2sqr (y_norms, y, d, ny); + } + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) j1 = ny; + RangeSearchResult *tmp_res = new RangeSearchResult(nx); + tmp_res->buffer_size = buffer_size; + RangeSearchPartialResult * pres = new RangeSearchPartialResult (tmp_res); + res.push_back (pres); + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if(i1 > nx) i1 = nx; + + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one, + y + j0 * d, &di, + x + i0 * d, &di, &zero, + ip_block, &nyi); + } + + for (size_t i = i0; i < i1; i++) { + const float *ip_line = ip_block + (i - i0) * (j1 - j0); + + RangeQueryResult & qres = pres->new_result (i); + + for (size_t j = j0; j < j1; j++) { + float ip = *ip_line++; + if (bitset.empty() || !bitset.test((int64_t)j)) { + if (compute_l2) { + float dis = x_norms[i] + y_norms[j] - 2 * ip; + if (dis < radius) { + qres.add (dis, j); + } + } else { + if (ip > radius) { + qres.add (ip, j); + } + } + } + } + } + } + InterruptCallback::check (); + } + +// RangeSearchPartialResult::merge (partial_results); +} + + +template +static void range_search_sse (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &res, + size_t buffer_size, + const BitsetView &bitset) +{ + +#pragma omp parallel + { + RangeSearchResult *tmp_res = new RangeSearchResult(nx); + tmp_res->buffer_size = buffer_size; + auto pres = new RangeSearchPartialResult(tmp_res); + +#pragma omp for + for (size_t i = 0; i < nx; i++) { + const float * x_ = x + i * d; + const float * y_ = y; + size_t j; + + RangeQueryResult & qres = pres->new_result (i); + + for (j = 0; j < ny; j++) { + if (bitset.empty() || !bitset.test((int64_t)j)) { + if (compute_l2) { + float disij = fvec_L2sqr (x_, y_, d); + if (disij < radius) { + qres.add (disij, j); + } + } else { + float ip = fvec_inner_product (x_, y_, d); + if (ip > radius) { + qres.add (ip, j); + } + } + } + y_ += d; + } + + } +#pragma omp critical + res.push_back(pres); + } + + // check just at the end because the use case is typically just + // when the nb of queries is low. + InterruptCallback::check(); +} + +// range search by sse when nq = 1, namely single query situation +template +static void range_search_sse_sq (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &res, + size_t buffer_size, + const BitsetView &bitset) +{ + +#pragma omp parallel + { + RangeSearchResult *tmp_res = new RangeSearchResult(nx); + tmp_res->buffer_size = buffer_size; + auto pres = new RangeSearchPartialResult(tmp_res); + + const float * x_ = x; + size_t j; + RangeQueryResult & qres = pres->new_result (0); + +#pragma omp for + for (j = 0; j < ny; j++) { + const float * y_ = y + j * d; + if (bitset.empty() || !bitset.test((int64_t)j)) { + if (compute_l2) { + float disij = fvec_L2sqr (x_, y_, d); + if (disij < radius) { + qres.add (disij, j); + } + } else { + float ip = fvec_inner_product (x_, y_, d); + if (ip > radius) { + qres.add (ip, j); + } + } + } + } +#pragma omp critical + res.push_back(pres); + } + + // check just at the end because the use case is typically just + // when the nb of queries is low. + InterruptCallback::check(); +} + + +void range_search_L2sqr ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &res, + size_t buffer_size, + const BitsetView &bitset) +{ + + if (nx < distance_compute_blas_threshold) { + if (nx == 1) { + range_search_sse_sq (x, y, d, nx, ny, radius, res, buffer_size, bitset); + } else { + range_search_sse (x, y, d, nx, ny, radius, res, buffer_size, bitset); + } + } else { + range_search_blas (x, y, d, nx, ny, radius, res, buffer_size, bitset); + } +} + +void range_search_inner_product ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &res, + size_t buffer_size, + const BitsetView &bitset) +{ + + if (nx < distance_compute_blas_threshold) { + if (nx == 1) + range_search_sse_sq (x, y, d, nx, ny, radius, res, buffer_size, bitset); + else + range_search_sse (x, y, d, nx, ny, radius, res, buffer_size, bitset); + } else { + range_search_blas (x, y, d, nx, ny, radius, res, buffer_size, bitset); + } +} + +} // namespace faiss diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances_range.h b/internal/core/src/index/thirdparty/faiss/utils/distances_range.h new file mode 100644 index 0000000000..d6f781434a --- /dev/null +++ b/internal/core/src/index/thirdparty/faiss/utils/distances_range.h @@ -0,0 +1,54 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include + +#include +#include + + +namespace faiss { + +/*************************************************************************** + * Range search + ***************************************************************************/ + +/// Forward declaration, see AuxIndexStructures.h +struct RangeSearchResult; + +/** Return the k nearest neighors of each of the nx vectors x among the ny + * vector y, w.r.t to max inner product + * + * @param x query vectors, size nx * d + * @param y database vectors, size ny * d + * @param radius search radius around the x vectors + * @param result result structure + */ +void range_search_L2sqr ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &result, + size_t buffer_size, + const BitsetView &bitset = nullptr); + +/// same as range_search_L2sqr for the inner product similarity +void range_search_inner_product ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + std::vector &result, + size_t buffer_size, + const BitsetView &bitset = nullptr); + +} // namespace faiss