mirror of https://github.com/milvus-io/milvus.git
Update knowhere (#5006)
Import performance of ivf::train and hnsw, and fix bugs Signed-off-by: shengjun.li <shengjun.li@zilliz.com>pull/5018/head
parent
7aaae3f98c
commit
a2875f9d95
|
@ -27,6 +27,7 @@ namespace knowhere {
|
|||
|
||||
static const int64_t MIN_NBITS = 1;
|
||||
static const int64_t MAX_NBITS = 16;
|
||||
static const int64_t DEFAULT_NBITS = 8;
|
||||
static const int64_t MIN_NLIST = 1;
|
||||
static const int64_t MAX_NLIST = 65536;
|
||||
static const int64_t MIN_NPROBE = 1;
|
||||
|
@ -91,7 +92,7 @@ ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode m
|
|||
|
||||
int64_t
|
||||
MatchNlist(int64_t size, int64_t nlist) {
|
||||
const int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
const int64_t MIN_POINTS_PER_CENTROID = 39;
|
||||
|
||||
if (nlist * MIN_POINTS_PER_CENTROID > size) {
|
||||
// nlist is too large, adjust to a proper value
|
||||
|
@ -146,9 +147,7 @@ IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod
|
|||
|
||||
bool
|
||||
IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
const int64_t DEFAULT_NBITS = 8;
|
||||
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
|
||||
|
||||
return IVFConfAdapter::CheckTrain(oricfg, mode);
|
||||
}
|
||||
|
||||
|
@ -161,7 +160,7 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
|||
CheckIntByRange(knowhere::IndexParams::nbits, MIN_NBITS, MAX_NBITS);
|
||||
|
||||
auto rows = oricfg[knowhere::meta::ROWS].get<int64_t>();
|
||||
auto nbits = oricfg[knowhere::IndexParams::nbits].get<int64_t>();
|
||||
auto nbits = oricfg.count(IndexParams::nbits) ? oricfg[IndexParams::nbits].get<int64_t>() : DEFAULT_NBITS;
|
||||
oricfg[knowhere::IndexParams::nbits] = MatchNbits(rows, nbits);
|
||||
|
||||
auto m = oricfg[knowhere::IndexParams::m].get<int64_t>();
|
||||
|
|
|
@ -83,8 +83,6 @@ IndexHNSW::Load(const BinarySet& index_binary) {
|
|||
}
|
||||
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Load finished, show statistics:";
|
||||
// LOG_KNOWHERE_DEBUG_ << hnsw_stats->ToString();
|
||||
|
||||
normalize = index_->metric_type_ == 1; // 1 == InnerProduct
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
|
@ -102,7 +100,6 @@ IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
|||
space = new hnswlib::L2Space(dim);
|
||||
} else if (metric_type == Metric::IP) {
|
||||
space = new hnswlib::InnerProductSpace(dim);
|
||||
normalize = true;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type);
|
||||
}
|
||||
|
@ -142,7 +139,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
|
|||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
GET_TENSOR_DATA(dataset_ptr)
|
||||
GET_TENSOR_DATA_DIM(dataset_ptr)
|
||||
|
||||
size_t k = config[meta::TOPK].get<int64_t>();
|
||||
size_t id_size = sizeof(int64_t) * k;
|
||||
|
@ -159,44 +156,39 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
|
|||
}
|
||||
|
||||
index_->setEf(config[IndexParams::ef].get<int64_t>());
|
||||
|
||||
using P = std::pair<float, int64_t>;
|
||||
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
|
||||
bool transform = (index_->metric_type_ == 1); // InnerProduct: 1
|
||||
|
||||
std::chrono::high_resolution_clock::time_point query_start, query_end;
|
||||
query_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
std::vector<P> ret;
|
||||
const float* single_query = reinterpret_cast<const float*>(p_data) + i * Dim();
|
||||
|
||||
auto single_query = (float*)p_data + i * dim;
|
||||
std::priority_queue<std::pair<float, hnswlib::labeltype>> rst;
|
||||
if (STATISTICS_LEVEL >= 3) {
|
||||
ret = index_->searchKnn(single_query, k, compare, bitset, query_stats[i]);
|
||||
rst = index_->searchKnn(single_query, k, bitset, query_stats[i]);
|
||||
} else {
|
||||
auto dummy_stat = hnswlib::StatisticsInfo();
|
||||
ret = index_->searchKnn(single_query, k, compare, bitset, dummy_stat);
|
||||
rst = index_->searchKnn(single_query, k, bitset, dummy_stat);
|
||||
}
|
||||
size_t rst_size = rst.size();
|
||||
|
||||
while (ret.size() < k) {
|
||||
ret.emplace_back(std::make_pair(-1, -1));
|
||||
auto p_single_dis = p_dist + i * k;
|
||||
auto p_single_id = p_id + i * k;
|
||||
size_t idx = rst_size - 1;
|
||||
while (!rst.empty()) {
|
||||
auto& it = rst.top();
|
||||
p_single_dis[idx] = transform ? (1 - it.first) : it.first;
|
||||
p_single_id[idx] = it.second;
|
||||
rst.pop();
|
||||
idx--;
|
||||
}
|
||||
std::vector<float> dist;
|
||||
std::vector<int64_t> ids;
|
||||
MapOffsetToUid(p_single_id, rst_size);
|
||||
|
||||
if (normalize) {
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(dist),
|
||||
[](const std::pair<float, int64_t>& e) { return float(1 - e.first); });
|
||||
} else {
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(dist),
|
||||
[](const std::pair<float, int64_t>& e) { return e.first; });
|
||||
for (idx = rst_size; idx < k; idx++) {
|
||||
p_single_dis[idx] = float(1.0 / 0.0);
|
||||
p_single_id[idx] = -1;
|
||||
}
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
|
||||
[](const std::pair<float, int64_t>& e) { return e.second; });
|
||||
|
||||
MapOffsetToUid(ids.data(), ids.size());
|
||||
memcpy(p_dist + i * k, dist.data(), dist_size);
|
||||
memcpy(p_id + i * k, ids.data(), id_size);
|
||||
}
|
||||
query_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
|
|
|
@ -56,7 +56,6 @@ class IndexHNSW : public VecIndex {
|
|||
ClearStatistics() override;
|
||||
|
||||
private:
|
||||
bool normalize = false;
|
||||
std::shared_ptr<hnswlib::HierarchicalNSW<float>> index_;
|
||||
};
|
||||
|
||||
|
|
|
@ -19,10 +19,10 @@ std::shared_ptr<SPTAG::MetadataSet>
|
|||
ConvertToMetadataSet(const DatasetPtr& dataset_ptr) {
|
||||
auto elems = dataset_ptr->Get<int64_t>(meta::ROWS);
|
||||
|
||||
auto p_id = (int64_t*)malloc(sizeof(int64_t) * elems);
|
||||
auto p_id = new int64_t[elems];
|
||||
for (int64_t i = 0; i < elems; ++i) p_id[i] = i;
|
||||
|
||||
auto p_offset = (int64_t*)malloc(sizeof(int64_t) * (elems + 1));
|
||||
auto p_offset = new int64_t[elems + 1];
|
||||
for (int64_t i = 0; i <= elems; ++i) p_offset[i] = i * 8;
|
||||
|
||||
std::shared_ptr<SPTAG::MetadataSet> metaset(
|
||||
|
|
|
@ -1098,12 +1098,16 @@ void elkan_L2_sse (
|
|||
return (i > j) ? data[j + i * (i - 1) / 2] : data[i + j * (j - 1) / 2];
|
||||
};
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = j0 + 1; i < j1; i++) {
|
||||
const float *y_i = y + i * d;
|
||||
for (size_t j = j0; j < i; j++) {
|
||||
const float *y_j = y + j * d;
|
||||
Y(i, j) = sqrt(fvec_L2sqr(y_i, y_j, d));
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nt = omp_get_num_threads();
|
||||
int rank = omp_get_thread_num();
|
||||
for (size_t i = j0 + 1 + rank; i < j1; i += nt) {
|
||||
const float *y_i = y + i * d;
|
||||
for (size_t j = j0; j < i; j++) {
|
||||
const float *y_j = y + j * d;
|
||||
Y(i, j) = fvec_L2sqr(y_i, y_j, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1112,18 +1116,22 @@ void elkan_L2_sse (
|
|||
const float *x_i = x + i * d;
|
||||
|
||||
int64_t ids_i = j0;
|
||||
float val_i = sqrt(fvec_L2sqr(x_i, y + j0 * d, d));
|
||||
float val_i_2 = val_i * 2;
|
||||
float val_i = fvec_L2sqr(x_i, y + j0 * d, d);
|
||||
float val_i_time_4 = val_i * 4;
|
||||
for (size_t j = j0 + 1; j < j1; j++) {
|
||||
if (val_i_2 <= Y(ids_i, j)) {
|
||||
if (val_i_time_4 <= Y(ids_i, j)) {
|
||||
continue;
|
||||
}
|
||||
const float *y_j = y + j * d;
|
||||
float disij = sqrt(fvec_L2sqr(x_i, y_j, d));
|
||||
float disij = fvec_L2sqr(x_i, y_j, d / 2);
|
||||
if (disij >= val_i) {
|
||||
continue;
|
||||
}
|
||||
disij += fvec_L2sqr(x_i + d / 2, y_j + d / 2, d - d / 2);
|
||||
if (disij < val_i) {
|
||||
ids_i = j;
|
||||
val_i = disij;
|
||||
val_i_2 = val_i * 2;
|
||||
val_i_time_4 = val_i * 4;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -317,50 +317,54 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
return top_candidates;
|
||||
}
|
||||
|
||||
void getNeighborsByHeuristic2(
|
||||
std::vector<tableint>
|
||||
getNeighborsByHeuristic2 (
|
||||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
||||
const size_t M) {
|
||||
if (top_candidates.size() < M) {
|
||||
return;
|
||||
}
|
||||
std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
|
||||
std::vector<std::pair<dist_t, tableint>> return_list;
|
||||
while (top_candidates.size() > 0) {
|
||||
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
|
||||
top_candidates.pop();
|
||||
}
|
||||
std::vector<tableint> return_list;
|
||||
|
||||
while (queue_closest.size()) {
|
||||
if (return_list.size() >= M)
|
||||
break;
|
||||
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
|
||||
dist_t dist_to_query = -curent_pair.first;
|
||||
queue_closest.pop();
|
||||
bool good = true;
|
||||
for (std::pair<dist_t, tableint> second_pair : return_list) {
|
||||
dist_t curdist =
|
||||
fstdistfunc_(getDataByInternalId(second_pair.second),
|
||||
getDataByInternalId(curent_pair.second),
|
||||
dist_func_param_);;
|
||||
if (curdist < dist_to_query) {
|
||||
good = false;
|
||||
break;
|
||||
if (top_candidates.size() < M) {
|
||||
return_list.resize(top_candidates.size());
|
||||
|
||||
for (int i = static_cast<int>(top_candidates.size() - 1); i >= 0; i--) {
|
||||
return_list[i] = top_candidates.top().second;
|
||||
top_candidates.pop();
|
||||
}
|
||||
|
||||
} else if (M > 0) {
|
||||
return_list.reserve(M);
|
||||
|
||||
std::vector<std::pair<dist_t, tableint>> queue_closest;
|
||||
queue_closest.resize(top_candidates.size());
|
||||
for (int i = static_cast<int>(top_candidates.size() - 1); i >= 0; i--) {
|
||||
queue_closest[i] = top_candidates.top();
|
||||
top_candidates.pop();
|
||||
}
|
||||
|
||||
for (std::pair<dist_t, tableint> ¤t_pair: queue_closest) {
|
||||
bool good = true;
|
||||
for (tableint id : return_list) {
|
||||
dist_t curdist =
|
||||
fstdistfunc_(getDataByInternalId(id),
|
||||
getDataByInternalId(current_pair.second),
|
||||
dist_func_param_);
|
||||
if (curdist < current_pair.first) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (good) {
|
||||
return_list.push_back(current_pair.second);
|
||||
if (return_list.size() >= M) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (good) {
|
||||
return_list.push_back(curent_pair);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
for (std::pair<dist_t, tableint> curent_pair : return_list) {
|
||||
|
||||
top_candidates.emplace(-curent_pair.first, curent_pair.second);
|
||||
}
|
||||
return return_list;
|
||||
}
|
||||
|
||||
|
||||
linklistsizeint *get_linklist0(tableint internal_id) const {
|
||||
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
|
||||
};
|
||||
|
@ -373,21 +377,17 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
|
||||
};
|
||||
|
||||
void mutuallyConnectNewElement(const void *data_point, tableint cur_c,
|
||||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
||||
int level) {
|
||||
tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
|
||||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
||||
int level) {
|
||||
|
||||
size_t Mcurmax = level ? maxM_ : maxM0_;
|
||||
getNeighborsByHeuristic2(top_candidates, M_);
|
||||
if (top_candidates.size() > M_)
|
||||
|
||||
std::vector<tableint> selectedNeighbors(getNeighborsByHeuristic2(top_candidates, M_));
|
||||
if (selectedNeighbors.size() > M_)
|
||||
throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
|
||||
|
||||
std::vector<tableint> selectedNeighbors;
|
||||
selectedNeighbors.reserve(M_);
|
||||
while (top_candidates.size() > 0) {
|
||||
selectedNeighbors.push_back(top_candidates.top().second);
|
||||
top_candidates.pop();
|
||||
}
|
||||
tableint next_closest_entry_point = selectedNeighbors.front();
|
||||
|
||||
{
|
||||
linklistsizeint *ll_cur;
|
||||
|
@ -451,15 +451,11 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
dist_func_param_), data[j]);
|
||||
}
|
||||
|
||||
getNeighborsByHeuristic2(candidates, Mcurmax);
|
||||
|
||||
int indx = 0;
|
||||
while (candidates.size() > 0) {
|
||||
data[indx] = candidates.top().second;
|
||||
candidates.pop();
|
||||
indx++;
|
||||
std::vector<tableint> selected(getNeighborsByHeuristic2(candidates, Mcurmax));
|
||||
setListCount(ll_other, static_cast<unsigned short int>(selected.size()));
|
||||
for (size_t idx = 0; idx < selected.size(); idx++) {
|
||||
data[idx] = selected[idx];
|
||||
}
|
||||
setListCount(ll_other, indx);
|
||||
// Nearest K:
|
||||
/*int indx = -1;
|
||||
for (int j = 0; j < sz_link_list_other; j++) {
|
||||
|
@ -475,6 +471,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
return next_closest_entry_point;
|
||||
}
|
||||
|
||||
std::mutex global;
|
||||
|
@ -499,17 +497,18 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
|
||||
|
||||
|
||||
// Reallocate base layer
|
||||
data_level0_memory_ = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
|
||||
if (data_level0_memory_ == nullptr)
|
||||
char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
|
||||
if (data_level0_memory_new == nullptr)
|
||||
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
|
||||
data_level0_memory_ = data_level0_memory_new;
|
||||
|
||||
// Reallocate all other layers
|
||||
linkLists_ = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
|
||||
if (linkLists_ == nullptr)
|
||||
char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
|
||||
if (linkLists_new == nullptr)
|
||||
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
|
||||
linkLists_ = linkLists_new;
|
||||
|
||||
max_elements_=new_max_elements;
|
||||
max_elements_ = new_max_elements;
|
||||
|
||||
}
|
||||
|
||||
|
@ -814,9 +813,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
}
|
||||
|
||||
std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
|
||||
int curlevel = getRandomLevel(mult_);
|
||||
if (level > 0)
|
||||
curlevel = level;
|
||||
int curlevel = (level > 0) ? level : getRandomLevel(mult_);
|
||||
|
||||
element_levels_[cur_c] = curlevel;
|
||||
|
||||
|
@ -881,9 +878,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
|
||||
currObj, data_point, level);
|
||||
|
||||
mutuallyConnectNewElement(data_point, cur_c, top_candidates, level);
|
||||
|
||||
currObj = top_candidates.top().second;
|
||||
currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level);
|
||||
}
|
||||
} else {
|
||||
// Do nothing for the first element
|
||||
|
@ -956,24 +951,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|||
return result;
|
||||
};
|
||||
|
||||
template <typename Comp>
|
||||
std::vector<std::pair<dist_t, labeltype>>
|
||||
searchKnn(const void* query_data, size_t k, Comp comp, const faiss::BitsetView bitset, StatisticsInfo &stats) {
|
||||
std::vector<std::pair<dist_t, labeltype>> result;
|
||||
if (cur_element_count == 0) return result;
|
||||
|
||||
auto ret = searchKnn(query_data, k, bitset, stats);
|
||||
|
||||
while (!ret.empty()) {
|
||||
result.push_back(ret.top());
|
||||
ret.pop();
|
||||
}
|
||||
|
||||
std::sort(result.begin(), result.end(), comp);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t cal_size() {
|
||||
int64_t ret = 0;
|
||||
ret += sizeof(*this);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -92,10 +92,10 @@ namespace hnswlib {
|
|||
class AlgorithmInterface {
|
||||
public:
|
||||
virtual void addPoint(const void *datapoint, labeltype label)=0;
|
||||
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t, const faiss::BitsetView bitset, hnswlib::StatisticsInfo &stats) const = 0;
|
||||
template <typename Comp>
|
||||
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp, const faiss::BitsetView bitset, hnswlib::StatisticsInfo &stats) {
|
||||
}
|
||||
|
||||
virtual std::priority_queue<std::pair<dist_t, labeltype >>
|
||||
searchKnn(const void *, size_t, const faiss::BitsetView bitset, hnswlib::StatisticsInfo &stats) const = 0;
|
||||
|
||||
virtual void saveIndex(const std::string &location)=0;
|
||||
virtual ~AlgorithmInterface(){
|
||||
}
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
#pragma once
|
||||
#ifndef NO_MANUAL_VECTORIZATION
|
||||
#ifdef __SSE__
|
||||
#define USE_SSE
|
||||
#ifdef __AVX__
|
||||
#define USE_AVX
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX) || defined(USE_SSE)
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#include <stdexcept>
|
||||
#else
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
|
||||
#else
|
||||
#define PORTABLE_ALIGN32 __declspec(align(32))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <fstream>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include <string.h>
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <faiss/utils/BitsetView.h>
|
||||
|
||||
namespace hnswlib_nm {
|
||||
typedef int64_t labeltype;
|
||||
|
||||
template <typename T>
|
||||
class pairGreater {
|
||||
public:
|
||||
bool operator()(const T& p1, const T& p2) {
|
||||
return p1.first > p2.first;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
|
||||
out.write((char *) &podRef, sizeof(T));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void readBinaryPOD(std::istream &in, T &podRef) {
|
||||
in.read((char *) &podRef, sizeof(T));
|
||||
}
|
||||
|
||||
template<typename T, typename W>
|
||||
static void writeBinaryPOD(W &out, const T &podRef) {
|
||||
out.write((char *) &podRef, sizeof(T));
|
||||
}
|
||||
|
||||
template<typename T, typename R>
|
||||
static void readBinaryPOD(R &in, T &podRef) {
|
||||
in.read((char *) &podRef, sizeof(T));
|
||||
}
|
||||
|
||||
template<typename MTYPE>
|
||||
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
|
||||
|
||||
|
||||
template<typename MTYPE>
|
||||
class SpaceInterface {
|
||||
public:
|
||||
//virtual void search(void *);
|
||||
virtual size_t get_data_size() = 0;
|
||||
|
||||
virtual DISTFUNC<MTYPE> get_dist_func() = 0;
|
||||
|
||||
virtual void *get_dist_func_param() = 0;
|
||||
|
||||
virtual ~SpaceInterface() {}
|
||||
};
|
||||
|
||||
template<typename dist_t>
|
||||
class AlgorithmInterface {
|
||||
public:
|
||||
virtual void addPoint(void *datapoint, labeltype label, size_t base, size_t offset)=0;
|
||||
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn_NM(const void *, size_t, const faiss::BitsetView bitset, dist_t *pdata) const = 0;
|
||||
template <typename Comp>
|
||||
std::vector<std::pair<dist_t, labeltype>> searchKnn_NM(const void*, size_t, Comp, const faiss::BitsetView bitset, dist_t *pdata) {
|
||||
}
|
||||
virtual void saveIndex(const std::string &location)=0;
|
||||
virtual ~AlgorithmInterface(){
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#include "space_l2.h"
|
||||
#include "space_ip.h"
|
||||
#include "bruteforce.h"
|
||||
#include "hnswalg_nm.h"
|
|
@ -80,6 +80,20 @@ TEST_P(HNSWTest, HNSW_basic) {
|
|||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
// case: k > nb
|
||||
const int64_t new_rows = 6;
|
||||
base_dataset->Set(milvus::knowhere::meta::ROWS, new_rows);
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf, nullptr);
|
||||
auto res_ids = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
for (int64_t j = new_rows; j < k; j++) {
|
||||
ASSERT_EQ(res_ids[i * k + j], -1);
|
||||
}
|
||||
}
|
||||
ReleaseQueryResult(result2);
|
||||
}
|
||||
|
||||
TEST_P(HNSWTest, HNSW_delete) {
|
||||
|
|
Loading…
Reference in New Issue