mirror of https://github.com/milvus-io/milvus.git
* Optimize index flat L2/IP for SSE Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * parallel optimization Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * fix threshold Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * add changelog Signed-off-by: sahuang <xiaohai.xu@zilliz.com> * add changelog Signed-off-by: sahuang <xiaohai.xu@zilliz.com> Co-authored-by: sahuang <xiaohai.xu@zilliz.com>pull/1680/head
parent
3de34d3831
commit
59dab6cb84
|
@ -19,8 +19,9 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#1546 Move Config.cpp to config directory
|
||||
- \#1547 Rename storage/file to storage/disk and rename classes
|
||||
- \#1548 Move store/Directory to storage/Operation and add FSHandler
|
||||
- \#1649 Fix Milvus crash on old CPU
|
||||
- \#1619 Improve compact performance
|
||||
- \#1649 Fix Milvus crash on old CPU
|
||||
- \#1653 IndexFlat performance improvement for NQ < thread_number
|
||||
|
||||
## Task
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace faiss {
|
|||
if (init_heap) ha->heapify ();
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh < thread_max_num) {
|
||||
if (ha->nh < 4) {
|
||||
// omp for n2
|
||||
int all_hash_size = thread_max_num * k;
|
||||
float *value = new float[all_hash_size];
|
||||
|
|
|
@ -152,39 +152,84 @@ static void knn_inner_product_sse (const float * x,
|
|||
ConcurrentBitsetPtr bitset = nullptr)
|
||||
{
|
||||
size_t k = res->k;
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
|
||||
check_period *= omp_get_max_threads();
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
if (nx < 4) {
|
||||
// omp for ny
|
||||
size_t all_hash_size = thread_max_num * k;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
// init hash
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
value[i] = -1.0 / 0.0;
|
||||
}
|
||||
const float *x_i = x + i * d;
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
const float * x_i = x + i * d;
|
||||
const float * y_j = y;
|
||||
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
|
||||
minheap_heapify (k, simi, idxi);
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
const float *y_j = y + j * d;
|
||||
float ip = fvec_inner_product (x_i, y_j, d);
|
||||
|
||||
if (ip > simi[0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
minheap_push (k, simi, idxi, ip, j);
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
float * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (ip > val_[0]) {
|
||||
minheap_pop (k, val_, ids_);
|
||||
minheap_push (k, val_, ids_, ip, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
|
||||
// merge hash
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
minheap_heapify (k, simi, idxi);
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
if (value[i] > simi[0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
minheap_push (k, simi, idxi, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
minheap_reorder (k, simi, idxi);
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= thread_max_num;
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
const float * x_i = x + i * d;
|
||||
const float * y_j = y;
|
||||
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
|
||||
minheap_heapify (k, simi, idxi);
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
float ip = fvec_inner_product (x_i, y_j, d);
|
||||
|
||||
if (ip > simi[0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
minheap_push (k, simi, idxi, ip, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
minheap_reorder (k, simi, idxi);
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void knn_L2sqr_sse (
|
||||
|
@ -196,37 +241,87 @@ static void knn_L2sqr_sse (
|
|||
{
|
||||
size_t k = res->k;
|
||||
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= omp_get_max_threads();
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
size_t thread_max_num = omp_get_max_threads();
|
||||
if (nx < 4) {
|
||||
// omp for ny
|
||||
size_t all_hash_size = thread_max_num * k;
|
||||
float *value = new float[all_hash_size];
|
||||
int64_t *labels = new int64_t[all_hash_size];
|
||||
|
||||
for (size_t i = 0; i < nx; i++) {
|
||||
// init hash
|
||||
for (size_t i = 0; i < all_hash_size; i++) {
|
||||
value[i] = 1.0 / 0.0;
|
||||
}
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
labels[i] = -1;
|
||||
}
|
||||
const float *x_i = x + i * d;
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
const float * x_i = x + i * d;
|
||||
const float * y_j = y;
|
||||
size_t j;
|
||||
float * simi = res->get_val(i);
|
||||
int64_t * idxi = res->get_ids (i);
|
||||
|
||||
maxheap_heapify (k, simi, idxi);
|
||||
for (j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)) {
|
||||
const float *y_j = y + j * d;
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
|
||||
if (disij < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, disij, j);
|
||||
size_t thread_no = omp_get_thread_num();
|
||||
float * __restrict val_ = value + thread_no * k;
|
||||
int64_t * __restrict ids_ = labels + thread_no * k;
|
||||
if (disij < val_[0]) {
|
||||
maxheap_pop (k, val_, ids_);
|
||||
maxheap_push (k, val_, ids_, disij, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
|
||||
// merge hash
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
memcpy(simi, value, k * sizeof(float));
|
||||
memcpy(idxi, labels, k * sizeof(int64_t));
|
||||
maxheap_heapify (k, simi, idxi, value, labels, k);
|
||||
for (size_t i = k; i < all_hash_size; i++) {
|
||||
if (value[i] < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, value[i], labels[i]);
|
||||
}
|
||||
}
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
delete[] value;
|
||||
delete[] labels;
|
||||
|
||||
} else {
|
||||
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
||||
check_period *= thread_max_num;
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
const float * x_i = x + i * d;
|
||||
const float * y_j = y;
|
||||
float * simi = res->get_val(i);
|
||||
int64_t * idxi = res->get_ids (i);
|
||||
|
||||
maxheap_heapify (k, simi, idxi);
|
||||
|
||||
for (size_t j = 0; j < ny; j++) {
|
||||
if(!bitset || !bitset->test(j)){
|
||||
float disij = fvec_L2sqr (x_i, y_j, d);
|
||||
|
||||
if (disij < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, disij, j);
|
||||
}
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -899,4 +994,4 @@ void pairwise_L2sqr (int64_t d,
|
|||
}
|
||||
|
||||
|
||||
} // namespace faiss
|
||||
} // namespace faiss
|
|
@ -281,7 +281,7 @@ void hammings_knn_hc (
|
|||
if (init_heap) ha->heapify ();
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh < thread_max_num) {
|
||||
if (ha->nh < 4) {
|
||||
// omp for n2
|
||||
int all_hash_size = thread_max_num * k;
|
||||
hamdis_t *value = new hamdis_t[all_hash_size];
|
||||
|
@ -432,7 +432,7 @@ void hammings_knn_hc_1 (
|
|||
}
|
||||
|
||||
int thread_max_num = omp_get_max_threads();
|
||||
if (ha->nh < thread_max_num) {
|
||||
if (ha->nh < 4) {
|
||||
// omp for n2
|
||||
int all_hash_size = thread_max_num * k;
|
||||
hamdis_t *value = new hamdis_t[all_hash_size];
|
||||
|
|
Loading…
Reference in New Issue