mirror of https://github.com/milvus-io/milvus.git
optimizer parallel policy (#3014)
* optimizer parallel policy Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * modify parallel policy Signed-off-by: shengjun.li <shengjun.li@zilliz.com>pull/3025/head
parent
7457438ba5
commit
1c36e6d83d
|
@ -148,7 +148,7 @@ static void knn_inner_product_sse (const float * x,
|
|||
size_t k = res->k;
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
|
||||
if (ny > parallel_policy_threshold) {
|
||||
if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) {
|
||||
size_t block_x = std::min(
|
||||
get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))),
|
||||
nx);
|
||||
|
@ -173,24 +173,24 @@ static void knn_inner_product_sse (const float * x,
|
|||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
const float *x_i = x + x_from * d;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float disij = fvec_inner_product (x_i, y_j, d);
|
||||
|
||||
float * val_ = value + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
float * val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (disij > val_[0]) {
|
||||
minheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
x_i += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge heap
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * __restrict value_x = value + (i - x_from) * k;
|
||||
int64_t * __restrict labels_x = labels + (i - x_from) * k;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
|
@ -201,14 +201,16 @@ static void knn_inner_product_sse (const float * x,
|
|||
}
|
||||
}
|
||||
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * value_x = value + (i - x_from) * k;
|
||||
int64_t * labels_x = labels + (i - x_from) * k;
|
||||
// sort
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float * value_x = value + i * k;
|
||||
int64_t * labels_x = labels + i * k;
|
||||
minheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
memcpy(res->val+ x_from * k, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids+ x_from * k, labels, thread_heap_size * sizeof(int64_t));
|
||||
// copy result
|
||||
memcpy(res->val + x_from * k, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids + x_from * k, labels, thread_heap_size * sizeof(int64_t));
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
@ -255,7 +257,7 @@ static void knn_L2sqr_sse (
|
|||
size_t k = res->k;
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
|
||||
if (ny > parallel_policy_threshold) {
|
||||
if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) {
|
||||
size_t block_x = std::min(
|
||||
get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))),
|
||||
nx);
|
||||
|
@ -280,24 +282,24 @@ static void knn_L2sqr_sse (
|
|||
if(!bitset || !bitset->test(j)) {
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
const float *y_j = y + j * d;
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
const float *x_i = x + i * d;
|
||||
const float *x_i = x + x_from * d;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
|
||||
float * val_ = value + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + (i - x_from) * k;
|
||||
float * val_ = value + thread_no * thread_heap_size + i * k;
|
||||
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
|
||||
if (disij < val_[0]) {
|
||||
maxheap_swap_top (k, val_, ids_, disij, j);
|
||||
}
|
||||
x_i += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge heap
|
||||
for (size_t t = 1; t < thread_max_num; t++) {
|
||||
// merge heap
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * __restrict value_x = value + (i - x_from) * k;
|
||||
int64_t * __restrict labels_x = labels + (i - x_from) * k;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float * __restrict value_x = value + i * k;
|
||||
int64_t * __restrict labels_x = labels + i * k;
|
||||
float *value_x_t = value_x + t * thread_heap_size;
|
||||
int64_t *labels_x_t = labels_x + t * thread_heap_size;
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
|
@ -308,19 +310,20 @@ static void knn_L2sqr_sse (
|
|||
}
|
||||
}
|
||||
|
||||
for (size_t i = x_from; i < x_to; i++) {
|
||||
float * value_x = value + (i - x_from) * k;
|
||||
int64_t * labels_x = labels + (i - x_from) * k;
|
||||
// sort
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float * value_x = value + i * k;
|
||||
int64_t * labels_x = labels + i * k;
|
||||
maxheap_reorder (k, value_x, labels_x);
|
||||
}
|
||||
|
||||
memcpy(res->val+ x_from * k, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids+ x_from * k, labels, thread_heap_size * sizeof(int64_t));
|
||||
// copy result
|
||||
memcpy(res->val + x_from * k, value, thread_heap_size * sizeof(float));
|
||||
memcpy(res->ids + x_from * k, labels, thread_heap_size * sizeof(int64_t));
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
|
||||
} else {
|
||||
|
||||
float * value = res->val;
|
||||
|
|
Loading…
Reference in New Issue