NSG support MetricType IP

pull/554/head
xiaojun.lin 2019-11-26 17:15:14 +08:00
parent f34cdad372
commit 88cb2b5991
11 changed files with 338 additions and 164 deletions

View File

@ -26,6 +26,7 @@ Please mark all change in change log and use the ticket from JIRA.
- \#513 - Unittest DELETE_BY_RANGE sometimes failed
- \#527 - faiss benchmark not compatible with faiss 1.6.0
- \#530 - BuildIndex stop when do build index and search simultaneously
- \#533 - NSG build failed with MetricType Inner Product
## Feature
- \#12 - Pure CPU version for Milvus

View File

@ -38,6 +38,7 @@ set(index_srcs
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
knowhere/index/vector_index/nsg/NSGHelper.cpp
knowhere/index/vector_index/nsg/Distance.cpp
knowhere/index/vector_index/IndexIVFSQ.cpp
knowhere/index/vector_index/IndexIVFPQ.cpp
knowhere/index/vector_index/FaissBaseIndex.cpp

View File

@ -115,10 +115,6 @@ NSG::Train(const DatasetPtr& dataset, const Config& config) {
build_cfg->CheckValid(); // throw exception
}
if (build_cfg->metric_type != METRICTYPE::L2) {
KNOWHERE_THROW_MSG("NSG not support this kind of metric type");
}
// TODO(linxj): dev IndexFactory, support more IndexType
#ifdef MILVUS_GPU_VERSION
auto preprocess_index = std::make_shared<GPUIVF>(build_cfg->gpu_id);

View File

@ -0,0 +1,247 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <immintrin.h>
#include "knowhere/index/vector_index/nsg/Distance.h"
namespace knowhere {
namespace algo {
float
DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
float result = 0;
#ifdef __GNUC__
#ifdef __AVX__
#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
tmp1 = _mm256_loadu_ps(addr1); \
tmp2 = _mm256_loadu_ps(addr2); \
tmp1 = _mm256_sub_ps(tmp1, tmp2); \
tmp1 = _mm256_mul_ps(tmp1, tmp1); \
dest = _mm256_add_ps(dest, tmp1);
__m256 sum;
__m256 l0, l1;
__m256 r0, r1;
unsigned D = (size + 7) & ~7U;
unsigned DR = D % 16;
unsigned DD = D - DR;
const float* l = a;
const float* r = b;
const float* e_l = l + DD;
const float* e_r = r + DD;
float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
sum = _mm256_loadu_ps(unpack);
if (DR) {
AVX_L2SQR(e_l, e_r, sum, l0, r0);
}
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
AVX_L2SQR(l, r, sum, l0, r0);
AVX_L2SQR(l + 8, r + 8, sum, l1, r1);
}
_mm256_storeu_ps(unpack, sum);
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
#else
#ifdef __SSE2__
#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
tmp1 = _mm_load_ps(addr1); \
tmp2 = _mm_load_ps(addr2); \
tmp1 = _mm_sub_ps(tmp1, tmp2); \
tmp1 = _mm_mul_ps(tmp1, tmp1); \
dest = _mm_add_ps(dest, tmp1);
__m128 sum;
__m128 l0, l1, l2, l3;
__m128 r0, r1, r2, r3;
unsigned D = (size + 3) & ~3U;
unsigned DR = D % 16;
unsigned DD = D - DR;
const float* l = a;
const float* r = b;
const float* e_l = l + DD;
const float* e_r = r + DD;
float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};
sum = _mm_load_ps(unpack);
switch (DR) {
case 12:
SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2);
case 8:
SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1);
case 4:
SSE_L2SQR(e_l, e_r, sum, l0, r0);
default:
break;
}
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
SSE_L2SQR(l, r, sum, l0, r0);
SSE_L2SQR(l + 4, r + 4, sum, l1, r1);
SSE_L2SQR(l + 8, r + 8, sum, l2, r2);
SSE_L2SQR(l + 12, r + 12, sum, l3, r3);
}
_mm_storeu_ps(unpack, sum);
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
// nomal distance
#else
float diff0, diff1, diff2, diff3;
const float* last = a + size;
const float* unroll_group = last - 3;
/* Process 4 items with each loop for efficiency. */
while (a < unroll_group) {
diff0 = a[0] - b[0];
diff1 = a[1] - b[1];
diff2 = a[2] - b[2];
diff3 = a[3] - b[3];
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;
}
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
while (a < last) {
diff0 = *a++ - *b++;
result += diff0 * diff0;
}
#endif
#endif
#endif
return result;
}
float
DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
float result = 0;
#ifdef __GNUC__
#ifdef __AVX__
#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \
tmp1 = _mm256_loadu_ps(addr1); \
tmp2 = _mm256_loadu_ps(addr2); \
tmp1 = _mm256_mul_ps(tmp1, tmp2); \
dest = _mm256_add_ps(dest, tmp1);
__m256 sum;
__m256 l0, l1;
__m256 r0, r1;
unsigned D = (size + 7) & ~7U;
unsigned DR = D % 16;
unsigned DD = D - DR;
const float* l = a;
const float* r = b;
const float* e_l = l + DD;
const float* e_r = r + DD;
float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
sum = _mm256_loadu_ps(unpack);
if (DR) {
AVX_DOT(e_l, e_r, sum, l0, r0);
}
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
AVX_DOT(l, r, sum, l0, r0);
AVX_DOT(l + 8, r + 8, sum, l1, r1);
}
_mm256_storeu_ps(unpack, sum);
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
#else
#ifdef __SSE2__
#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \
tmp1 = _mm128_loadu_ps(addr1); \
tmp2 = _mm128_loadu_ps(addr2); \
tmp1 = _mm128_mul_ps(tmp1, tmp2); \
dest = _mm128_add_ps(dest, tmp1);
__m128 sum;
__m128 l0, l1, l2, l3;
__m128 r0, r1, r2, r3;
unsigned D = (size + 3) & ~3U;
unsigned DR = D % 16;
unsigned DD = D - DR;
const float* l = a;
const float* r = b;
const float* e_l = l + DD;
const float* e_r = r + DD;
float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};
sum = _mm_load_ps(unpack);
switch (DR) {
case 12:
SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2);
case 8:
SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1);
case 4:
SSE_DOT(e_l, e_r, sum, l0, r0);
default:
break;
}
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
SSE_DOT(l, r, sum, l0, r0);
SSE_DOT(l + 4, r + 4, sum, l1, r1);
SSE_DOT(l + 8, r + 8, sum, l2, r2);
SSE_DOT(l + 12, r + 12, sum, l3, r3);
}
_mm_storeu_ps(unpack, sum);
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
#else
float dot0, dot1, dot2, dot3;
const float* last = a + size;
const float* unroll_group = last - 3;
/* Process 4 items with each loop for efficiency. */
while (a < unroll_group) {
dot0 = a[0] * b[0];
dot1 = a[1] * b[1];
dot2 = a[2] * b[2];
dot3 = a[3] * b[3];
result += dot0 + dot1 + dot2 + dot3;
a += 4;
b += 4;
}
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
while (a < last) {
result += *a++ * *b++;
}
#endif
#endif
#endif
return result;
}
//#include <faiss/utils/distances.h>
// float
// DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
// return faiss::fvec_L2sqr(a,b,size);
//}
//
// float
// DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
// return faiss::fvec_inner_product(a,b,size);
//}
} // namespace algo
} // namespace knowhere

View File

@ -0,0 +1,39 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
namespace knowhere {
namespace algo {
struct Distance {
virtual float
Compare(const float* a, const float* b, unsigned size) const = 0;
};
struct DistanceL2 : public Distance {
float
Compare(const float* a, const float* b, unsigned size) const override;
};
struct DistanceIP : public Distance {
float
Compare(const float* a, const float* b, unsigned size) const override;
};
} // namespace algo
} // namespace knowhere

View File

@ -35,17 +35,24 @@
namespace knowhere {
namespace algo {
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, MetricType metric)
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
switch (metric) {
case METRICTYPE::L2:
distance_ = new DistanceL2;
break;
case METRICTYPE::IP:
distance_ = new DistanceIP;
break;
}
}
NsgIndex::~NsgIndex() {
delete[] ori_data_;
delete[] ids_;
delete distance_;
}
// void NsgIndex::Build(size_t nb, const float *data, const BuildParam &parameters) {
//}
void
NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters) {
TimeRecorder rc("NSG");
@ -126,7 +133,7 @@ NsgIndex::InitNavigationPoint() {
//>> Debug code
/////
// float r1 = calculate(center, ori_data_ + navigation_point * dimension, dimension);
// float r1 = distance_->Compare(center, ori_data_ + navigation_point * dimension, dimension);
// assert(r1 == resset[0].distance);
/////
}
@ -180,7 +187,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::v
continue;
}
float dist = calculate(ori_data_ + dimension * id, query, dimension);
float dist = distance_->Compare(ori_data_ + dimension * id, query, dimension);
resset[i] = Neighbor(id, dist, false);
///////////// difference from other GetNeighbors ///////////////
@ -205,7 +212,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::v
continue;
has_calculated_dist[id] = true;
float dist = calculate(query, ori_data_ + dimension * id, dimension);
float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
Neighbor nn(id, dist, false);
fullset.push_back(nn);
@ -278,7 +285,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::v
continue;
}
float dist = calculate(ori_data_ + id * dimension, query, dimension);
float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
resset[i] = Neighbor(id, dist, false);
}
std::sort(resset.begin(), resset.end()); // sort by distance
@ -299,7 +306,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, std::v
continue;
has_calculated_dist[id] = true;
float dist = calculate(ori_data_ + dimension * id, query, dimension);
float dist = distance_->Compare(ori_data_ + dimension * id, query, dimension);
Neighbor nn(id, dist, false);
fullset.push_back(nn);
@ -371,7 +378,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph&
continue;
}
float dist = calculate(ori_data_ + id * dimension, query, dimension);
float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
resset[i] = Neighbor(id, dist, false);
}
std::sort(resset.begin(), resset.end()); // sort by distance
@ -399,7 +406,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph&
continue;
has_calculated_dist[id] = true;
float dist = calculate(query, ori_data_ + dimension * id, dimension);
float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
if (dist >= resset[buffer_size - 1].distance)
continue;
@ -449,7 +456,7 @@ NsgIndex::Link() {
//>> Debug code
/////
// float r1 = calculate(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension);
// float r1 = distance_->Compare(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension);
// assert(r1 == temp[0].distance);
/////
SyncPrune(n, fullset, flags, cut_graph_dist);
@ -496,7 +503,7 @@ NsgIndex::SyncPrune(size_t n, std::vector<Neighbor>& pool, boost::dynamic_bitset
auto id = knng[n][i];
if (has_calculated[id])
continue;
float dist = calculate(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension);
float dist = distance_->Compare(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension);
pool.emplace_back(Neighbor(id, dist, true));
}
@ -613,7 +620,8 @@ NsgIndex::SelectEdge(unsigned& cursor, std::vector<Neighbor>& sort_pool, std::ve
auto& p = pool[cursor];
bool should_link = true;
for (size_t t = 0; t < result.size(); ++t) {
float dist = calculate(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension);
float dist =
distance_->Compare(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension);
if (dist < p.distance) {
should_link = false;

View File

@ -22,18 +22,16 @@
#include <vector>
#include <boost/dynamic_bitset.hpp>
#include "Distance.h"
#include "Neighbor.h"
#include "knowhere/common/Config.h"
namespace knowhere {
namespace algo {
using node_t = int64_t;
enum class MetricType {
METRIC_INNER_PRODUCT = 0,
METRIC_L2 = 1,
};
struct BuildParams {
size_t search_length;
size_t out_degree;
@ -50,7 +48,8 @@ class NsgIndex {
public:
size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
MetricType metric_type; // L2 | IP
METRICTYPE metric_type; // L2 | IP
Distance* distance_;
float* ori_data_;
int64_t* ids_; // TODO: support different type
@ -69,7 +68,7 @@ class NsgIndex {
size_t out_degree;
public:
explicit NsgIndex(const size_t& dimension, const size_t& n, MetricType metric = MetricType::METRIC_L2);
explicit NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric = METRICTYPE::L2);
NsgIndex() = default;

View File

@ -16,7 +16,6 @@
// under the License.
#include <cstring>
#include <fstream>
#include "knowhere/index/vector_index/nsg/NSGHelper.h"
@ -27,9 +26,9 @@ namespace algo {
int
InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) {
//>> Fix: Add assert
for (unsigned int i = 0; i < K; ++i) {
assert(addr[i].id != nn.id);
}
// for (unsigned int i = 0; i < K; ++i) {
// assert(addr[i].id != nn.id);
// }
// find the location to insert
int left = 0, right = K - 1;
@ -68,114 +67,5 @@ InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) {
return right;
}
// TODO: support L2 / IP
float
calculate(const float* a, const float* b, unsigned size) {
float result = 0;
#ifdef __GNUC__
#ifdef __AVX__
#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
tmp1 = _mm256_loadu_ps(addr1); \
tmp2 = _mm256_loadu_ps(addr2); \
tmp1 = _mm256_sub_ps(tmp1, tmp2); \
tmp1 = _mm256_mul_ps(tmp1, tmp1); \
dest = _mm256_add_ps(dest, tmp1);
__m256 sum;
__m256 l0, l1;
__m256 r0, r1;
unsigned D = (size + 7) & ~7U;
unsigned DR = D % 16;
unsigned DD = D - DR;
const float* l = a;
const float* r = b;
const float* e_l = l + DD;
const float* e_r = r + DD;
float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
sum = _mm256_loadu_ps(unpack);
if (DR) {
AVX_L2SQR(e_l, e_r, sum, l0, r0);
}
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
AVX_L2SQR(l, r, sum, l0, r0);
AVX_L2SQR(l + 8, r + 8, sum, l1, r1);
}
_mm256_storeu_ps(unpack, sum);
result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
#else
#ifdef __SSE2__
#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \
tmp1 = _mm_load_ps(addr1); \
tmp2 = _mm_load_ps(addr2); \
tmp1 = _mm_sub_ps(tmp1, tmp2); \
tmp1 = _mm_mul_ps(tmp1, tmp1); \
dest = _mm_add_ps(dest, tmp1);
__m128 sum;
__m128 l0, l1, l2, l3;
__m128 r0, r1, r2, r3;
unsigned D = (size + 3) & ~3U;
unsigned DR = D % 16;
unsigned DD = D - DR;
const float* l = a;
const float* r = b;
const float* e_l = l + DD;
const float* e_r = r + DD;
float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};
sum = _mm_load_ps(unpack);
switch (DR) {
case 12:
SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2);
case 8:
SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1);
case 4:
SSE_L2SQR(e_l, e_r, sum, l0, r0);
default:
break;
}
for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) {
SSE_L2SQR(l, r, sum, l0, r0);
SSE_L2SQR(l + 4, r + 4, sum, l1, r1);
SSE_L2SQR(l + 8, r + 8, sum, l2, r2);
SSE_L2SQR(l + 12, r + 12, sum, l3, r3);
}
_mm_storeu_ps(unpack, sum);
result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
// nomal distance
#else
float diff0, diff1, diff2, diff3;
const float* last = a + size;
const float* unroll_group = last - 3;
/* Process 4 items with each loop for efficiency. */
while (a < unroll_group) {
diff0 = a[0] - b[0];
diff1 = a[1] - b[1];
diff2 = a[2] - b[2];
diff3 = a[3] - b[3];
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;
}
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
while (a < last) {
diff0 = *a++ - *b++;
result += diff0 * diff0;
}
#endif
#endif
#endif
return result;
}
} // namespace algo
}; // namespace algo
} // namespace knowhere

View File

@ -17,21 +17,13 @@
#pragma once
#include <x86intrin.h>
#include <iostream>
#include <faiss/AutoTune.h>
#include "NSG.h"
#include "knowhere/common/Config.h"
#include "Neighbor.h"
namespace knowhere {
namespace algo {
extern int
InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn);
extern float
calculate(const float* a, const float* b, unsigned size);
} // namespace algo
} // namespace knowhere

View File

@ -18,7 +18,6 @@
#pragma once
#include "NSG.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace knowhere {
@ -26,6 +25,7 @@ namespace algo {
extern void
write_index(NsgIndex* index, MemoryIOWriter& writer);
extern NsgIndex*
read_index(MemoryIOReader& reader);

View File

@ -24,6 +24,8 @@
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
#include "knowhere/common/Timer.h"
#include "knowhere/index/vector_index/nsg/NSGIO.h"
#include "unittest/utils.h"
@ -95,20 +97,19 @@ TEST_F(NSGInterfaceTest, basic_test) {
index_->Add(base_dataset, knowhere::Config());
index_->Seal();
});
{
// std::cout << "k = 1" << std::endl;
// new_index->Search(GenQuery(1), Config::object{{"k", 1}});
// new_index->Search(GenQuery(10), Config::object{{"k", 1}});
// new_index->Search(GenQuery(100), Config::object{{"k", 1}});
// new_index->Search(GenQuery(1000), Config::object{{"k", 1}});
// new_index->Search(GenQuery(10000), Config::object{{"k", 1}});
// std::cout << "k = 5" << std::endl;
// new_index->Search(GenQuery(1), Config::object{{"k", 5}});
// new_index->Search(GenQuery(20), Config::object{{"k", 5}});
// new_index->Search(GenQuery(100), Config::object{{"k", 5}});
// new_index->Search(GenQuery(300), Config::object{{"k", 5}});
// new_index->Search(GenQuery(500), Config::object{{"k", 5}});
}
}
TEST_F(NSGInterfaceTest, comparetest) {
knowhere::algo::DistanceL2 distanceL2;
knowhere::algo::DistanceIP distanceIP;
knowhere::TimeRecorder tc("Compare");
for (int i = 0; i < 1000; ++i) {
distanceL2.Compare(xb.data(), xq.data(), 256);
}
tc.RecordSection("L2");
for (int i = 0; i < 1000; ++i) {
distanceIP.Compare(xb.data(), xq.data(), 256);
}
tc.RecordSection("IP");
}