mirror of https://github.com/milvus-io/milvus.git
Merge branch 'update_unittest' into 'branch-0.4.0'
MS-538 1. update kdt unittest See merge request megasearch/milvus!539 Former-commit-id: eca8fb73f491f73ddc2775a26d0b03db9dc92fe2pull/191/head
commit
941e4dc83c
|
@ -1,26 +1,26 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "preprocessor.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
class NormalizePreprocessor : public Preprocessor {
|
||||
public:
|
||||
DatasetPtr
|
||||
Preprocess(const DatasetPtr &input) override;
|
||||
|
||||
private:
|
||||
|
||||
void
|
||||
Normalize(float *arr, int64_t dimension);
|
||||
};
|
||||
|
||||
|
||||
using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
|
||||
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
//#pragma once
|
||||
//
|
||||
//#include <memory>
|
||||
//#include "preprocessor.h"
|
||||
//
|
||||
//
|
||||
//namespace zilliz {
|
||||
//namespace knowhere {
|
||||
//
|
||||
//class NormalizePreprocessor : public Preprocessor {
|
||||
// public:
|
||||
// DatasetPtr
|
||||
// Preprocess(const DatasetPtr &input) override;
|
||||
//
|
||||
// private:
|
||||
//
|
||||
// void
|
||||
// Normalize(float *arr, int64_t dimension);
|
||||
//};
|
||||
//
|
||||
//
|
||||
//using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
|
||||
//
|
||||
//
|
||||
//} // namespace knowhere
|
||||
//} // namespace zilliz
|
||||
|
|
|
@ -27,8 +27,8 @@ class CPUKDTRNG : public VectorIndex {
|
|||
Load(const BinarySet &index_array) override;
|
||||
|
||||
public:
|
||||
PreprocessorPtr
|
||||
BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
|
||||
//PreprocessorPtr
|
||||
//BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
|
||||
int64_t Count() override;
|
||||
int64_t Dimension() override;
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
|
||||
#include "knowhere/index/vector_index/definitions.h"
|
||||
#include "knowhere/common/config.h"
|
||||
#include "knowhere/index/preprocessor/normalize.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace knowhere {
|
||||
|
||||
DatasetPtr
|
||||
NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
|
||||
//
|
||||
//#include "knowhere/index/vector_index/definitions.h"
|
||||
//#include "knowhere/common/config.h"
|
||||
//#include "knowhere/index/preprocessor/normalize.h"
|
||||
//
|
||||
//
|
||||
//namespace zilliz {
|
||||
//namespace knowhere {
|
||||
//
|
||||
//DatasetPtr
|
||||
//NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
|
||||
// // TODO: wrap dataset->tensor
|
||||
// auto tensor = dataset->tensor()[0];
|
||||
// auto p_data = (float *)tensor->raw_mutable_data();
|
||||
|
@ -19,24 +19,24 @@ NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
|
|||
// for (auto i = 0; i < rows; ++i) {
|
||||
// Normalize(&(p_data[i * dimension]), dimension);
|
||||
// }
|
||||
}
|
||||
|
||||
void
|
||||
NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
|
||||
//double vector_length = 0;
|
||||
//for (auto j = 0; j < dimension; j++) {
|
||||
// double val = arr[j];
|
||||
// vector_length += val * val;
|
||||
//}
|
||||
//vector_length = std::sqrt(vector_length);
|
||||
//if (vector_length < 1e-6) {
|
||||
// auto val = (float) (1.0 / std::sqrt((double) dimension));
|
||||
// for (int j = 0; j < dimension; j++) arr[j] = val;
|
||||
//} else {
|
||||
// for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length);
|
||||
//}
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace zilliz
|
||||
//}
|
||||
//
|
||||
//void
|
||||
//NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
|
||||
// double vector_length = 0;
|
||||
// for (auto j = 0; j < dimension; j++) {
|
||||
// double val = arr[j];
|
||||
// vector_length += val * val;
|
||||
// }
|
||||
// vector_length = std::sqrt(vector_length);
|
||||
// if (vector_length < 1e-6) {
|
||||
// auto val = (float) (1.0 / std::sqrt((double) dimension));
|
||||
// for (int j = 0; j < dimension; j++) arr[j] = val;
|
||||
// } else {
|
||||
// for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length);
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//} // namespace knowhere
|
||||
//} // namespace zilliz
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
|
||||
#include "knowhere/index/vector_index/definitions.h"
|
||||
#include "knowhere/index/preprocessor/normalize.h"
|
||||
//#include "knowhere/index/preprocessor/normalize.h"
|
||||
#include "knowhere/index/vector_index/kdt_parameters.h"
|
||||
#include "knowhere/adapter/sptag.h"
|
||||
#include "knowhere/common/exception.h"
|
||||
|
@ -60,10 +60,10 @@ CPUKDTRNG::Load(const BinarySet &binary_set) {
|
|||
index_ptr_->LoadIndexFromMemory(index_blobs);
|
||||
}
|
||||
|
||||
PreprocessorPtr
|
||||
CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
|
||||
return std::make_shared<NormalizePreprocessor>();
|
||||
}
|
||||
//PreprocessorPtr
|
||||
//CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
|
||||
// return std::make_shared<NormalizePreprocessor>();
|
||||
//}
|
||||
|
||||
IndexModelPtr
|
||||
CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
|
||||
|
@ -72,7 +72,7 @@ CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
|
|||
|
||||
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// && preprocessor_) {
|
||||
preprocessor_->Preprocess(dataset);
|
||||
// preprocessor_->Preprocess(dataset);
|
||||
//}
|
||||
|
||||
auto vectorset = ConvertToVectorSet(dataset);
|
||||
|
@ -90,7 +90,7 @@ CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) {
|
|||
|
||||
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
|
||||
// && preprocessor_) {
|
||||
preprocessor_->Preprocess(dataset);
|
||||
// preprocessor_->Preprocess(dataset);
|
||||
//}
|
||||
|
||||
auto vectorset = ConvertToVectorSet(dataset);
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "knowhere/common/exception.h"
|
||||
|
||||
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
|
||||
#include "knowhere/index/vector_index/definitions.h"
|
||||
|
@ -125,6 +126,10 @@ TEST_P(KDTTest, kdt_serialize) {
|
|||
auto result = new_index->Search(query_dataset, search_cfg);
|
||||
AssertAnns(result, nq, k);
|
||||
PrintResult(result, nq, k);
|
||||
ASSERT_EQ(new_index->Count(), nb);
|
||||
ASSERT_EQ(new_index->Dimension(), dim);
|
||||
ASSERT_THROW({new_index->Clone();}, zilliz::knowhere::KnowhereException);
|
||||
ASSERT_NO_THROW({new_index->Seal();});
|
||||
|
||||
{
|
||||
int fileno = 0;
|
||||
|
|
Loading…
Reference in New Issue