add furthest neighbor search to Milvus. (#27452)

Signed-off-by: jeff <1093656867@qq.com>
summer2023
JiefengWang 2023-10-31 18:08:15 +08:00 committed by GitHub
parent ab6dbf7659
commit 85c9d37090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2124 additions and 0 deletions

View File

@ -32,6 +32,7 @@ add_subdirectory( index )
add_subdirectory( query )
add_subdirectory( segcore )
add_subdirectory( indexbuilder )
add_subdirectory( fns )
if(USE_DYNAMIC_SIMD)
add_subdirectory( simd )
endif()

View File

@ -0,0 +1,32 @@
# Licensed to the LF AI & Data foundation under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
find_package(OpenMP REQUIRED)
set(FURTHEST_NEIGHBOR_SEARCH_FILES
kmeans.cpp
kgraph.cpp
)
add_library(milvus_fnsearcher SHARED ${FURTHEST_NEIGHBOR_SEARCH_FILES})
if(USE_DYNAMIC_SIMD)
target_link_libraries(milvus_fnsearcher milvus_index milvus_simd)
else()
target_link_libraries(milvus_fnsearcher milvus_index)
endif()
install(TARGETS milvus_fnsearcher DESTINATION "${CMAKE_INSTALL_LIBDIR}")

View File

@ -0,0 +1,68 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef INTERNAL_CORE_SRC_FNS_BASICDISTANCE_H_
#define INTERNAL_CORE_SRC_FNS_BASICDISTANCE_H_
namespace milvus::basicDistance {
template <typename Dat_Type>
float
basicL2(const Dat_Type*, const Dat_Type*, const unsigned int);
template <typename Dat_Type>
float
basicL2(const Dat_Type* vec1, const Dat_Type* vec2, const unsigned int dim) {
float res = 0.0;
Dat_Type* v1 = (Dat_Type*)vec1;
Dat_Type* v2 = (Dat_Type*)vec2;
for (unsigned int i = 0; i < dim; ++i) {
float diff = (float)(*v1 - *v2);
res += (float)(diff * diff);
v1++;
v2++;
}
return res;
}
/// ---- inner product
template <typename Dat_Type>
float
basicInnerProduct(const Dat_Type* vec1,
const Dat_Type* vec2,
const unsigned int dim);
template <typename Dat_Type>
float
basicInnerProduct(const Dat_Type* vec1,
const Dat_Type* vec2,
const unsigned int dim) {
float res = 0.0;
Dat_Type* v1 = (Dat_Type*)vec1;
Dat_Type* v2 = (Dat_Type*)vec2;
for (unsigned int i = 0; i < dim; ++i) {
float product = (float)(*v1 * *v2);
res += product;
v1++;
v2++;
}
return 1.0 - res;
}
} // namespace milvus::basicDistance
#endif // INTERNAL_CORE_SRC_FNS_BASICDISTANCE_H_

View File

@ -0,0 +1,296 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef INTERNAL_CORE_SRC_FNS_FNS_HPP_
#define INTERNAL_CORE_SRC_FNS_FNS_HPP_
#include <memory>
#include <iostream>
#include <unordered_set>
#include <algorithm>
#include <queue>
#include <utility>
#include <vector>
#include "kmeans.h"
#include "kgraph.h"
#include "basicDistance.h"
#include "visitedList.hpp"
#include "common/Types.h"
namespace milvus::fns {
using PriorQ = std::priority_queue<std::pair<distance_t, idx_t>,
std::vector<std::pair<distance_t, idx_t>>>;
using NNDIndexParam = milvus::kgraph::IndexParams;
using graph_t = std::vector<std::vector<idx_t>>;
struct BuildParam {
size_t kms_cluster_num = 1024;
size_t filter_pos = 4;
size_t filter_density = 0;
size_t filter_ctrl_size = 64;
size_t filter_max_edges = 2048;
};
template <typename T>
class FNS {
private:
const T* data_{nullptr};
float* kmeans_centroids_{nullptr};
int* kmeans_labels_{nullptr};
graph_t kmeans_label_graph_;
std::vector<idx_t> density_;
graph_t knn_graph_;
const size_t data_size_{0};
const size_t data_dim_{0};
BuildParam build_param_;
NNDIndexParam nnd_para_;
size_t search_k_{0};
VisitedListPool* visited_list_pool_{nullptr};
public:
FNS(const T* base, size_t data_size, size_t data_dim)
: data_(base), data_size_(data_size), data_dim_(data_dim) {
visited_list_pool_ = new VisitedListPool(1, data_size_);
}
~FNS() {
delete visited_list_pool_;
if (kmeans_centroids_)
delete[] kmeans_centroids_;
if (kmeans_labels_)
delete[] kmeans_labels_;
}
void
setUpBuildPara(const BuildParam& bp) {
build_param_ = bp;
}
void
setUpNNDPara(const NNDIndexParam& nnd_para) {
nnd_para_ = nnd_para;
}
int
build() {
auto kms_return_val = runKmeans();
auto kms_nnd_val = runNNDescent();
if (kms_return_val == -1 || kms_nnd_val == -1)
return -1;
updateDensity();
graph_t(build_param_.kms_cluster_num).swap(kmeans_label_graph_);
for (size_t idx = 0; idx < data_size_; ++idx) {
auto lb = kmeans_labels_[idx];
kmeans_label_graph_[lb].emplace_back(idx);
}
return 0;
}
void
setSearchK(size_t k) {
search_k_ = k;
}
int
runNNDescent() {
// load base data;
Matrix<T> base_data;
base_data.load(data_, data_size_, data_dim_);
MatrixOracle<T, metric::l2sqr> oracle(base_data);
std::unique_ptr<milvus::kgraph::KGraphConstructor> kg(
new milvus::kgraph::KGraphConstructor(oracle, nnd_para_));
int nnd_return_val = kg->build_index();
if (nnd_return_val != 0) {
return nnd_return_val;
}
auto knn_pool = kg->nhoods;
std::vector<std::vector<idx_t>>(data_size_).swap(knn_graph_);
for (size_t i = 0; i < data_size_; ++i) {
auto& kg_nbhood = knn_graph_[i];
auto const& pool = knn_pool[i].pool;
for (auto& elem : pool) {
kg_nbhood.emplace_back(elem.id);
}
}
return 0;
}
float
runKmeans() {
if (kmeans_centroids_) {
delete[] kmeans_centroids_;
kmeans_centroids_ = nullptr;
}
kmeans_centroids_ = new float[build_param_.kms_cluster_num * data_dim_];
if (kmeans_labels_) {
delete[] kmeans_labels_;
kmeans_labels_ = nullptr;
}
kmeans_labels_ = new int[data_size_];
std::unique_ptr<puck::Kmeans> kms(new puck::Kmeans(true));
auto kms_return_val = kms->kmeans(data_dim_,
data_size_,
build_param_.kms_cluster_num,
data_,
kmeans_centroids_,
nullptr,
kmeans_labels_);
return kms_return_val;
}
inline void
updateDensity() {
vector<idx_t>(data_size_, 0).swap(density_);
for (size_t i = 0; i < data_size_; ++i) {
for (size_t j = 0; j < build_param_.filter_pos &&
j < (size_t)knn_graph_[0].size();
++j) {
auto nb = knn_graph_[i][j];
++density_[nb];
}
}
}
inline PriorQ
searchFNS(T* query) {
PriorQ top_results;
if (search_k_ <= 0 || search_k_ > data_size_ || query == nullptr) {
return top_results;
}
VisitedList* vl = visited_list_pool_->getFreeVisitedList();
vl_type* visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
std::vector<std::pair<distance_t, idx_t>> pairs_center(
build_param_.kms_cluster_num);
for (size_t i = 0; i < build_param_.kms_cluster_num; ++i) {
auto label = i;
auto dst = milvus::basicDistance::basicL2(
query, &kmeans_centroids_[i * data_dim_], data_dim_);
pairs_center[i].first = -dst;
pairs_center[i].second = label;
}
sort(pairs_center.begin(), pairs_center.end());
vector<idx_t> low_density_points;
for (auto& p : pairs_center) {
auto lb = p.second;
auto& points = kmeans_label_graph_[lb];
for (auto p : points) {
if (density_[p] <= build_param_.filter_density) {
low_density_points.emplace_back(p);
if (low_density_points.size() >
build_param_.filter_max_edges) {
break;
}
}
}
if (low_density_points.size() > build_param_.filter_max_edges) {
break;
}
}
std::vector<std::pair<distance_t, idx_t>> pairs(
low_density_points.size());
for (auto i = 0; i < low_density_points.size(); ++i) {
auto nb = low_density_points[i];
auto dst = milvus::basicDistance::basicL2(
query, data_ + nb * data_dim_, data_dim_);
pairs[i].first = -dst;
pairs[i].second = nb;
}
sort(pairs.begin(), pairs.end());
for (size_t i = 0; i < search_k_; i++) {
auto& p = pairs[i];
top_results.push(p);
visited_array[p.second] = visited_array_tag;
}
for (size_t i = 0; i < build_param_.filter_ctrl_size; i++) {
auto nb = pairs[i].second;
for (size_t j = 0; j < knn_graph_[0].size(); ++j) {
size_t nn = (size_t)knn_graph_[nb][j];
if (visited_array[nn] == visited_array_tag) {
continue;
}
visited_array[nn] = visited_array_tag;
auto dst = milvus::basicDistance::basicL2(
query, data_ + nn * data_dim_, data_dim_);
if (dst > -top_results.top().first) {
top_results.emplace(-dst, nn);
}
}
while (top_results.size() > search_k_) {
top_results.pop();
}
}
visited_list_pool_->releaseVisitedList(vl);
return top_results;
}
double
evaluateRatio(std::vector<std::vector<unsigned>>& gt,
graph_t& full_fns,
size_t query_size,
T* query,
size_t checkK = 100) {
double avg_ratio = 0;
for (size_t index = 0; index < query_size; ++index) {
auto q = query + index * data_dim_;
auto& gt_list = gt[index];
auto& nn_list = full_fns[index];
double overall_ratio = 0;
checkK = std::min(checkK, gt_list.size());
for (int iter = 0; iter < checkK; ++iter) {
auto idx = nn_list[iter];
auto gt_idx = gt_list[iter];
overall_ratio += milvus::basicDistance::basicL2(
q, data_ + gt_idx * data_dim_, data_dim_) /
milvus::basicDistance::basicL2(
q, data_ + idx * data_dim_, data_dim_);
}
avg_ratio += overall_ratio / checkK;
}
return avg_ratio / query_size;
}
float
evaluateRecall(std::vector<std::vector<unsigned>>& gt,
graph_t& full_fns,
size_t checkK = 100) {
size_t hit = 0;
size_t checkSz = full_fns.size();
for (size_t iter = 0; iter < checkSz; ++iter) {
auto& fns = full_fns[iter];
auto& gt_list = gt[iter];
for (auto i = 0; i < checkK; ++i) {
auto fn = fns[i];
for (auto j = 0; j < checkK; ++j) {
auto nb = gt_list[j];
if (fn == nb) {
++hit;
break;
}
}
}
}
return 1.0 * hit / (checkSz * checkK);
}
};
} // namespace milvus::fns
#endif // INTERNAL_CORE_SRC_FNS_FNS_HPP_

242
internal/core/src/fns/kgraph.cpp Executable file
View File

@ -0,0 +1,242 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "kgraph.h"
namespace milvus::kgraph {
KGraphConstructor::KGraphConstructor(IndexOracle const& o, IndexParams& p)
: oracle(o), params(p), nhoods(o.size()), n_comps(0) {
}
KGraphConstructor::~KGraphConstructor() {
}
void
KGraphConstructor::init() {
unsigned N = oracle.size();
unsigned seed = params.seed;
std::mt19937 rng(seed);
for (auto& nhood : nhoods) {
nhood.nn_new.resize(params.S * 2);
nhood.pool.resize(params.L + 1);
nhood.radius = std::numeric_limits<float>::max();
}
#pragma omp parallel
{
#ifdef _OPENMP
std::mt19937 rng(seed ^ omp_get_thread_num());
#else
std::mt19937 rng(seed);
#endif
std::vector<unsigned> random(params.S + 1);
#pragma omp for
for (unsigned n = 0; n < N; ++n) {
auto& nhood = nhoods[n];
Neighbors& pool = nhood.pool;
GenRandom(rng, &nhood.nn_new[0], nhood.nn_new.size(), N);
GenRandom(rng, &random[0], random.size(), N);
nhood.L = params.S;
nhood.M = params.S;
unsigned i = 0;
for (unsigned l = 0; l < nhood.L; ++l) {
if (random[i] == n)
++i;
auto& nn = nhood.pool[l];
nn.id = random[i++];
nn.dist = oracle(nn.id, n);
nn.flag = true;
}
sort(pool.begin(), pool.begin() + nhood.L); //// only init top smp?
}
}
}
void
KGraphConstructor::join() {
size_t cc = 0;
#pragma omp parallel for default(shared) schedule(dynamic, 100) reduction(+:cc)
for (unsigned n = 0; n < oracle.size(); ++n) {
size_t uu = 0;
nhoods[n].found = false;
nhoods[n].join([&](unsigned i, unsigned j) {
float dist = oracle(i, j);
++cc;
unsigned r;
r = nhoods[i].parallel_try_insert(j, dist);
if (r < params.K)
++uu;
nhoods[j].parallel_try_insert(i, dist);
if (r < params.K)
++uu;
});
nhoods[n].found = uu > 0;
}
n_comps += cc;
}
void
KGraphConstructor::update() {
unsigned N = oracle.size();
unsigned seed = params.seed;
std::mt19937 rng(seed);
for (auto& nhood : nhoods) {
nhood.nn_new.clear();
nhood.nn_old.clear();
nhood.rnn_new.clear();
nhood.rnn_old.clear();
nhood.radius = nhood.pool.back().dist;
}
//!!! compute radius2
#pragma omp parallel for
for (unsigned n = 0; n < N; ++n) {
auto& nhood = nhoods[n];
if (nhood.found) {
unsigned maxl = std::min(nhood.M + params.S, nhood.L);
unsigned c = 0;
unsigned l = 0;
while ((l < maxl) && (c < params.S)) {
if (nhood.pool[l].flag)
++c;
++l;
}
nhood.M = l;
}
BOOST_VERIFY(nhood.M > 0);
nhood.radiusM = nhood.pool[nhood.M - 1].dist;
}
#pragma omp parallel for
for (unsigned n = 0; n < N; ++n) {
auto& nhood = nhoods[n];
auto& nn_new = nhood.nn_new;
auto& nn_old = nhood.nn_old;
for (unsigned l = 0; l < nhood.M; ++l) {
auto& nn = nhood.pool[l];
auto& nhood_o =
nhoods[nn.id]; // nhood on the other side of the edge
if (nn.flag) {
nn_new.push_back(nn.id);
if (nn.dist > nhood_o.radiusM) {
LockGuard guard(nhood_o.lock);
nhood_o.rnn_new.push_back(n);
}
nn.flag = false;
} else {
nn_old.push_back(nn.id);
if (nn.dist > nhood_o.radiusM) {
LockGuard guard(nhood_o.lock);
nhood_o.rnn_old.push_back(n);
}
}
}
}
for (unsigned i = 0; i < N; ++i) {
auto& nn_new = nhoods[i].nn_new;
auto& nn_old = nhoods[i].nn_old;
auto& rnn_new = nhoods[i].rnn_new;
auto& rnn_old = nhoods[i].rnn_old;
if (params.R && (rnn_new.size() > params.R)) {
shuffle(rnn_new.begin(), rnn_new.end(), rng);
rnn_new.resize(params.R);
}
nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
if (params.R && (rnn_old.size() > params.R)) {
shuffle(rnn_old.begin(), rnn_old.end(), rng);
rnn_old.resize(params.R);
}
nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
}
}
int
KGraphConstructor::build_index() {
if (params.S > params.L) {
return -1;
}
auto start_timer = std::chrono::steady_clock::now();
unsigned N = oracle.size();
init();
float total = N * float(N - 1) / 2;
for (unsigned it = 0; (it < params.iterations); ++it) {
printf("kgraph::iter : [%d / %d]\n", it + 1, params.iterations);
join();
update();
}
auto end_timer = std::chrono::steady_clock::now();
auto duration = 1.0 *
std::chrono::duration_cast<std::chrono::milliseconds>(
end_timer - start_timer)
.count() /
1000.0;
{
// auto mem = getMemoryUsage();
std::cout << "kgraph's duration: " << duration << " seconds\n";
}
return 0;
}
template <typename RNG>
void
KGraphConstructor::GenRandom(RNG& rng,
unsigned* addr,
unsigned size,
unsigned N) {
if (N == size) {
for (unsigned i = 0; i < size; ++i) {
addr[i] = i;
}
return;
}
for (unsigned i = 0; i < size; ++i) {
addr[i] = rng() % (N - size);
}
std::sort(addr, addr + size);
for (unsigned i = 1; i < size; ++i) {
if (addr[i] <= addr[i - 1]) {
addr[i] = addr[i - 1] + 1;
}
}
unsigned off = rng() % N;
for (unsigned i = 0; i < size; ++i) {
addr[i] = (addr[i] + off) % N;
}
}
float
KGraphConstructor::evaluate(vector<vector<unsigned int>>& gt) {
auto size_n = gt.size();
auto check_k = 10;
unsigned int hit = 0.0;
for (unsigned int i = 0; i < size_n; ++i) {
auto& pool = nhoods[i].pool;
for (unsigned int j = 0; j < check_k; ++j) {
auto res_idx = pool[j].id;
for (unsigned int k = 0; k < check_k; ++k) {
auto gt_idx = gt[i][k];
if (fabs((float)res_idx - (float)gt_idx) < 1E-4) {
hit += 1;
break;
}
}
}
}
return (float)1.0 * hit / (size_n * check_k);
}
} // namespace milvus::kgraph

222
internal/core/src/fns/kgraph.h Executable file
View File

@ -0,0 +1,222 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef INTERNAL_CORE_SRC_FNS_KGRAPH_H_
#define INTERNAL_CORE_SRC_FNS_KGRAPH_H_
#include "kgraph_data.h"
#include <string>
#include <vector>
#include <omp.h>
#include <unordered_set>
#include <iostream>
#include <fstream>
#include <random>
#include <algorithm>
#include "boost/smart_ptr/detail/spinlock.hpp"
#include <mutex>
// using namespace kgraph;
namespace milvus::kgraph {
using Neighbors = std::vector<Neighbor>;
using graph = std::vector<std::vector<Neighbor>>;
typedef boost::detail::spinlock Lock;
typedef std::lock_guard<Lock> LockGuard;
struct IndexParams {
unsigned iterations;
unsigned L;
unsigned K;
unsigned S;
unsigned R;
unsigned controls;
unsigned seed;
float delta;
float recall;
unsigned prune;
int reverse;
/// Construct with default values.
IndexParams()
: iterations(10),
L(100),
K(100),
S(24),
R(24),
controls(0),
seed(0),
delta(0),
recall(1),
prune(0),
reverse(0) {
}
};
struct Nhood { // neighborhood
Lock lock;
float radius; // distance of interesting range
float radiusM;
Neighbors pool;
unsigned L; // # valid items in the pool, L + 1 <= pool.size()
unsigned M; // we only join items in pool[0..M)
bool found; // helped found new NN in this round
std::vector<unsigned> nn_old;
std::vector<unsigned> nn_new;
std::vector<unsigned> rnn_old;
std::vector<unsigned> rnn_new;
unsigned
UpdateKnnListHelper(Neighbor* addr,
unsigned K,
const Neighbor& nn) { /// why inline NO bug TODO:
// find the location to insert
unsigned j;
unsigned i = K;
while (i > 0) {
j = i - 1;
if (addr[j].dist <= nn.dist)
break;
i = j;
}
// check for equal ID
unsigned l = i;
while (l > 0) {
j = l - 1;
if (addr[j].dist < nn.dist)
break;
if (addr[j].id == nn.id)
return K + 1;
l = j;
}
// i <= K-1
j = K;
while (j > i) {
addr[j] = addr[j - 1];
--j;
}
addr[i] = nn;
return i;
return 0;
}
unsigned
UpdateKnnList(Neighbor* addr, unsigned K, const Neighbor& nn) {
return UpdateKnnListHelper(addr, K, nn);
}
unsigned
parallel_try_insert(unsigned id, float dist) {
if (dist > radius)
return pool.size();
LockGuard guard(lock);
unsigned l = UpdateKnnList(&pool[0], L, Neighbor(id, dist, true));
if (l <= L) {
if (L + 1 < pool.size()) {
++L;
} else {
radius = pool[L - 1].dist;
}
}
return l;
}
template <typename C>
void
join(C callback) const {
for (unsigned const i : nn_new) {
for (unsigned const j : nn_new) {
if (i < j) {
callback(i, j);
}
}
for (unsigned j : nn_old) {
callback(i, j);
}
}
}
};
class KGraphConstructor {
public:
std::vector<Nhood> nhoods;
private:
IndexOracle const& oracle;
IndexParams params;
size_t n_comps;
private:
template <typename RNG>
void
GenRandom(RNG& rng, unsigned* addr, unsigned size, unsigned N);
void
init();
void
join();
void
update();
public:
int
build_index();
void inner_save(std::string);
float
evaluate(vector<vector<unsigned int>>&);
KGraphConstructor(IndexOracle const& o, IndexParams&);
~KGraphConstructor();
public:
inline int
parseLine(char* line) {
// This assumes that a digit will be found and the line ends in " Kb".
int i = strlen(line);
const char* p = line;
while (*p < '0' || *p > '9') p++;
line[i - 3] = '\0';
i = atoi(p);
return i;
}
double
getMemoryUsage() {
FILE* file = fopen("/proc/self/status", "r");
int highwater_mark = -1;
int current_memory = -1;
char line[128];
while (fgets(line, 128, file) != NULL) {
if (strncmp(line, "VmHWM:", 6) == 0) {
highwater_mark = parseLine(line);
}
if (strncmp(line, "VmRSS:", 6) == 0) {
current_memory = parseLine(line);
}
if (highwater_mark > 0 && current_memory > 0) {
break;
}
}
fclose(file);
return (double)1.0 * highwater_mark / 1024;
}
};
} // namespace milvus::kgraph
#endif // INTERNAL_CORE_SRC_FNS_KGRAPH_H_

View File

@ -0,0 +1,413 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef INTERNAL_CORE_SRC_FNS_KGRAPH_DATA_H_
#define INTERNAL_CORE_SRC_FNS_KGRAPH_DATA_H_
#include <cmath>
#ifndef __APPLE__
#include <malloc.h>
#endif
#include <memory>
#include <cstring>
#include <cstdlib>
#include <vector>
#include <string>
#include <fstream>
#include <stdexcept>
#include <boost/assert.hpp>
#include <iostream>
#include <stdexcept>
#ifdef __GNUC__
#ifdef __AVX__
#define KGRAPH_MATRIX_ALIGN 32
#else
#ifdef __SSE2__
#define KGRAPH_MATRIX_ALIGN 16
#else
#define KGRAPH_MATRIX_ALIGN 4
#endif
#endif
#endif
namespace milvus {
constexpr float EPS = 1e-6;
struct Neighbor {
uint32_t id;
float dist;
bool flag; // whether this entry is a newly found one
Neighbor() {
}
Neighbor(unsigned i, float d, bool f = true) : id(i), dist(d), flag(f) {
}
bool
operator<(const Neighbor& other) const {
if (fabs(this->dist - other.dist) < EPS)
return this->id < other.id;
return this->dist < other.dist;
}
bool
operator==(const Neighbor& other) const {
return this->id == other.id && (fabs(this->dist - other.dist) < EPS);
}
bool
operator>=(const Neighbor& other) const {
return !(*this < other);
}
bool
operator<=(const Neighbor& other) const {
return (*this == other) || (*this < other);
}
bool
operator>(const Neighbor& other) const {
return !(*this <= other);
}
bool
operator!=(const Neighbor& other) const {
return !(*this == other);
}
};
using std::vector;
/// namespace for various distance metrics.
namespace metric {
/// L2 square distance.
struct l2sqr {
template <typename T>
/// L2 square distance.
static float
apply(T const* t1, T const* t2, unsigned dim) {
float r = 0;
for (unsigned i = 0; i < dim; ++i) {
float v = float(t1[i]) - float(t2[i]);
v *= v;
r += v;
}
return r;
}
/// inner product.
template <typename T>
static float
dot(T const* t1, T const* t2, unsigned dim) {
float r = 0;
for (unsigned i = 0; i < dim; ++i) {
r += float(t1[i]) * float(t2[i]);
}
return r;
}
/// L2 norm.
template <typename T>
static float
norm2(T const* t1, unsigned dim) {
float r = 0;
for (unsigned i = 0; i < dim; ++i) {
float v = float(t1[i]);
v *= v;
r += v;
}
return r;
}
};
struct l2 {
template <typename T>
static float
apply(T const* t1, T const* t2, unsigned dim) {
return sqrt(l2sqr::apply<T>(t1, t2, dim));
}
};
} // namespace metric
/// Matrix data.
template <typename T, unsigned A = KGRAPH_MATRIX_ALIGN>
class Matrix {
unsigned col;
unsigned row;
size_t stride;
char* data;
void
reset(unsigned r, unsigned c) {
row = r;
col = c;
stride = (sizeof(T) * c + A - 1) / A * A;
if (data)
free(data);
#ifndef __APPLE__
data = (char*)memalign(
A, row * stride); // SSE instruction needs data to be aligned
#else
data = (char*)malloc(row * stride);
#endif
if (!data)
throw std::runtime_error("memalign");
}
public:
Matrix() : col(0), row(0), stride(0), data(0) {
}
Matrix(unsigned r, unsigned c) : data(0) {
reset(r, c);
}
~Matrix() {
if (data)
free(data);
}
unsigned
size() const {
return row;
}
unsigned
dim() const {
return col;
}
size_t
step() const {
return stride;
}
void
resize(unsigned r, unsigned c) {
reset(r, c);
}
T const*
operator[](unsigned i) const {
return reinterpret_cast<T const*>(&data[stride * i]);
}
T*
operator[](unsigned i) {
return reinterpret_cast<T*>(&data[stride * i]);
}
void
zero() {
memset(data, 0, row * stride);
}
void
normalize2() {
#pragma omp parallel for
for (unsigned i = 0; i < row; ++i) {
T* p = operator[](i);
double sum = metric::l2sqr::norm2(p, col);
sum = std::sqrt(sum);
for (unsigned j = 0; j < col; ++j) {
p[j] /= sum;
}
}
}
void
load(const std::string& path,
unsigned dim,
unsigned skip = 0,
unsigned gap = 0) {
std::ifstream is(path.c_str(), std::ios::binary);
if (!is)
return;
is.seekg(0, std::ios::end);
size_t size = is.tellg();
size -= skip;
is.seekg(0, std::ios::beg);
is.read((char*)&dim, sizeof(unsigned int));
unsigned line = sizeof(T) * dim + gap;
unsigned N = size / line;
reset(N, dim);
zero();
is.seekg(skip, std::ios::beg);
for (unsigned i = 0; i < N; ++i) {
is.seekg(gap, std::ios::cur);
is.read(&data[stride * i], sizeof(T) * dim);
}
}
void
load(const T* base_vector, size_t data_size, size_t data_dim) {
reset(data_size, data_dim);
zero();
for (size_t i = 0; i < data_size; ++i) {
memcpy(&data[stride * i],
base_vector + i * data_dim,
sizeof(T) * data_dim);
}
}
};
namespace kgraph {
class IndexOracle {
public:
/// Returns the size of the dataset.
virtual unsigned
size() const = 0;
/// Computes similarity
/**
* 0 <= i, j < size() are the index of two objects in the dataset.
* This method return the distance between objects i and j.
*/
virtual float
operator()(unsigned i, unsigned j) const = 0;
};
/// Search oracle
/** The search oracle is the user-supplied plugin that computes
* the distance between the query and a arbitrary object in the dataset.
* It is used for online k-NN search.
*/
class SearchOracle {
public:
/// Returns the size of the dataset.
virtual unsigned
size() const = 0;
/// Computes similarity
/**
* 0 <= i < size() are the index of an objects in the dataset.
* This method return the distance between the query and object i.
*/
virtual float
operator()(unsigned i) const = 0;
/// Search with brutal force.
/**
* Search results are guaranteed to be ranked in ascending order of distance.
*
* @param K Return at most K nearest neighbors.
* @param epsilon Only returns nearest neighbors within distance epsilon.
* @param ids Pointer to the memory where neighbor IDs are returned.
* @param dists Pointer to the memory where distance values are returned, can be nullptr.
*/
unsigned
search(unsigned K,
float epsilon,
unsigned* ids,
float* dists = nullptr) const;
};
} // namespace kgraph
/// Matrix proxy to interface with 3rd party libraries (FLANN, OpenCV, NumPy).
template <typename DATA_TYPE, unsigned A = KGRAPH_MATRIX_ALIGN>
class MatrixProxy {
unsigned rows;
unsigned cols; // # elements, not bytes, in a row,
size_t stride; // # bytes in a row, >= cols * sizeof(element)
uint8_t const* data;
public:
MatrixProxy(Matrix<DATA_TYPE> const& m)
: rows(m.size()),
cols(m.dim()),
stride(m.step()),
data(reinterpret_cast<uint8_t const*>(m[0])) {
}
#ifndef __AVX__
#ifdef FLANN_DATASET_H_
/// Construct from FLANN matrix.
MatrixProxy(flann::Matrix<DATA_TYPE> const& m)
: rows(m.rows), cols(m.cols), stride(m.stride), data(m.data) {
if (stride % A)
throw invalid_argument("bad alignment");
}
#endif
#ifdef CV_MAJOR_VERSION
/// Construct from OpenCV matrix.
MatrixProxy(cv::Mat const& m)
: rows(m.rows), cols(m.cols), stride(m.step), data(m.data) {
if (stride % A)
throw invalid_argument("bad alignment");
}
#endif
#ifdef NPY_NDARRAYOBJECT_H
/// Construct from NumPy matrix.
MatrixProxy(PyArrayObject* obj) {
if (!obj || (obj->nd != 2))
throw invalid_argument("bad array shape");
rows = obj->dimensions[0];
cols = obj->dimensions[1];
stride = obj->strides[0];
data = reinterpret_cast<uint8_t const*>(obj->data);
if (obj->descr->elsize != sizeof(DATA_TYPE))
throw invalid_argument("bad data type size");
if (stride % A)
throw invalid_argument("bad alignment");
if (!(stride >= cols * sizeof(DATA_TYPE)))
throw invalid_argument("bad stride");
}
#endif
#endif
unsigned
size() const {
return rows;
}
unsigned
dim() const {
return cols;
}
DATA_TYPE const*
operator[](unsigned i) const {
return reinterpret_cast<DATA_TYPE const*>(data + stride * i);
}
DATA_TYPE*
operator[](unsigned i) {
return const_cast<DATA_TYPE*>(
reinterpret_cast<DATA_TYPE const*>(data + stride * i));
}
};
template <typename DATA_TYPE, typename DIST_TYPE>
class MatrixOracle : public kgraph::IndexOracle {
MatrixProxy<DATA_TYPE> proxy;
public:
class SearchOracle : public kgraph::SearchOracle {
MatrixProxy<DATA_TYPE> proxy;
DATA_TYPE const* query;
public:
SearchOracle(MatrixProxy<DATA_TYPE> const& p, DATA_TYPE const* q)
: proxy(p), query(q) {
}
virtual unsigned
size() const {
return proxy.size();
}
virtual float
operator()(unsigned i) const {
return DIST_TYPE::apply(proxy[i], query, proxy.dim());
}
};
template <typename MATRIX_TYPE>
MatrixOracle(MATRIX_TYPE const& m) : proxy(m) {
}
virtual unsigned
size() const {
return proxy.size();
}
virtual float
operator()(unsigned i, unsigned j) const {
return DIST_TYPE::apply(proxy[i], proxy[j], proxy.dim());
}
SearchOracle
query(DATA_TYPE const* query) const {
return SearchOracle(proxy, query);
}
};
} // namespace milvus
#endif // INTERNAL_CORE_SRC_FNS_KGRAPH_DATA_H_

View File

@ -0,0 +1,382 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.
//Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @file kmeans.cpp
* @author yinjie06(yinjie06@baidu.com)
* @date 2023/07/25 11:11
* @brief
*
**/
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <assert.h>
#ifndef __APPLE__
#include <cblas.h>
#else
#include <Accelerate/Accelerate.h>
#endif
#include <random>
#include "kmeans.h"
#include <iostream>
namespace milvus::puck {
void
Kmeans::random_init_center(const size_t total_cnt,
const size_t sample_cnt,
const uint32_t dim,
const float* train_dataset,
std::vector<size_t>& sample_ids) {
sample_ids.clear();
std::uniform_int_distribution<> dis(0, total_cnt - 1);
std::vector<bool> filter(total_cnt, false);
while (sample_ids.size() < sample_cnt) {
size_t sample_id = dis(_rnd);
//去重
if (filter[sample_id]) {
continue;
}
sample_ids.push_back(sample_id);
filter[sample_id] = true;
}
}
int
Kmeans::roulette_selection(std::vector<float>& wheel) {
float total_val = 0;
for (auto& val : wheel) {
total_val += val;
}
cblas_sscal(wheel.size(), 1.0 / total_val, wheel.data(), 1);
std::uniform_real_distribution<double> dis(0, 1.0);
double rd = dis(_rnd);
for (auto id = 0; id < wheel.size(); ++id) {
rd -= wheel[id];
if (rd < 0) {
return id;
}
}
return wheel.size() - 1;
}
void
Kmeans::kmeanspp_init_center(const size_t total_cnt,
const size_t sample_cnt,
const uint32_t dim,
const float* train_dataset,
std::vector<size_t>& sample_ids) {
std::vector<float> disbest(total_cnt, std::numeric_limits<float>::max());
std::vector<float> distmp(total_cnt);
sample_ids.resize(sample_cnt, 0);
sample_ids[0] = _rnd() % total_cnt;
std::vector<float> points_norm(total_cnt, 0);
#pragma omp parallel for schedule(dynamic) num_threads(_params.nt)
for (size_t j = 0; j < total_cnt; j++) {
points_norm[j] = cblas_sdot(
dim, train_dataset + j * dim, 1, train_dataset + j * dim, 1);
}
for (size_t i = 1; i < sample_cnt; i++) {
size_t newsel = sample_ids[i - 1];
const float* last_center = train_dataset + newsel * dim;
#pragma omp parallel for schedule(dynamic) num_threads(_params.nt)
for (size_t j = 0; j < total_cnt; j++) {
float temp =
points_norm[j] + points_norm[newsel] -
2.0 *
cblas_sdot(dim, train_dataset + j * dim, 1, last_center, 1);
if (temp < disbest[j]) {
disbest[j] = temp;
}
}
memcpy(distmp.data(), disbest.data(), total_cnt * sizeof(distmp[0]));
sample_ids[i] = roulette_selection(distmp);
}
}
int
Kmeans::kmeans_reassign_empty(uint32_t dim,
size_t total_cnt,
size_t k,
float* centroids,
int* assign,
int* nassign) {
std::vector<float> proba_split(k);
std::vector<float> vepsilon(dim);
std::normal_distribution<> d_normal(0, _rnd() / ((double)RAND_MAX + 1.0));
#pragma omp parallel for schedule(dynamic) num_threads(_params.nt)
for (auto c = 0; c < k; c++) {
proba_split[c] = (nassign[c] < 2 ? 0 : nassign[c] * nassign[c] - 1);
}
int nreassign = 0;
for (auto c = 0; c < k; c++) {
if (nassign[c] == 0) {
nreassign++;
auto j = roulette_selection(proba_split);
memcpy(centroids + c * dim,
centroids + j * dim,
dim * sizeof(centroids[0]));
double s = cblas_snrm2(dim, centroids + j * dim, 1) * 0.0000001;
for (auto& v : vepsilon) {
v = d_normal(_rnd);
}
cblas_sscal(dim, s, vepsilon.data(), 1);
cblas_saxpy(dim, 1.0, vepsilon.data(), 1, centroids + j * dim, 1);
cblas_saxpy(dim, -1.0, vepsilon.data(), 1, centroids + c * dim, 1);
proba_split[j] = 0;
}
}
return nreassign;
}
float
Kmeans::kmeans(uint32_t dim,
size_t total_cnt,
size_t k,
const float* train_dataset,
float* centroids_out,
float* dis_out,
int* assign_out) {
if (k == 0 || k >= total_cnt) {
return -1;
}
int nt = std::max(_params.nt, 1);
std::unique_ptr<float[]> centroids(new float[k * dim]);
std::unique_ptr<float[]> dis(new float[total_cnt]);
std::unique_ptr<int[]> assign(new int[total_cnt]);
std::unique_ptr<int[]> nassign(new int[k]);
double qerr = std::numeric_limits<double>::max();
double qerr_best = std::numeric_limits<double>::max();
std::vector<size_t> selected(k);
int core_ret = 0;
for (auto run = 0; run < _params.redo; run++) {
if (_params.init_type == KMeansCenterInitType::KMEANS_PLUS_PLUS) {
//数据集太大时候,缩小范围,待开发
uint32_t nsubset =
(total_cnt > k * 8 && total_cnt > 8 * 1024) ? k * 8 : total_cnt;
kmeanspp_init_center(nsubset, k, dim, train_dataset, selected);
} else {
random_init_center(total_cnt, k, dim, train_dataset, selected);
}
for (auto i = 0; i < k; i++) {
memcpy(centroids.get() + i * dim,
train_dataset + selected[i] * dim,
dim * sizeof(centroids[0]));
}
core_ret = kmeans_core(dim,
total_cnt,
k,
_params.niter,
nt,
centroids.get(),
train_dataset,
assign.get(),
nassign.get(),
dis.get(),
&qerr);
if (core_ret < 0) {
return -1;
break;
}
if (qerr < qerr_best) {
qerr_best = qerr;
if (centroids_out != nullptr) {
memcpy(centroids_out,
centroids.get(),
k * dim * sizeof(*centroids.get()));
}
if (dis_out != nullptr) {
memcpy(dis_out, dis.get(), total_cnt * sizeof(*dis.get()));
}
if (assign_out != nullptr) {
memcpy(assign_out,
assign.get(),
total_cnt * sizeof(*assign.get()));
}
}
}
return qerr_best / total_cnt;
}
int
nearest_center(uint32_t dim,
const float* centroids,
const size_t centroid_cnt,
const float* train_dataset,
const size_t point_cnt,
int* assign,
float* dis) {
std::vector<float> points_norm(point_cnt);
std::vector<float> centroids_norm(centroid_cnt);
int nt = std::thread::hardware_concurrency();
#pragma omp parallel for schedule(dynamic) num_threads(nt)
for (size_t j = 0; j < point_cnt; j++) {
points_norm[j] = cblas_sdot(
dim, train_dataset + j * dim, 1, train_dataset + j * dim, 1);
}
#pragma omp parallel for schedule(dynamic) num_threads(nt)
for (size_t j = 0; j < centroid_cnt; j++) {
centroids_norm[j] =
cblas_sdot(dim, centroids + j * dim, 1, centroids + j * dim, 1);
}
#pragma omp parallel for schedule(dynamic) num_threads(nt)
for (size_t j = 0; j < point_cnt; j++) {
std::pair<float, uint32_t> min_centroid = {
std::numeric_limits<float>::max(), 0};
for (size_t c = 0; c < centroid_cnt; c++) {
float cur_dist = points_norm[j] + centroids_norm[c] -
2.0 * cblas_sdot(dim,
train_dataset + j * dim,
1,
centroids + c * dim,
1);
if (cur_dist < min_centroid.first) {
min_centroid = {cur_dist, c};
}
}
assign[j] = min_centroid.second;
dis[j] = min_centroid.first;
}
return 0;
}
int
Kmeans::kmeans_core(uint32_t d,
size_t n,
size_t k,
int niter,
int nt,
float* centroids,
const float* v,
int* assign,
int* nassign,
float* dis,
double* qerr_out) {
double qerr = std::numeric_limits<double>::max();
double qerr_old = std::numeric_limits<double>::max();
int tot_nreassign = 0;
auto start_timer = std::chrono::steady_clock::now();
for (auto iter = 1; iter <= niter; iter++) {
printf("kmeans::iter : [%d / %d]\n", iter, niter);
nearest_center(d, centroids, k, v, n, assign, dis);
memset(nassign, 0, k * sizeof(int));
{
memset(centroids, 0, sizeof(centroids[0] * k * d));
for (auto i = 0; i < n; i++) {
if (assign[i] < 0 || assign[i] >= k) {
return -1;
}
nassign[assign[i]]++;
cblas_saxpy(d, 1.0, v + i * d, 1, centroids + assign[i] * d, 1);
}
for (auto i = 0; i < k; i++) {
cblas_sscal(d, 1.0 / nassign[i], centroids + i * d, 1);
}
}
auto nreassign =
kmeans_reassign_empty(d, n, k, centroids, assign, nassign);
tot_nreassign += nreassign;
if (tot_nreassign > n / 100 && tot_nreassign > 1000) {
return -1;
}
qerr_old = qerr;
qerr = 0;
for (auto i = 0; i < n; i++) {
qerr += dis[i];
}
if (std::fabs(qerr_old - qerr) < 1e-6 && nreassign == 0) {
break;
}
}
auto end_timer = std::chrono::steady_clock::now();
auto duration = 1.0 *
std::chrono::duration_cast<std::chrono::milliseconds>(
end_timer - start_timer)
.count() /
1000.0;
{ std::cout << "kmeans's duration: " << duration << " seconds\n"; }
*qerr_out = qerr;
return 0;
}
}; // namespace milvus::puck

View File

@ -0,0 +1,128 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.
//Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @file kmeans.h
* @author yinjie06(yinjie06@baidu.com)
* @date 2023/07/25 11:11
* @brief
*
**/
#pragma once
#include <thread>
#include <random>
#include <chrono>
#include <vector>
namespace milvus::puck {
enum KMeansCenterInitType {
RANDOM = 0,
KMEANS_PLUS_PLUS = 1,
};
//Kmeans聚类的参数
struct KmeansParams {
int redo;
int nt;
int niter;
KMeansCenterInitType init_type;
KmeansParams(bool kmeans_pp = false) {
redo = 1;
nt = std::thread::hardware_concurrency();
niter = 30;
init_type = kmeans_pp ? KMeansCenterInitType::KMEANS_PLUS_PLUS
: KMeansCenterInitType::RANDOM;
}
};
int
nearest_center(uint32_t dim,
const float* centroids,
const size_t centroid_cnt,
const float* point,
const size_t point_cnt,
int* assign,
float* dis);
class Kmeans {
public:
Kmeans() : _rnd(time(0)) {
}
Kmeans(bool kmeans_pp) : _params(kmeans_pp), _rnd(time(0)) {
}
~Kmeans() {
}
float
kmeans(uint32_t dim,
size_t n,
size_t k,
const float* v,
float* centroids_out,
float* dis_out = nullptr,
int* assign_out = nullptr);
KmeansParams&
get_params() {
return _params;
}
protected:
void
random_init_center(const size_t total_cnt,
const size_t sample_cnt,
const uint32_t dim,
const float* train_dataset,
std::vector<size_t>& sample_ids);
//double drand_r(unsigned int* seed);
void
kmeanspp_init_center(const size_t total_cnt,
const size_t sample_cnt,
const uint32_t dim,
const float* train_dataset,
std::vector<size_t>& sample_ids);
int
kmeans_reassign_empty(uint32_t dim,
size_t total_cnt,
size_t k,
float* centroids,
int* assign,
int* nassign);
int
kmeans_core(uint32_t dim,
size_t total_cnt,
size_t k,
int niter,
int nt,
float* centroids,
const float* v,
int* assign,
int* nassign,
float* dis,
double* qerr_out);
private:
int
roulette_selection(std::vector<float>& wheel);
private:
uint32_t _dim;
KmeansParams _params;
std::mt19937 _rnd;
};
} // namespace milvus::puck

View File

@ -0,0 +1,101 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <mutex>
#include <deque>
typedef unsigned int vl_type;
namespace milvus {
class VisitedList {
public:
vl_type curV;
vl_type* mass;
unsigned int numelements;
VisitedList(int numelements1) {
curV = -1;
numelements = numelements1;
mass = new vl_type[numelements];
}
void
reset() {
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
curV++;
}
};
~VisitedList() {
delete[] mass;
}
};
///////////////////////////////////////////////////////////
//
// Class for multi-threaded pool-management of VisitedLists
//
/////////////////////////////////////////////////////////
class VisitedListPool {
std::deque<VisitedList*> pool;
std::mutex poolguard;
int numelements;
public:
VisitedListPool(int initmaxpools, int numelements1) {
numelements = numelements1;
for (int i = 0; i < initmaxpools; i++)
pool.push_front(new VisitedList(numelements));
}
VisitedList*
getFreeVisitedList() {
VisitedList* rez;
{
std::unique_lock<std::mutex> lock(poolguard);
if (pool.size() > 0) {
rez = pool.front();
pool.pop_front();
} else {
rez = new VisitedList(numelements);
}
}
rez->reset();
return rez;
};
void
releaseVisitedList(VisitedList* vl) {
std::unique_lock<std::mutex> lock(poolguard);
pool.push_front(vl);
};
~VisitedListPool() {
while (pool.size()) {
VisitedList* rez = pool.front();
pool.pop_front();
delete rez;
}
};
};
} // namespace milvus

View File

@ -57,6 +57,7 @@ set(MILVUS_TEST_FILES
test_always_true_expr.cpp
test_plan_proto.cpp
test_chunk_cache.cpp
test_fns.cpp
)
if ( BUILD_DISK_ANN STREQUAL "ON" )
@ -125,6 +126,7 @@ target_link_libraries(all_tests
milvus_storage
milvus_indexbuilder
pthread
milvus_fnsearcher
)
install(TARGETS all_tests DESTINATION unittest)

View File

@ -0,0 +1,237 @@
#include <iostream>
#include <random>
#include <gtest/gtest.h>
#include "test_utils/Distance.h"
#include "fns/fns.hpp"
// using namespace milvus;
namespace milvus::fns {
namespace {
void
GenRandom(float* addr, unsigned size) {
for (unsigned i = 0; i < size; ++i) {
addr[i] = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
}
}
void
GenGT(float* query,
float* data,
size_t query_size,
size_t data_size,
size_t dim,
std::vector<std::vector<unsigned>>& gt) {
std::vector<std::vector<unsigned>>(query_size).swap(gt);
size_t gt_size = 100;
for (unsigned i = 0; i < query_size; ++i) {
auto q = query + i * dim;
std::vector<std::pair<float, unsigned>> candidates;
for (unsigned j = 0; j < data_size; ++j) {
auto dist = L2(q, data + j * dim, (int)dim);
candidates.emplace_back(-dist, j);
}
std::sort(candidates.begin(), candidates.end());
for (size_t j = 0; j < gt_size; ++j) {
gt[i].emplace_back(candidates[j].second);
}
}
}
} // namespace
class FurthestNeighborSearchTest : public testing::Test {
protected:
static void
SetUpTestCase() {
data_size = 10000;
data_dim = 128;
query_size = 100;
query_dim = data_dim;
base_data = new float[data_size * data_dim];
query = new float[query_size * query_dim];
for (size_t i = 0; i < data_size; ++i) {
GenRandom(base_data + i * data_dim, data_dim);
}
for (size_t i = 0; i < query_size; ++i) {
GenRandom(query + i * query_dim, query_dim);
}
GenGT(query, base_data, query_size, data_size, query_dim, fns_gt);
fns_ptr_ = new FNS<float>(base_data, data_size, data_dim);
}
static void
TearDownTestCase() {
delete fns_ptr_;
fns_ptr_ = nullptr;
if (base_data) {
delete[] base_data;
base_data = nullptr;
}
if (query) {
delete[] query;
query = nullptr;
}
}
void
SetUp() override {
}
void
TearDown() override {
}
// Some expensive resource shared by all tests.
static FNS<float>* fns_ptr_;
static size_t query_size, query_dim;
static float* base_data;
static float* query;
static std::vector<std::vector<unsigned>> fns_gt;
static size_t data_size, data_dim;
};
FNS<float>* FurthestNeighborSearchTest::fns_ptr_ = nullptr;
size_t FurthestNeighborSearchTest::query_size = 0;
size_t FurthestNeighborSearchTest::query_dim = 0;
float* FurthestNeighborSearchTest::base_data = nullptr;
float* FurthestNeighborSearchTest::query = nullptr;
;
size_t FurthestNeighborSearchTest::data_size = 0;
size_t FurthestNeighborSearchTest::data_dim = 0;
std::vector<std::vector<unsigned>> FurthestNeighborSearchTest::fns_gt =
std::vector<std::vector<unsigned>>();
TEST_F(FurthestNeighborSearchTest, BuildKmeansZeroK) {
BuildParam bp;
bp.kms_cluster_num = 0;
fns_ptr_->setUpBuildPara(bp);
EXPECT_EQ(fns_ptr_->runKmeans(), -1);
}
TEST_F(FurthestNeighborSearchTest, BuildKmeansHugeK) {
BuildParam bp;
bp.kms_cluster_num = FurthestNeighborSearchTest::data_size;
fns_ptr_->setUpBuildPara(bp);
EXPECT_EQ(fns_ptr_->runKmeans(), -1);
}
TEST_F(FurthestNeighborSearchTest, BuildKmeansNormalK) {
BuildParam bp;
bp.kms_cluster_num = (size_t)sqrt(FurthestNeighborSearchTest::data_size);
fns_ptr_->setUpBuildPara(bp);
EXPECT_GT(fns_ptr_->runKmeans(), 0);
}
TEST_F(FurthestNeighborSearchTest, NNDescentHugeS) {
NNDIndexParam nnd_para;
nnd_para.S = 32;
nnd_para.R = 32;
nnd_para.K = 20;
nnd_para.L = 20;
fns_ptr_->setUpNNDPara(nnd_para);
EXPECT_EQ(fns_ptr_->runNNDescent(), -1);
}
TEST_F(FurthestNeighborSearchTest, NNDescentNormalS) {
NNDIndexParam nnd_para;
fns_ptr_->setUpNNDPara(nnd_para);
EXPECT_EQ(fns_ptr_->runNNDescent(), 0);
}
TEST_F(FurthestNeighborSearchTest, BuildIndexing) {
BuildParam bp;
bp.kms_cluster_num = (size_t)sqrt(FurthestNeighborSearchTest::data_size);
bp.filter_ctrl_size = 100; //// num to expand FNSs
bp.filter_pos = 3; //// TOP `pos` neighbors to calculate RVS density
bp.filter_density =
1; //// max-density to judge whether a point is a low-density point.
bp.filter_max_edges =
2048; //// max edges to re-ranking ("brute-force" to get top-furthest points. )
fns_ptr_->setUpBuildPara(bp);
NNDIndexParam nnd_para;
fns_ptr_->setUpNNDPara(nnd_para);
EXPECT_EQ(fns_ptr_->build(), 0);
}
TEST_F(FurthestNeighborSearchTest, SearchFurthestNeighborsMinusOne) {
fns_ptr_->setSearchK(-1);
#pragma omp parallel for
for (auto iter = 0; iter < query_size; ++iter) {
auto res = fns_ptr_->searchFNS(query + iter * query_dim);
EXPECT_EQ(res.size(), 0);
}
}
TEST_F(FurthestNeighborSearchTest, SearchFurthestNeighborsZero) {
fns_ptr_->setSearchK(0);
#pragma omp parallel for
for (auto iter = 0; iter < query_size; ++iter) {
auto res = fns_ptr_->searchFNS(query + iter * query_dim);
EXPECT_EQ(res.size(), 0);
}
}
TEST_F(FurthestNeighborSearchTest, SearchFurthestNeighborsTop100) {
fns_ptr_->setSearchK(100);
#pragma omp parallel for
for (auto iter = 0; iter < query_size; ++iter) {
auto res = fns_ptr_->searchFNS(query + iter * query_dim);
EXPECT_EQ(res.size(), 100);
}
}
TEST_F(FurthestNeighborSearchTest, SearchFurthestNeighborsHuge) {
fns_ptr_->setSearchK(1E9);
#pragma omp parallel for
for (auto iter = 0; iter < query_size; ++iter) {
auto res = fns_ptr_->searchFNS(query + iter * query_dim);
EXPECT_EQ(res.size(), 0);
}
}
TEST_F(FurthestNeighborSearchTest, SearchFurthestNeighborsBenchMark) {
fns_ptr_->setSearchK(100);
graph_t FNS_res(query_size);
auto start = std::chrono::steady_clock::now();
#pragma omp parallel for
for (auto iter = 0; iter < query_size; ++iter) {
auto res = fns_ptr_->searchFNS(query + iter * query_dim);
vector<idx_t> tmp;
while (res.size()) {
tmp.emplace_back(res.top().second);
res.pop();
}
reverse(tmp.begin(), tmp.end());
FNS_res[iter] = tmp;
}
auto end = std::chrono::steady_clock::now();
auto duration =
(double)1.0 *
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count() /
1000.0; // seconds
std::cout << "duration = " << duration << " s\n";
std::cout << "QPS = " << (1.0 * query_size / duration) << std::endl;
std::cout << "avg_ratio@1 = "
<< (fns_ptr_->evaluateRatio(
fns_gt, FNS_res, query_size, query, 1))
<< std::endl;
std::cout << "avg_ratio@10 = "
<< (fns_ptr_->evaluateRatio(
fns_gt, FNS_res, query_size, query, 10))
<< std::endl;
std::cout << "avg_ratio@100 = "
<< (fns_ptr_->evaluateRatio(
fns_gt, FNS_res, query_size, query, 100))
<< std::endl;
std::cout << "recall@1 = " << (fns_ptr_->evaluateRecall(fns_gt, FNS_res, 1))
<< std::endl;
std::cout << "recall@10 = "
<< (fns_ptr_->evaluateRecall(fns_gt, FNS_res, 10)) << std::endl;
std::cout << "recall@100 = "
<< (fns_ptr_->evaluateRecall(fns_gt, FNS_res, 100)) << std::endl;
}
} // namespace milvus::fns