mirror of https://github.com/milvus-io/milvus.git
* fix Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * update... Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * fix2 Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * fix3 Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * update changelog Signed-off-by: Nicky <nicky.xj.lin@gmail.com>pull/1206/head
parent
f10f6cd5f4
commit
191a8c9941
|
@ -34,6 +34,7 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#813 - Add push mode for prometheus monitor
|
||||
- \#815 - Support MinIO storage
|
||||
- \#823 - Support binary vector tanimoto/jaccard/hamming metric
|
||||
- \#853 - Support HNSW
|
||||
- \#910 - Change Milvus c++ standard to c++17
|
||||
|
||||
## Improvement
|
||||
|
|
|
@ -15,20 +15,21 @@
|
|||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexHNSW.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexHNSW.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
#include "hnswlib/hnswalg.h"
|
||||
#include "hnswlib/space_ip.h"
|
||||
#include "hnswlib/space_l2.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -71,39 +72,37 @@ IndexHNSW::Load(const BinarySet& index_binary) {
|
|||
|
||||
DatasetPtr
|
||||
IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
auto search_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
|
||||
if (search_cfg != nullptr) {
|
||||
search_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
size_t id_size = sizeof(int64_t) * config->k;
|
||||
size_t dist_size = sizeof(float) * config->k;
|
||||
auto p_id = (int64_t*)malloc(id_size * rows);
|
||||
auto p_dist = (float*)malloc(dist_size * rows);
|
||||
|
||||
using P = std::pair<float, int64_t>;
|
||||
auto compare = [](P v1, P v2) { return v1.second < v2.second; };
|
||||
std::vector<std::pair<float, int64_t>> ret = index_->searchKnn(p_data, search_cfg->k, compare);
|
||||
auto compare = [](P v1, P v2) { return v1.first < v2.first; };
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
const float* single_query = p_data + i * dim;
|
||||
std::vector<std::pair<float, int64_t>> ret = index_->searchKnn(single_query, config->k, compare);
|
||||
std::vector<float> dist;
|
||||
std::vector<int64_t> ids;
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(dist),
|
||||
[](const std::pair<float, int64_t>& e) { return e.first; });
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
|
||||
[](const std::pair<float, int64_t>& e) { return e.second; });
|
||||
|
||||
std::vector<float> dist(ret.size());
|
||||
std::vector<int64_t> ids(ret.size());
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(dist),
|
||||
[](const std::pair<float, int64_t>& e) { return e.first; });
|
||||
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
|
||||
[](const std::pair<float, int64_t>& e) { return e.second; });
|
||||
|
||||
auto elems = rows * search_cfg->k;
|
||||
assert(elems == ret.size());
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
memcpy(p_dist, dist.data(), dist.size() * sizeof(float));
|
||||
memcpy(p_id, ids.data(), ids.size() * sizeof(int64_t));
|
||||
memcpy(p_dist + i * config->k, dist.data(), dist_size);
|
||||
memcpy(p_id + i * config->k, ids.data(), id_size);
|
||||
}
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
|
|
|
@ -40,6 +40,7 @@ enum class IndexType {
|
|||
IVFPQ = 6,
|
||||
SPTAGKDT = 7,
|
||||
SPTAGBKT = 8,
|
||||
HNSW = 11,
|
||||
};
|
||||
|
||||
enum class MetricType {
|
||||
|
|
Loading…
Reference in New Issue