diff --git a/CHANGELOG.md b/CHANGELOG.md index 016fd33a5a..4b86ca0d74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ Please mark all change in change log and use the ticket from JIRA. - \#513 - Unittest DELETE_BY_RANGE sometimes failed - \#527 - faiss benchmark not compatible with faiss 1.6.0 - \#530 - BuildIndex stop when do build index and search simultaneously +- \#533 - NSG build failed with MetricType Inner Product ## Feature - \#12 - Pure CPU version for Milvus diff --git a/core/src/index/knowhere/CMakeLists.txt b/core/src/index/knowhere/CMakeLists.txt index 285461bdef..a7d3966481 100644 --- a/core/src/index/knowhere/CMakeLists.txt +++ b/core/src/index/knowhere/CMakeLists.txt @@ -38,6 +38,7 @@ set(index_srcs knowhere/index/vector_index/nsg/NSG.cpp knowhere/index/vector_index/nsg/NSGIO.cpp knowhere/index/vector_index/nsg/NSGHelper.cpp + knowhere/index/vector_index/nsg/Distance.cpp knowhere/index/vector_index/IndexIVFSQ.cpp knowhere/index/vector_index/IndexIVFPQ.cpp knowhere/index/vector_index/FaissBaseIndex.cpp diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp index 204819517a..3cf0122233 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp @@ -115,10 +115,6 @@ NSG::Train(const DatasetPtr& dataset, const Config& config) { build_cfg->CheckValid(); // throw exception } - if (build_cfg->metric_type != METRICTYPE::L2) { - KNOWHERE_THROW_MSG("NSG not support this kind of metric type"); - } - // TODO(linxj): dev IndexFactory, support more IndexType #ifdef MILVUS_GPU_VERSION auto preprocess_index = std::make_shared(build_cfg->gpu_id); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/Distance.cpp b/core/src/index/knowhere/knowhere/index/vector_index/nsg/Distance.cpp new file mode 100644 index 0000000000..8508b65218 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/Distance.cpp @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "knowhere/index/vector_index/nsg/Distance.h" + +namespace knowhere { +namespace algo { + +float +DistanceL2::Compare(const float* a, const float* b, unsigned size) const { + float result = 0; + +#ifdef __GNUC__ +#ifdef __AVX__ + +#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_sub_ps(tmp1, tmp2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp1); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + unsigned D = (size + 7) & ~7U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) { + AVX_L2SQR(e_l, e_r, sum, l0, r0); + } + + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + AVX_L2SQR(l, r, sum, l0, r0); + AVX_L2SQR(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; + +#else +#ifdef __SSE2__ +#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm_load_ps(addr1); \ + tmp2 = _mm_load_ps(addr2); \ + tmp1 = _mm_sub_ps(tmp1, tmp2); \ + tmp1 = _mm_mul_ps(tmp1, tmp1); \ + dest = _mm_add_ps(dest, tmp1); + + __m128 sum; + __m128 l0, l1, l2, l3; + __m128 r0, r1, r2, r3; + unsigned D = (size + 3) & ~3U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; + + sum = _mm_load_ps(unpack); + switch (DR) { + case 12: + SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2); + case 8: + SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1); + case 4: + SSE_L2SQR(e_l, e_r, sum, l0, r0); + default: + break; + } + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + SSE_L2SQR(l, r, sum, l0, r0); + SSE_L2SQR(l + 4, r + 4, sum, l1, r1); + SSE_L2SQR(l + 8, r + 8, sum, l2, r2); + SSE_L2SQR(l + 12, r + 12, sum, l3, r3); + } + _mm_storeu_ps(unpack, sum); + result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; + +// nomal distance +#else + + float diff0, diff1, diff2, diff3; + const float* last = a + size; + const float* unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) { + diff0 = a[0] - b[0]; + diff1 = a[1] - b[1]; + diff2 = a[2] - b[2]; + diff3 = a[3] - b[3]; + result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; + a += 4; + b += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) { + diff0 = *a++ - *b++; + result += diff0 * diff0; + } +#endif +#endif +#endif + + return result; +} + +float +DistanceIP::Compare(const float* a, const float* b, unsigned size) const { + float result = 0; + +#ifdef __GNUC__ +#ifdef __AVX__ +#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp2); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + unsigned D = (size + 7) & ~7U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) { + AVX_DOT(e_l, e_r, sum, l0, r0); + } + + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + AVX_DOT(l, r, sum, l0, r0); + AVX_DOT(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; + +#else +#ifdef __SSE2__ +#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm128_loadu_ps(addr1); \ + tmp2 = _mm128_loadu_ps(addr2); \ + tmp1 = _mm128_mul_ps(tmp1, tmp2); \ + dest = _mm128_add_ps(dest, tmp1); + __m128 sum; + __m128 l0, l1, l2, l3; + __m128 r0, r1, r2, r3; + unsigned D = (size + 3) & ~3U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; + + sum = _mm_load_ps(unpack); + switch (DR) { + case 12: + SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2); + case 8: + SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1); + case 4: + SSE_DOT(e_l, e_r, sum, l0, r0); + default: + break; + } + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + SSE_DOT(l, r, sum, l0, r0); + SSE_DOT(l + 4, r + 4, sum, l1, r1); + SSE_DOT(l + 8, r + 8, sum, l2, r2); + SSE_DOT(l + 12, r + 12, sum, l3, r3); + } + _mm_storeu_ps(unpack, sum); + result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; +#else + + float dot0, dot1, dot2, dot3; + const float* last = a + size; + const float* unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) { + dot0 = a[0] * b[0]; + dot1 = a[1] * b[1]; + dot2 = a[2] * b[2]; + dot3 = a[3] * b[3]; + result += dot0 + dot1 + dot2 + dot3; + a += 4; + b += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) { + result += *a++ * *b++; + } +#endif +#endif +#endif + return result; +} + +//#include +// float +// DistanceL2::Compare(const float* a, const float* b, unsigned size) const { +// return faiss::fvec_L2sqr(a,b,size); +//} +// +// float +// DistanceIP::Compare(const float* a, const float* b, unsigned size) const { +// return faiss::fvec_inner_product(a,b,size); +//} + +} // namespace algo +} // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/Distance.h b/core/src/index/knowhere/knowhere/index/vector_index/nsg/Distance.h new file mode 100644 index 0000000000..df24ca8725 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/Distance.h @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace knowhere { +namespace algo { + +struct Distance { + virtual float + Compare(const float* a, const float* b, unsigned size) const = 0; +}; + +struct DistanceL2 : public Distance { + float + Compare(const float* a, const float* b, unsigned size) const override; +}; + +struct DistanceIP : public Distance { + float + Compare(const float* a, const float* b, unsigned size) const override; +}; + +} // namespace algo +} // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp index b4e00e57b7..e9e65b1191 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp @@ -35,17 +35,24 @@ namespace knowhere { namespace algo { -NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, MetricType metric) +NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric) : dimension(dimension), ntotal(n), metric_type(metric) { + switch (metric) { + case METRICTYPE::L2: + distance_ = new DistanceL2; + break; + case METRICTYPE::IP: + distance_ = new DistanceIP; + break; + } } NsgIndex::~NsgIndex() { delete[] ori_data_; delete[] ids_; + delete distance_; } -// void NsgIndex::Build(size_t nb, const float *data, const BuildParam ¶meters) { -//} void NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters) { TimeRecorder rc("NSG"); @@ -126,7 +133,7 @@ NsgIndex::InitNavigationPoint() { //>> Debug code ///// - // float r1 = calculate(center, ori_data_ + navigation_point * dimension, dimension); + // float r1 = distance_->Compare(center, ori_data_ + navigation_point * dimension, dimension); // assert(r1 == resset[0].distance); ///// } @@ -180,7 +187,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; } - float dist = calculate(ori_data_ + dimension * id, query, dimension); + float dist = distance_->Compare(ori_data_ + dimension * id, query, dimension); resset[i] = Neighbor(id, dist, false); ///////////// difference from other GetNeighbors /////////////// @@ -205,7 +212,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; has_calculated_dist[id] = true; - float dist = calculate(query, ori_data_ + dimension * id, dimension); + float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension); Neighbor nn(id, dist, false); fullset.push_back(nn); @@ -278,7 +285,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; } - float dist = calculate(ori_data_ + id * dimension, query, dimension); + float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension); resset[i] = Neighbor(id, dist, false); } std::sort(resset.begin(), resset.end()); // sort by distance @@ -299,7 +306,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; has_calculated_dist[id] = true; - float dist = calculate(ori_data_ + dimension * id, query, dimension); + float dist = distance_->Compare(ori_data_ + dimension * id, query, dimension); Neighbor nn(id, dist, false); fullset.push_back(nn); @@ -371,7 +378,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& continue; } - float dist = calculate(ori_data_ + id * dimension, query, dimension); + float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension); resset[i] = Neighbor(id, dist, false); } std::sort(resset.begin(), resset.end()); // sort by distance @@ -399,7 +406,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& continue; has_calculated_dist[id] = true; - float dist = calculate(query, ori_data_ + dimension * id, dimension); + float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension); if (dist >= resset[buffer_size - 1].distance) continue; @@ -449,7 +456,7 @@ NsgIndex::Link() { //>> Debug code ///// - // float r1 = calculate(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension); + // float r1 = distance_->Compare(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension); // assert(r1 == temp[0].distance); ///// SyncPrune(n, fullset, flags, cut_graph_dist); @@ -496,7 +503,7 @@ NsgIndex::SyncPrune(size_t n, std::vector& pool, boost::dynamic_bitset auto id = knng[n][i]; if (has_calculated[id]) continue; - float dist = calculate(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension); + float dist = distance_->Compare(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension); pool.emplace_back(Neighbor(id, dist, true)); } @@ -613,7 +620,8 @@ NsgIndex::SelectEdge(unsigned& cursor, std::vector& sort_pool, std::ve auto& p = pool[cursor]; bool should_link = true; for (size_t t = 0; t < result.size(); ++t) { - float dist = calculate(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension); + float dist = + distance_->Compare(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension); if (dist < p.distance) { should_link = false; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.h b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.h index 160c076e45..5dd128610f 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.h @@ -22,18 +22,16 @@ #include #include + +#include "Distance.h" #include "Neighbor.h" +#include "knowhere/common/Config.h" namespace knowhere { namespace algo { using node_t = int64_t; -enum class MetricType { - METRIC_INNER_PRODUCT = 0, - METRIC_L2 = 1, -}; - struct BuildParams { size_t search_length; size_t out_degree; @@ -50,7 +48,8 @@ class NsgIndex { public: size_t dimension; size_t ntotal; // totabl nb of indexed vectors - MetricType metric_type; // L2 | IP + METRICTYPE metric_type; // L2 | IP + Distance* distance_; float* ori_data_; int64_t* ids_; // TODO: support different type @@ -69,7 +68,7 @@ class NsgIndex { size_t out_degree; public: - explicit NsgIndex(const size_t& dimension, const size_t& n, MetricType metric = MetricType::METRIC_L2); + explicit NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric = METRICTYPE::L2); NsgIndex() = default; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp index 05e8d18787..dd250570b8 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp @@ -16,7 +16,6 @@ // under the License. #include -#include #include "knowhere/index/vector_index/nsg/NSGHelper.h" @@ -27,9 +26,9 @@ namespace algo { int InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) { //>> Fix: Add assert - for (unsigned int i = 0; i < K; ++i) { - assert(addr[i].id != nn.id); - } + // for (unsigned int i = 0; i < K; ++i) { + // assert(addr[i].id != nn.id); + // } // find the location to insert int left = 0, right = K - 1; @@ -68,114 +67,5 @@ InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) { return right; } -// TODO: support L2 / IP -float -calculate(const float* a, const float* b, unsigned size) { - float result = 0; - -#ifdef __GNUC__ -#ifdef __AVX__ - -#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm256_loadu_ps(addr1); \ - tmp2 = _mm256_loadu_ps(addr2); \ - tmp1 = _mm256_sub_ps(tmp1, tmp2); \ - tmp1 = _mm256_mul_ps(tmp1, tmp1); \ - dest = _mm256_add_ps(dest, tmp1); - - __m256 sum; - __m256 l0, l1; - __m256 r0, r1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float* l = a; - const float* r = b; - const float* e_l = l + DD; - const float* e_r = r + DD; - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; - - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_L2SQR(e_l, e_r, sum, l0, r0); - } - - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { - AVX_L2SQR(l, r, sum, l0, r0); - AVX_L2SQR(l + 8, r + 8, sum, l1, r1); - } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; - -#else -#ifdef __SSE2__ -#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm_load_ps(addr1); \ - tmp2 = _mm_load_ps(addr2); \ - tmp1 = _mm_sub_ps(tmp1, tmp2); \ - tmp1 = _mm_mul_ps(tmp1, tmp1); \ - dest = _mm_add_ps(dest, tmp1); - - __m128 sum; - __m128 l0, l1, l2, l3; - __m128 r0, r1, r2, r3; - unsigned D = (size + 3) & ~3U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float* l = a; - const float* r = b; - const float* e_l = l + DD; - const float* e_r = r + DD; - float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; - - sum = _mm_load_ps(unpack); - switch (DR) { - case 12: - SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2); - case 8: - SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1); - case 4: - SSE_L2SQR(e_l, e_r, sum, l0, r0); - default: - break; - } - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { - SSE_L2SQR(l, r, sum, l0, r0); - SSE_L2SQR(l + 4, r + 4, sum, l1, r1); - SSE_L2SQR(l + 8, r + 8, sum, l2, r2); - SSE_L2SQR(l + 12, r + 12, sum, l3, r3); - } - _mm_storeu_ps(unpack, sum); - result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; - -// nomal distance -#else - - float diff0, diff1, diff2, diff3; - const float* last = a + size; - const float* unroll_group = last - 3; - - /* Process 4 items with each loop for efficiency. */ - while (a < unroll_group) { - diff0 = a[0] - b[0]; - diff1 = a[1] - b[1]; - diff2 = a[2] - b[2]; - diff3 = a[3] - b[3]; - result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; - a += 4; - b += 4; - } - /* Process last 0-3 pixels. Not needed for standard vector lengths. */ - while (a < last) { - diff0 = *a++ - *b++; - result += diff0 * diff0; - } -#endif -#endif -#endif - - return result; -} - -} // namespace algo +}; // namespace algo } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h index 5007cf019c..a909dd84e7 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h @@ -17,21 +17,13 @@ #pragma once -#include -#include - -#include - -#include "NSG.h" -#include "knowhere/common/Config.h" +#include "Neighbor.h" namespace knowhere { namespace algo { extern int InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn); -extern float -calculate(const float* a, const float* b, unsigned size); } // namespace algo } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGIO.h b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGIO.h index 12913b69df..9f2a42c4ad 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGIO.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSGIO.h @@ -18,7 +18,6 @@ #pragma once #include "NSG.h" -#include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/helpers/FaissIO.h" namespace knowhere { @@ -26,6 +25,7 @@ namespace algo { extern void write_index(NsgIndex* index, MemoryIOWriter& writer); + extern NsgIndex* read_index(MemoryIOReader& reader); diff --git a/core/src/index/unittest/test_nsg/test_nsg.cpp b/core/src/index/unittest/test_nsg/test_nsg.cpp index 47c014e691..4722c7e8f6 100644 --- a/core/src/index/unittest/test_nsg/test_nsg.cpp +++ b/core/src/index/unittest/test_nsg/test_nsg.cpp @@ -24,6 +24,8 @@ #ifdef MILVUS_GPU_VERSION #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #endif + +#include "knowhere/common/Timer.h" #include "knowhere/index/vector_index/nsg/NSGIO.h" #include "unittest/utils.h" @@ -95,20 +97,19 @@ TEST_F(NSGInterfaceTest, basic_test) { index_->Add(base_dataset, knowhere::Config()); index_->Seal(); }); - - { - // std::cout << "k = 1" << std::endl; - // new_index->Search(GenQuery(1), Config::object{{"k", 1}}); - // new_index->Search(GenQuery(10), Config::object{{"k", 1}}); - // new_index->Search(GenQuery(100), Config::object{{"k", 1}}); - // new_index->Search(GenQuery(1000), Config::object{{"k", 1}}); - // new_index->Search(GenQuery(10000), Config::object{{"k", 1}}); - - // std::cout << "k = 5" << std::endl; - // new_index->Search(GenQuery(1), Config::object{{"k", 5}}); - // new_index->Search(GenQuery(20), Config::object{{"k", 5}}); - // new_index->Search(GenQuery(100), Config::object{{"k", 5}}); - // new_index->Search(GenQuery(300), Config::object{{"k", 5}}); - // new_index->Search(GenQuery(500), Config::object{{"k", 5}}); - } +} + +TEST_F(NSGInterfaceTest, comparetest) { + knowhere::algo::DistanceL2 distanceL2; + knowhere::algo::DistanceIP distanceIP; + + knowhere::TimeRecorder tc("Compare"); + for (int i = 0; i < 1000; ++i) { + distanceL2.Compare(xb.data(), xq.data(), 256); + } + tc.RecordSection("L2"); + for (int i = 0; i < 1000; ++i) { + distanceIP.Compare(xb.data(), xq.data(), 256); + } + tc.RecordSection("IP"); }