mirror of https://github.com/milvus-io/milvus.git
Improve IVF search performance when NQ and nProbe are both large (#2984)
* fix indexflat search Signed-off-by: cqy <yaya645@126.com> * add parallel_policy_threshold Signed-off-by: shengjun.li <shengjun.li@zilliz.com> Co-authored-by: shengjun.li <shengjun.li@zilliz.com>pull/3006/head
parent
dd938878ea
commit
91d3fe5cbc
|
@ -12,6 +12,7 @@ Please mark all change in change log and use the issue from GitHub
|
|||
## Feature
|
||||
|
||||
## Improvement
|
||||
- \#2653 Improve IVF search performance when NQ and nProbe are both large
|
||||
|
||||
## Task
|
||||
|
||||
|
|
|
@ -132,16 +132,11 @@ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/***************************************************************************
|
||||
* KNN functions
|
||||
***************************************************************************/
|
||||
|
||||
|
||||
int parallel_policy_threshold = 65535;
|
||||
|
||||
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
||||
static void knn_inner_product_sse (const float * x,
|
||||
|
@ -151,101 +146,103 @@ static void knn_inner_product_sse (const float * x,
|
|||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = res->k;
|
||||
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
|
||||
size_t thread_heap_size = nx * k;
|
||||
size_t all_heap_size = thread_heap_size * thread_max_num;
|
||||
float *value = new float[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
|
||||
// init heap
|
||||
for (size_t i = 0; i < all_heap_size; i++) {
|
||||
value[i] = -1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
if (ny > parallel_policy_threshold) {
|
||||
size_t block_x = std::min(
|
||||
get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))),
|
||||
nx);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
float ip = fvec_inner_product (x_i, y_j, d);
|
||||
size_t all_heap_size = block_x * k * thread_max_num;
|
||||
float *value = new float[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
|
||||
float * val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (ip > val_[0]) {
|
||||
minheap_swap_top (k, val_, ids_, ip, j);
|
||||
}
|
||||
for (size_t x_from = 0, x_to; x_from < nx; x_from = x_to) {
|
||||
x_to = std::min(nx, x_from + block_x);
|
||||
int size = x_to - x_from;
|
||||
int thread_heap_size = size * k;
|
||||
|
||||
// init heap
|
||||
for (size_t i = 0; i < all_heap_size; i++) {
|
||||
value[i] = -1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] > value_x[0]) {
|
||||
minheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
float disij = fvec_inner_product (x_i, y_j, d);
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * value_x = value + i * k;
|
||||
int64_t * labels_x = labels + i * k;
|
||||
minheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(res->val, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids, labels, thread_heap_size * sizeof(int64_t));
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
/*
|
||||
else {
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= thread_max_num;
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
const float * x_i = x + i * d;
|
||||
const float * y_j = y;
|
||||
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
|
||||
minheap_heapify (k, simi, idxi);
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
float ip = fvec_inner_product (x_i, y_j, d);
|
||||
|
||||
if (ip > simi[0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
minheap_push (k, simi, idxi, ip, j);
|
||||
float * val_ = value + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
if (disij > val_[0]) {
|
||||
minheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
minheap_reorder (k, simi, idxi);
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * __restrict value_x = value + (i - x_from) * k;
|
||||
int64_t * __restrict labels_x = labels + (i - x_from) * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] > value_x[0]) {
|
||||
minheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * value_x = value + (i - x_from) * k;
|
||||
int64_t * labels_x = labels + (i - x_from) * k;
|
||||
minheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
memcpy(res->val+ x_from * k, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids+ x_from * k, labels, thread_heap_size * sizeof(int64_t));
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
float * value = res->val;
|
||||
int64_t * labels = res->ids;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
const float *y_j = y;
|
||||
|
||||
float * __restrict val_ = value + i * k;
|
||||
int64_t * __restrict ids_ = labels + i * k;
|
||||
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
val_[j] = -1.0 / 0.0;
|
||||
ids_[j] = -1;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if (!bitset || !bitset->test(j)) {
|
||||
float disij = fvec_inner_product (x_i, y_j, d);
|
||||
if (disij > val_[0]) {
|
||||
minheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
|
||||
minheap_reorder (k, val_, ids_);
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
static void knn_L2sqr_sse (
|
||||
|
@ -256,100 +253,105 @@ static void knn_L2sqr_sse (
|
|||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = res->k;
|
||||
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
|
||||
size_t thread_heap_size = nx * k;
|
||||
size_t all_heap_size = thread_heap_size * thread_max_num;
|
||||
float *value = new float[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
if (ny > parallel_policy_threshold) {
|
||||
size_t block_x = std::min(
|
||||
get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))),
|
||||
nx);
|
||||
|
||||
// init heap
|
||||
for (size_t i = 0; i < all_heap_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
size_t all_heap_size = block_x * k * thread_max_num;
|
||||
float *value = new float[all_heap_size];
|
||||
int64_t *labels = new int64_t[all_heap_size];
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
for (size_t x_from = 0, x_to; x_from < nx; x_from = x_to) {
|
||||
x_to = std::min(nx, x_from + block_x);
|
||||
int size = x_to - x_from;
|
||||
int thread_heap_size = size * k;
|
||||
|
||||
float * val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (disij < val_[0]) {
|
||||
maxheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
// init heap
|
||||
for (size_t i = 0; i < all_heap_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
labels[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
float * value_x = value + i * k;
|
||||
int64_t * labels_x = labels + i * k;
|
||||
maxheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
// copy result
|
||||
memcpy(res->val, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids, labels, thread_heap_size * sizeof(int64_t));
|
||||
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
/*
|
||||
else {
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= thread_max_num;
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
const float * x_i = x + i * d;
|
||||
const float * y_j = y;
|
||||
float * simi = res->get_val(i);
|
||||
int64_t * idxi = res->get_ids (i);
|
||||
|
||||
maxheap_heapify (k, simi, idxi);
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
|
||||
if (disij < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, disij, j);
|
||||
float * val_ = value + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
if (disij < val_[0]) {
|
||||
maxheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * __restrict value_x = value + (i - x_from) * k;
|
||||
int64_t * __restrict labels_x = labels + (i - x_from) * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
if (value_x_t[j] < value_x[0]) {
|
||||
maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * value_x = value + (i - x_from) * k;
|
||||
int64_t * labels_x = labels + (i - x_from) * k;
|
||||
maxheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
memcpy(res->val+ x_from * k, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids+ x_from * k, labels, thread_heap_size * sizeof(int64_t));
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
|
||||
} else {
|
||||
|
||||
float * value = res->val;
|
||||
int64_t * labels = res->ids;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
const float *y_j = y;
|
||||
|
||||
float * __restrict val_ = value + i * k;
|
||||
int64_t * __restrict ids_ = labels + i * k;
|
||||
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
val_[j] = 1.0 / 0.0;
|
||||
ids_[j] = -1;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if (!bitset || !bitset->test(j)) {
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
if (disij < val_[0]) {
|
||||
maxheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
|
||||
maxheap_reorder (k, val_, ids_);
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
||||
|
|
|
@ -158,6 +158,9 @@ void pairwise_indexed_inner_product (
|
|||
// threshold on nx above which we switch to BLAS to compute distances
|
||||
extern int distance_compute_blas_threshold;
|
||||
|
||||
// threshold on nx above which we switch to compute parallel on ny
|
||||
extern int parallel_policy_threshold;
|
||||
|
||||
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
||||
* vector y, w.r.t to max inner product
|
||||
*
|
||||
|
|
|
@ -684,4 +684,28 @@ bool check_openmp() {
|
|||
return true;
|
||||
}
|
||||
|
||||
int64_t get_L3_Size() {
|
||||
static int64_t l3_size = -1;
|
||||
constexpr int64_t KB = 1024;
|
||||
if (l3_size == -1) {
|
||||
|
||||
FILE* file = fopen("/sys/devices/system/cpu/cpu0/cache/index3/size","r");
|
||||
int64_t result = 0;
|
||||
constexpr int64_t line_length = 128;
|
||||
char line[line_length];
|
||||
if (file){
|
||||
char* ret = fgets(line, sizeof(line) - 1, file);
|
||||
|
||||
sscanf(line, "%luK", &result);
|
||||
l3_size = result * KB;
|
||||
|
||||
fclose(file);
|
||||
} else {
|
||||
l3_size = 12 * KB * KB; // 12M
|
||||
}
|
||||
|
||||
}
|
||||
return l3_size;
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -160,6 +160,9 @@ uint64_t hash_bytes (const uint8_t *bytes, int64_t n);
|
|||
/** Whether OpenMP annotations were respected. */
|
||||
bool check_openmp();
|
||||
|
||||
/** get the size of L3 cache */
|
||||
int64_t get_L3_Size();
|
||||
|
||||
} // namspace faiss
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue