mirror of https://github.com/milvus-io/milvus.git
Tanimoto distance (#1016)
* Add log to debug #678 * Rename nsg_mix to RNSG in C++ sdk #735 * [skip ci] change __function__ * clang-format * #766 If partition tag is similar, wrong partition is searched * #766 If partition tag is similar, wrong partition is searched * reorder changelog id * typo * define interface * Define interface (#832) * If partition tag is similar, wrong partition is searched (#825) * #766 If partition tag is similar, wrong partition is searched * #766 If partition tag is similar, wrong partition is searched * reorder changelog id * typo * define interface Attach files by dragging & dropping, selecting or pasting them. Co-authored-by: groot <yihua.mo@zilliz.com> * faiss & knowhere * faiss & knowhere (#842) * Add log to debug #678 * Rename nsg_mix to RNSG in C++ sdk #735 * [skip ci] change __function__ * clang-format * If partition tag is similar, wrong partition is searched (#825) * #766 If partition tag is similar, wrong partition is searched * #766 If partition tag is similar, wrong partition is searched * reorder changelog id * typo * faiss & knowhere Co-authored-by: groot <yihua.mo@zilliz.com> * support binary input * code lint * add wrapper interface * add knowhere unittest * sdk support binary * support using metric tanimoto and hamming * sdk binary insert/query example * fix bug * fix bug * update wrapper * format * Improve unittest and fix bugs * delete printresult * fix bug * #823 Support binary vector tanimoto metric * fix typo * dimension limit to 32768 * fix * dimension limit to 32768 * fix describe index bug * fix #886 * fix #889 * add jaccard cases * hamming dev-test case * change test_connect * Add tanimoto cases * change the output type of hamming * add abs * merge master * rearrange changelog id * modify feature description Co-authored-by: Yukikaze-CZR <48198922+Yukikaze-CZR@users.noreply.github.com> Co-authored-by: Tinkerrr <linxiaojun.cn@outlook.com>pull/1049/head^2
parent
297e7e8831
commit
0f1aa5f8bb
|
@ -16,12 +16,13 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#216 - Add CLI to get server info
|
||||
- \#343 - Add Opentracing
|
||||
- \#665 - Support get/set config via CLI
|
||||
- \#759 - Put C++ sdk out of milvus/core
|
||||
- \#766 - If partition tag is similar, wrong partition is searched
|
||||
- \#771 - Add server build commit info interface
|
||||
- \#759 - Put C++ sdk out of milvus/core
|
||||
- \#788 - Add web server into server module
|
||||
- \#813 - Add push mode for prometheus monitor
|
||||
- \#815 - Support MinIO storage
|
||||
- \#823 - Support binary vector tanimoto/jaccard/hamming metric
|
||||
- \#910 - Change Milvus c++ standard to c++17
|
||||
|
||||
## Improvement
|
||||
|
|
|
@ -29,7 +29,7 @@ constexpr uint64_t T = K * G;
|
|||
|
||||
constexpr uint64_t MAX_TABLE_FILE_MEM = 128 * M;
|
||||
|
||||
constexpr int VECTOR_TYPE_SIZE = sizeof(float);
|
||||
constexpr int FLOAT_TYPE_SIZE = sizeof(float);
|
||||
|
||||
static constexpr uint64_t ONE_KB = K;
|
||||
static constexpr uint64_t ONE_MB = ONE_KB * ONE_KB;
|
||||
|
|
|
@ -84,25 +84,22 @@ class DB {
|
|||
ShowPartitions(const std::string& table_id, std::vector<meta::TableSchema>& partition_schema_array) = 0;
|
||||
|
||||
virtual Status
|
||||
InsertVectors(const std::string& table_id, const std::string& partition_tag, uint64_t n, const float* vectors,
|
||||
IDNumbers& vector_ids_) = 0;
|
||||
InsertVectors(const std::string& table_id, const std::string& partition_tag, VectorsData& vectors) = 0;
|
||||
|
||||
virtual Status
|
||||
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) = 0;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) = 0;
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Size(uint64_t& result) = 0;
|
||||
|
|
|
@ -317,8 +317,7 @@ DBImpl::ShowPartitions(const std::string& table_id, std::vector<meta::TableSchem
|
|||
}
|
||||
|
||||
Status
|
||||
DBImpl::InsertVectors(const std::string& table_id, const std::string& partition_tag, uint64_t n, const float* vectors,
|
||||
IDNumbers& vector_ids) {
|
||||
DBImpl::InsertVectors(const std::string& table_id, const std::string& partition_tag, VectorsData& vectors) {
|
||||
// ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache";
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
return SHUTDOWN_ERROR;
|
||||
|
@ -337,8 +336,8 @@ DBImpl::InsertVectors(const std::string& table_id, const std::string& partition_
|
|||
}
|
||||
|
||||
// insert vectors into target table
|
||||
milvus::server::CollectInsertMetrics metrics(n, status);
|
||||
status = mem_mgr_->InsertVectors(target_table_name, n, vectors, vector_ids);
|
||||
milvus::server::CollectInsertMetrics metrics(vectors.vector_count_, status);
|
||||
status = mem_mgr_->InsertVectors(target_table_name, vectors);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
@ -407,23 +406,21 @@ DBImpl::DropIndex(const std::string& table_id) {
|
|||
|
||||
Status
|
||||
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
return SHUTDOWN_ERROR;
|
||||
}
|
||||
|
||||
meta::DatesT dates = {utils::GetDate()};
|
||||
Status result =
|
||||
Query(context, table_id, partition_tags, k, nq, nprobe, vectors, dates, result_ids, result_distances);
|
||||
Status result = Query(context, table_id, partition_tags, k, nprobe, vectors, dates, result_ids, result_distances);
|
||||
return result;
|
||||
}
|
||||
|
||||
Status
|
||||
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) {
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
auto query_ctx = context->Child("Query");
|
||||
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
|
@ -460,7 +457,7 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
|
|||
}
|
||||
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, nq, nprobe, vectors, result_ids, result_distances);
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
|
||||
|
||||
query_ctx->GetTraceContext()->GetSpan()->Finish();
|
||||
|
@ -470,9 +467,8 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
|
|||
|
||||
Status
|
||||
DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) {
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
auto query_ctx = context->Child("Query by file id");
|
||||
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
|
@ -501,7 +497,7 @@ DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std
|
|||
}
|
||||
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, nq, nprobe, vectors, result_ids, result_distances);
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
|
||||
|
||||
query_ctx->GetTraceContext()->GetSpan()->Finish();
|
||||
|
@ -523,11 +519,11 @@ DBImpl::Size(uint64_t& result) {
|
|||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Status
|
||||
DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
auto query_async_ctx = context->Child("Query Async");
|
||||
|
||||
server::CollectQueryMetrics metrics(nq);
|
||||
server::CollectQueryMetrics metrics(vectors.vector_count_);
|
||||
|
||||
TimeRecorder rc("");
|
||||
|
||||
|
@ -535,7 +531,7 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::s
|
|||
auto status = ongoing_files_checker_.MarkOngoingFiles(files);
|
||||
|
||||
ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size();
|
||||
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nq, nprobe, vectors);
|
||||
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nprobe, vectors);
|
||||
for (auto& file : files) {
|
||||
scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
|
||||
job->AddIndexFile(file_ptr);
|
||||
|
|
|
@ -92,8 +92,7 @@ class DBImpl : public DB {
|
|||
ShowPartitions(const std::string& table_id, std::vector<meta::TableSchema>& partition_schema_array) override;
|
||||
|
||||
Status
|
||||
InsertVectors(const std::string& table_id, const std::string& partition_tag, uint64_t n, const float* vectors,
|
||||
IDNumbers& vector_ids) override;
|
||||
InsertVectors(const std::string& table_id, const std::string& partition_tag, VectorsData& vectors) override;
|
||||
|
||||
Status
|
||||
CreateIndex(const std::string& table_id, const TableIndex& index) override;
|
||||
|
@ -106,20 +105,18 @@ class DBImpl : public DB {
|
|||
|
||||
Status
|
||||
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) override;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
|
||||
ResultDistances& result_distances) override;
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Size(uint64_t& result) override;
|
||||
|
@ -127,7 +124,7 @@ class DBImpl : public DB {
|
|||
private:
|
||||
Status
|
||||
QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
|
||||
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances);
|
||||
|
||||
void
|
||||
|
|
|
@ -43,6 +43,13 @@ struct TableIndex {
|
|||
int32_t metric_type_ = (int)MetricType::L2;
|
||||
};
|
||||
|
||||
struct VectorsData {
|
||||
uint64_t vector_count_ = 0;
|
||||
std::vector<float> float_data_;
|
||||
std::vector<uint8_t> binary_data_;
|
||||
IDNumbers id_array_;
|
||||
};
|
||||
|
||||
using File2RefCount = std::map<std::string, int64_t>;
|
||||
using Table2Files = std::map<std::string, File2RefCount>;
|
||||
|
||||
|
|
|
@ -37,12 +37,18 @@ enum class EngineType {
|
|||
FAISS_PQ,
|
||||
SPTAG_KDT,
|
||||
SPTAG_BKT,
|
||||
MAX_VALUE = SPTAG_BKT,
|
||||
FAISS_BIN_IDMAP,
|
||||
FAISS_BIN_IVFFLAT,
|
||||
MAX_VALUE = FAISS_BIN_IVFFLAT,
|
||||
};
|
||||
|
||||
enum class MetricType {
|
||||
L2 = 1,
|
||||
IP = 2,
|
||||
L2 = 1, // Euclidean Distance
|
||||
IP = 2, // Cosine Similarity
|
||||
HAMMING = 3, // Hamming Distance
|
||||
JACCARD = 4, // Jaccard Distance
|
||||
TANIMOTO = 5, // Tanimoto Distance
|
||||
MAX_VALUE = TANIMOTO,
|
||||
};
|
||||
|
||||
class ExecutionEngine {
|
||||
|
@ -50,6 +56,9 @@ class ExecutionEngine {
|
|||
virtual Status
|
||||
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) = 0;
|
||||
|
||||
virtual Status
|
||||
AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) = 0;
|
||||
|
||||
virtual size_t
|
||||
Count() const = 0;
|
||||
|
||||
|
@ -86,6 +95,10 @@ class ExecutionEngine {
|
|||
virtual Status
|
||||
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels, bool hybrid) = 0;
|
||||
|
||||
virtual Status
|
||||
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid) = 0;
|
||||
|
||||
virtual std::shared_ptr<ExecutionEngine>
|
||||
BuildIndex(const std::string& location, EngineType engine_type) = 0;
|
||||
|
||||
|
|
|
@ -25,7 +25,9 @@
|
|||
#include "utils/CommonUtil.h"
|
||||
#include "utils/Exception.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/ValidationUtil.h"
|
||||
|
||||
#include "wrapper/BinVecImpl.h"
|
||||
#include "wrapper/ConfAdapter.h"
|
||||
#include "wrapper/ConfAdapterMgr.h"
|
||||
#include "wrapper/VecImpl.h"
|
||||
|
@ -39,6 +41,40 @@
|
|||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
namespace {
|
||||
|
||||
Status
|
||||
MappingMetricType(MetricType metric_type, knowhere::METRICTYPE& kw_type) {
|
||||
switch (metric_type) {
|
||||
case MetricType::IP:
|
||||
kw_type = knowhere::METRICTYPE::IP;
|
||||
break;
|
||||
case MetricType::L2:
|
||||
kw_type = knowhere::METRICTYPE::L2;
|
||||
break;
|
||||
case MetricType::HAMMING:
|
||||
kw_type = knowhere::METRICTYPE::HAMMING;
|
||||
break;
|
||||
case MetricType::JACCARD:
|
||||
kw_type = knowhere::METRICTYPE::JACCARD;
|
||||
break;
|
||||
case MetricType::TANIMOTO:
|
||||
kw_type = knowhere::METRICTYPE::TANIMOTO;
|
||||
break;
|
||||
default:
|
||||
return Status(DB_ERROR, "Unsupported metric type");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool
|
||||
IsBinaryIndexType(IndexType type) {
|
||||
return type == IndexType::FAISS_BIN_IDMAP || type == IndexType::FAISS_BIN_IVFLAT_CPU;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class CachedQuantizer : public cache::DataObj {
|
||||
public:
|
||||
explicit CachedQuantizer(knowhere::QuantizerPtr data) : data_(std::move(data)) {
|
||||
|
@ -61,7 +97,10 @@ class CachedQuantizer : public cache::DataObj {
|
|||
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
|
||||
MetricType metric_type, int32_t nlist)
|
||||
: location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
|
||||
index_ = CreatetVecIndex(EngineType::FAISS_IDMAP);
|
||||
EngineType tmp_index_type = server::ValidationUtil::IsBinaryMetricType((int32_t)metric_type)
|
||||
? EngineType::FAISS_BIN_IDMAP
|
||||
: EngineType::FAISS_IDMAP;
|
||||
index_ = CreatetVecIndex(tmp_index_type);
|
||||
if (!index_) {
|
||||
throw Exception(DB_ERROR, "Unsupported index type");
|
||||
}
|
||||
|
@ -69,11 +108,20 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
|
|||
TempMetaConf temp_conf;
|
||||
temp_conf.gpu_id = gpu_num_;
|
||||
temp_conf.dim = dimension;
|
||||
temp_conf.metric_type = (metric_type_ == MetricType::IP) ? knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2;
|
||||
auto status = MappingMetricType(metric_type, temp_conf.metric_type);
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->Match(temp_conf);
|
||||
|
||||
auto ec = std::static_pointer_cast<BFIndex>(index_)->Build(conf);
|
||||
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
|
||||
if (auto bf_index = std::dynamic_pointer_cast<BFIndex>(index_)) {
|
||||
ec = bf_index->Build(conf);
|
||||
} else if (auto bf_bin_index = std::dynamic_pointer_cast<BinBFIndex>(index_)) {
|
||||
ec = bf_bin_index->Build(conf);
|
||||
}
|
||||
if (ec != KNOWHERE_SUCCESS) {
|
||||
throw Exception(DB_ERROR, "Build index error");
|
||||
}
|
||||
|
@ -148,6 +196,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
|
|||
index = GetVecIndexFactory(IndexType::SPTAG_BKT_RNT_CPU);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_BIN_IDMAP: {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_BIN_IVFFLAT: {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_BIN_IVFLAT_CPU);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
ENGINE_LOG_ERROR << "Unsupported index type";
|
||||
return nullptr;
|
||||
|
@ -242,6 +298,12 @@ ExecutionEngineImpl::AddWithIds(int64_t n, const float* xdata, const int64_t* xi
|
|||
return status;
|
||||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) {
|
||||
auto status = index_->Add(n, xdata, xids);
|
||||
return status;
|
||||
}
|
||||
|
||||
size_t
|
||||
ExecutionEngineImpl::Count() const {
|
||||
if (index_ == nullptr) {
|
||||
|
@ -253,7 +315,11 @@ ExecutionEngineImpl::Count() const {
|
|||
|
||||
size_t
|
||||
ExecutionEngineImpl::Size() const {
|
||||
return (size_t)(Count() * Dimension()) * sizeof(float);
|
||||
if (IsBinaryIndexType(index_->GetType())) {
|
||||
return (size_t)(Count() * Dimension() / 8);
|
||||
} else {
|
||||
return (size_t)(Count() * Dimension()) * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
|
@ -483,6 +549,14 @@ ExecutionEngineImpl::Merge(const std::string& location) {
|
|||
ENGINE_LOG_DEBUG << "Finish merge index file: " << location;
|
||||
}
|
||||
return status;
|
||||
} else if (auto bin_index = std::dynamic_pointer_cast<BinBFIndex>(to_merge)) {
|
||||
auto status = index_->Add(bin_index->Count(), bin_index->GetRawVectors(), bin_index->GetRawIds());
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Failed to merge: " << location << " to: " << location_;
|
||||
} else {
|
||||
ENGINE_LOG_DEBUG << "Finish merge index file: " << location;
|
||||
}
|
||||
return status;
|
||||
} else {
|
||||
return Status(DB_ERROR, "file index type is not idmap");
|
||||
}
|
||||
|
@ -493,7 +567,8 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
|||
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
|
||||
|
||||
auto from_index = std::dynamic_pointer_cast<BFIndex>(index_);
|
||||
if (from_index == nullptr) {
|
||||
auto bin_from_index = std::dynamic_pointer_cast<BinBFIndex>(index_);
|
||||
if (from_index == nullptr && bin_from_index == nullptr) {
|
||||
ENGINE_LOG_ERROR << "ExecutionEngineImpl: from_index is null, failed to build index";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -507,13 +582,20 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
|||
temp_conf.gpu_id = gpu_num_;
|
||||
temp_conf.dim = Dimension();
|
||||
temp_conf.nlist = nlist_;
|
||||
temp_conf.metric_type = (metric_type_ == MetricType::IP) ? knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2;
|
||||
temp_conf.size = Count();
|
||||
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType());
|
||||
auto conf = adapter->Match(temp_conf);
|
||||
|
||||
auto status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
|
||||
if (from_index) {
|
||||
status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
|
||||
} else if (bin_from_index) {
|
||||
status = to_index->BuildAll(Count(), bin_from_index->GetRawVectors(), bin_from_index->GetRawIds(), conf);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
@ -611,6 +693,40 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
|
|||
return status;
|
||||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances,
|
||||
int64_t* labels, bool hybrid) {
|
||||
if (index_ == nullptr) {
|
||||
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search";
|
||||
return Status(DB_ERROR, "index is null");
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
|
||||
|
||||
// TODO(linxj): remove here. Get conf from function
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.k = k;
|
||||
temp_conf.nprobe = nprobe;
|
||||
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
|
||||
|
||||
if (hybrid) {
|
||||
HybridLoad();
|
||||
}
|
||||
|
||||
auto status = index_->Search(n, data, distances, labels, conf);
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error:" << status.message();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::Cache() {
|
||||
cache::DataObjPtr obj = std::static_pointer_cast<cache::DataObj>(index_);
|
||||
|
|
|
@ -37,6 +37,9 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
Status
|
||||
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override;
|
||||
|
||||
Status
|
||||
AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) override;
|
||||
|
||||
size_t
|
||||
Count() const override;
|
||||
|
||||
|
@ -74,6 +77,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid = false) override;
|
||||
|
||||
Status
|
||||
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid = false) override;
|
||||
|
||||
ExecutionEnginePtr
|
||||
BuildIndex(const std::string& location, EngineType engine_type) override;
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace engine {
|
|||
class MemManager {
|
||||
public:
|
||||
virtual Status
|
||||
InsertVectors(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids) = 0;
|
||||
InsertVectors(const std::string& table_id, VectorsData& vectors) = 0;
|
||||
|
||||
virtual Status
|
||||
Serialize(std::set<std::string>& table_ids) = 0;
|
||||
|
|
|
@ -37,26 +37,25 @@ MemManagerImpl::GetMemByTable(const std::string& table_id) {
|
|||
}
|
||||
|
||||
Status
|
||||
MemManagerImpl::InsertVectors(const std::string& table_id_, size_t n_, const float* vectors_, IDNumbers& vector_ids_) {
|
||||
MemManagerImpl::InsertVectors(const std::string& table_id, VectorsData& vectors) {
|
||||
while (GetCurrentMem() > options_.insert_buffer_size_) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
return InsertVectorsNoLock(table_id_, n_, vectors_, vector_ids_);
|
||||
return InsertVectorsNoLock(table_id, vectors);
|
||||
}
|
||||
|
||||
Status
|
||||
MemManagerImpl::InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors,
|
||||
IDNumbers& vector_ids) {
|
||||
MemManagerImpl::InsertVectorsNoLock(const std::string& table_id, VectorsData& vectors) {
|
||||
MemTablePtr mem = GetMemByTable(table_id);
|
||||
VectorSourcePtr source = std::make_shared<VectorSource>(n, vectors);
|
||||
VectorSourcePtr source = std::make_shared<VectorSource>(vectors);
|
||||
|
||||
auto status = mem->Add(source, vector_ids);
|
||||
auto status = mem->Add(source);
|
||||
if (status.ok()) {
|
||||
if (vector_ids.empty()) {
|
||||
vector_ids = source->GetVectorIds();
|
||||
if (vectors.id_array_.empty()) {
|
||||
vectors.id_array_ = source->GetVectorIds();
|
||||
}
|
||||
}
|
||||
return status;
|
||||
|
|
|
@ -41,7 +41,7 @@ class MemManagerImpl : public MemManager {
|
|||
}
|
||||
|
||||
Status
|
||||
InsertVectors(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids) override;
|
||||
InsertVectors(const std::string& table_id, VectorsData& vectors) override;
|
||||
|
||||
Status
|
||||
Serialize(std::set<std::string>& table_ids) override;
|
||||
|
@ -63,7 +63,7 @@ class MemManagerImpl : public MemManager {
|
|||
GetMemByTable(const std::string& table_id);
|
||||
|
||||
Status
|
||||
InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids);
|
||||
InsertVectorsNoLock(const std::string& table_id, VectorsData& vectors);
|
||||
Status
|
||||
ToImmutable();
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ MemTable::MemTable(const std::string& table_id, const meta::MetaPtr& meta, const
|
|||
}
|
||||
|
||||
Status
|
||||
MemTable::Add(VectorSourcePtr& source, IDNumbers& vector_ids) {
|
||||
MemTable::Add(VectorSourcePtr& source) {
|
||||
while (!source->AllAdded()) {
|
||||
MemTableFilePtr current_mem_table_file;
|
||||
if (!mem_table_file_list_.empty()) {
|
||||
|
@ -39,12 +39,12 @@ MemTable::Add(VectorSourcePtr& source, IDNumbers& vector_ids) {
|
|||
Status status;
|
||||
if (mem_table_file_list_.empty() || current_mem_table_file->IsFull()) {
|
||||
MemTableFilePtr new_mem_table_file = std::make_shared<MemTableFile>(table_id_, meta_, options_);
|
||||
status = new_mem_table_file->Add(source, vector_ids);
|
||||
status = new_mem_table_file->Add(source);
|
||||
if (status.ok()) {
|
||||
mem_table_file_list_.emplace_back(new_mem_table_file);
|
||||
}
|
||||
} else {
|
||||
status = current_mem_table_file->Add(source, vector_ids);
|
||||
status = current_mem_table_file->Add(source);
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -36,7 +36,7 @@ class MemTable {
|
|||
MemTable(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options);
|
||||
|
||||
Status
|
||||
Add(VectorSourcePtr& source, IDNumbers& vector_ids);
|
||||
Add(VectorSourcePtr& source);
|
||||
|
||||
void
|
||||
GetCurrentMemTableFile(MemTableFilePtr& mem_table_file);
|
||||
|
|
|
@ -53,7 +53,7 @@ MemTableFile::CreateTableFile() {
|
|||
}
|
||||
|
||||
Status
|
||||
MemTableFile::Add(const VectorSourcePtr& source, IDNumbers& vector_ids) {
|
||||
MemTableFile::Add(VectorSourcePtr& source) {
|
||||
if (table_file_schema_.dimension_ <= 0) {
|
||||
std::string err_msg =
|
||||
"MemTableFile::Add: table_file_schema dimension = " + std::to_string(table_file_schema_.dimension_) +
|
||||
|
@ -62,13 +62,12 @@ MemTableFile::Add(const VectorSourcePtr& source, IDNumbers& vector_ids) {
|
|||
return Status(DB_ERROR, "Not able to create table file");
|
||||
}
|
||||
|
||||
size_t single_vector_mem_size = table_file_schema_.dimension_ * VECTOR_TYPE_SIZE;
|
||||
size_t single_vector_mem_size = source->SingleVectorSize(table_file_schema_.dimension_);
|
||||
size_t mem_left = GetMemLeft();
|
||||
if (mem_left >= single_vector_mem_size) {
|
||||
size_t num_vectors_to_add = std::ceil(mem_left / single_vector_mem_size);
|
||||
size_t num_vectors_added;
|
||||
auto status =
|
||||
source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added, vector_ids);
|
||||
auto status = source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added);
|
||||
if (status.ok()) {
|
||||
current_mem_ += (num_vectors_added * single_vector_mem_size);
|
||||
}
|
||||
|
@ -89,7 +88,7 @@ MemTableFile::GetMemLeft() {
|
|||
|
||||
bool
|
||||
MemTableFile::IsFull() {
|
||||
size_t single_vector_mem_size = table_file_schema_.dimension_ * VECTOR_TYPE_SIZE;
|
||||
size_t single_vector_mem_size = table_file_schema_.dimension_ * FLOAT_TYPE_SIZE;
|
||||
return (GetMemLeft() < single_vector_mem_size);
|
||||
}
|
||||
|
||||
|
@ -104,7 +103,8 @@ MemTableFile::Serialize() {
|
|||
|
||||
// if index type isn't IDMAP, set file type to TO_INDEX if file size execeed index_file_size
|
||||
// else set file type to RAW, no need to build index
|
||||
if (table_file_schema_.engine_type_ != (int)EngineType::FAISS_IDMAP) {
|
||||
if (table_file_schema_.engine_type_ != (int)EngineType::FAISS_IDMAP &&
|
||||
table_file_schema_.engine_type_ != (int)EngineType::FAISS_BIN_IDMAP) {
|
||||
table_file_schema_.file_type_ = (size >= table_file_schema_.index_file_size_) ? meta::TableFileSchema::TO_INDEX
|
||||
: meta::TableFileSchema::RAW;
|
||||
} else {
|
||||
|
|
|
@ -33,7 +33,7 @@ class MemTableFile {
|
|||
MemTableFile(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options);
|
||||
|
||||
Status
|
||||
Add(const VectorSourcePtr& source, IDNumbers& vector_ids);
|
||||
Add(VectorSourcePtr& source);
|
||||
|
||||
size_t
|
||||
GetCurrentMem();
|
||||
|
|
|
@ -24,30 +24,41 @@
|
|||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
VectorSource::VectorSource(const size_t& n, const float* vectors)
|
||||
: n_(n), vectors_(vectors), id_generator_(std::make_shared<SimpleIDGenerator>()) {
|
||||
VectorSource::VectorSource(VectorsData& vectors)
|
||||
: vectors_(vectors), id_generator_(std::make_shared<SimpleIDGenerator>()) {
|
||||
current_num_vectors_added = 0;
|
||||
}
|
||||
|
||||
Status
|
||||
VectorSource::Add(const ExecutionEnginePtr& execution_engine, const meta::TableFileSchema& table_file_schema,
|
||||
const size_t& num_vectors_to_add, size_t& num_vectors_added, IDNumbers& vector_ids) {
|
||||
server::CollectAddMetrics metrics(n_, table_file_schema.dimension_);
|
||||
const size_t& num_vectors_to_add, size_t& num_vectors_added) {
|
||||
uint64_t n = vectors_.vector_count_;
|
||||
server::CollectAddMetrics metrics(n, table_file_schema.dimension_);
|
||||
|
||||
num_vectors_added =
|
||||
current_num_vectors_added + num_vectors_to_add <= n_ ? num_vectors_to_add : n_ - current_num_vectors_added;
|
||||
current_num_vectors_added + num_vectors_to_add <= n ? num_vectors_to_add : n - current_num_vectors_added;
|
||||
IDNumbers vector_ids_to_add;
|
||||
if (vector_ids.empty()) {
|
||||
if (vectors_.id_array_.empty()) {
|
||||
id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add);
|
||||
} else {
|
||||
vector_ids_to_add.resize(num_vectors_added);
|
||||
for (int pos = current_num_vectors_added; pos < current_num_vectors_added + num_vectors_added; pos++) {
|
||||
vector_ids_to_add[pos - current_num_vectors_added] = vector_ids[pos];
|
||||
vector_ids_to_add[pos - current_num_vectors_added] = vectors_.id_array_[pos];
|
||||
}
|
||||
}
|
||||
Status status = execution_engine->AddWithIds(num_vectors_added,
|
||||
vectors_ + current_num_vectors_added * table_file_schema.dimension_,
|
||||
vector_ids_to_add.data());
|
||||
|
||||
Status status;
|
||||
if (!vectors_.float_data_.empty()) {
|
||||
status = execution_engine->AddWithIds(
|
||||
num_vectors_added, vectors_.float_data_.data() + current_num_vectors_added * table_file_schema.dimension_,
|
||||
vector_ids_to_add.data());
|
||||
} else if (!vectors_.binary_data_.empty()) {
|
||||
status = execution_engine->AddWithIds(
|
||||
num_vectors_added,
|
||||
vectors_.binary_data_.data() + current_num_vectors_added * SingleVectorSize(table_file_schema.dimension_),
|
||||
vector_ids_to_add.data());
|
||||
}
|
||||
|
||||
if (status.ok()) {
|
||||
current_num_vectors_added += num_vectors_added;
|
||||
vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()),
|
||||
|
@ -64,9 +75,20 @@ VectorSource::GetNumVectorsAdded() {
|
|||
return current_num_vectors_added;
|
||||
}
|
||||
|
||||
size_t
|
||||
VectorSource::SingleVectorSize(uint16_t dimension) {
|
||||
if (!vectors_.float_data_.empty()) {
|
||||
return dimension * FLOAT_TYPE_SIZE;
|
||||
} else if (!vectors_.binary_data_.empty()) {
|
||||
return dimension / 8;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool
|
||||
VectorSource::AllAdded() {
|
||||
return (current_num_vectors_added == n_);
|
||||
return (current_num_vectors_added == vectors_.vector_count_);
|
||||
}
|
||||
|
||||
IDNumbers
|
||||
|
|
|
@ -29,15 +29,18 @@ namespace engine {
|
|||
|
||||
class VectorSource {
|
||||
public:
|
||||
VectorSource(const size_t& n, const float* vectors);
|
||||
explicit VectorSource(VectorsData& vectors);
|
||||
|
||||
Status
|
||||
Add(const ExecutionEnginePtr& execution_engine, const meta::TableFileSchema& table_file_schema,
|
||||
const size_t& num_vectors_to_add, size_t& num_vectors_added, IDNumbers& vector_ids);
|
||||
const size_t& num_vectors_to_add, size_t& num_vectors_added);
|
||||
|
||||
size_t
|
||||
GetNumVectorsAdded();
|
||||
|
||||
size_t
|
||||
SingleVectorSize(uint16_t dimension);
|
||||
|
||||
bool
|
||||
AllAdded();
|
||||
|
||||
|
@ -45,8 +48,7 @@ class VectorSource {
|
|||
GetVectorIds();
|
||||
|
||||
private:
|
||||
const size_t n_;
|
||||
const float* vectors_;
|
||||
VectorsData& vectors_;
|
||||
IDNumbers vector_ids_;
|
||||
|
||||
size_t current_num_vectors_added;
|
||||
|
|
|
@ -462,7 +462,8 @@ const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_milvus_2eproto::offsets[] PROT
|
|||
~0u, // no _extensions_
|
||||
~0u, // no _oneof_case_
|
||||
~0u, // no _weak_field_map_
|
||||
PROTOBUF_FIELD_OFFSET(::milvus::grpc::RowRecord, vector_data_),
|
||||
PROTOBUF_FIELD_OFFSET(::milvus::grpc::RowRecord, float_data_),
|
||||
PROTOBUF_FIELD_OFFSET(::milvus::grpc::RowRecord, binary_data_),
|
||||
~0u, // no _has_bits_
|
||||
PROTOBUF_FIELD_OFFSET(::milvus::grpc::InsertParam, _internal_metadata_),
|
||||
~0u, // no _extensions_
|
||||
|
@ -565,18 +566,18 @@ static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOB
|
|||
{ 37, -1, sizeof(::milvus::grpc::PartitionList)},
|
||||
{ 44, -1, sizeof(::milvus::grpc::Range)},
|
||||
{ 51, -1, sizeof(::milvus::grpc::RowRecord)},
|
||||
{ 57, -1, sizeof(::milvus::grpc::InsertParam)},
|
||||
{ 66, -1, sizeof(::milvus::grpc::VectorIds)},
|
||||
{ 73, -1, sizeof(::milvus::grpc::SearchParam)},
|
||||
{ 84, -1, sizeof(::milvus::grpc::SearchInFilesParam)},
|
||||
{ 91, -1, sizeof(::milvus::grpc::TopKQueryResult)},
|
||||
{ 100, -1, sizeof(::milvus::grpc::StringReply)},
|
||||
{ 107, -1, sizeof(::milvus::grpc::BoolReply)},
|
||||
{ 114, -1, sizeof(::milvus::grpc::TableRowCount)},
|
||||
{ 121, -1, sizeof(::milvus::grpc::Command)},
|
||||
{ 127, -1, sizeof(::milvus::grpc::Index)},
|
||||
{ 134, -1, sizeof(::milvus::grpc::IndexParam)},
|
||||
{ 142, -1, sizeof(::milvus::grpc::DeleteByDateParam)},
|
||||
{ 58, -1, sizeof(::milvus::grpc::InsertParam)},
|
||||
{ 67, -1, sizeof(::milvus::grpc::VectorIds)},
|
||||
{ 74, -1, sizeof(::milvus::grpc::SearchParam)},
|
||||
{ 85, -1, sizeof(::milvus::grpc::SearchInFilesParam)},
|
||||
{ 92, -1, sizeof(::milvus::grpc::TopKQueryResult)},
|
||||
{ 101, -1, sizeof(::milvus::grpc::StringReply)},
|
||||
{ 108, -1, sizeof(::milvus::grpc::BoolReply)},
|
||||
{ 115, -1, sizeof(::milvus::grpc::TableRowCount)},
|
||||
{ 122, -1, sizeof(::milvus::grpc::Command)},
|
||||
{ 128, -1, sizeof(::milvus::grpc::Index)},
|
||||
{ 135, -1, sizeof(::milvus::grpc::IndexParam)},
|
||||
{ 143, -1, sizeof(::milvus::grpc::DeleteByDateParam)},
|
||||
};
|
||||
|
||||
static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = {
|
||||
|
@ -617,65 +618,65 @@ const char descriptor_table_protodef_milvus_2eproto[] PROTOBUF_SECTION_VARIABLE(
|
|||
"ilvus.grpc.Status\0224\n\017partition_array\030\002 \003"
|
||||
"(\0132\033.milvus.grpc.PartitionParam\"/\n\005Range"
|
||||
"\022\023\n\013start_value\030\001 \001(\t\022\021\n\tend_value\030\002 \001(\t"
|
||||
"\" \n\tRowRecord\022\023\n\013vector_data\030\001 \003(\002\"\200\001\n\013I"
|
||||
"nsertParam\022\022\n\ntable_name\030\001 \001(\t\0220\n\020row_re"
|
||||
"cord_array\030\002 \003(\0132\026.milvus.grpc.RowRecord"
|
||||
"\022\024\n\014row_id_array\030\003 \003(\003\022\025\n\rpartition_tag\030"
|
||||
"\004 \001(\t\"I\n\tVectorIds\022#\n\006status\030\001 \001(\0132\023.mil"
|
||||
"vus.grpc.Status\022\027\n\017vector_id_array\030\002 \003(\003"
|
||||
"\"\277\001\n\013SearchParam\022\022\n\ntable_name\030\001 \001(\t\0222\n\022"
|
||||
"query_record_array\030\002 \003(\0132\026.milvus.grpc.R"
|
||||
"owRecord\022-\n\021query_range_array\030\003 \003(\0132\022.mi"
|
||||
"lvus.grpc.Range\022\014\n\004topk\030\004 \001(\003\022\016\n\006nprobe\030"
|
||||
"\005 \001(\003\022\033\n\023partition_tag_array\030\006 \003(\t\"[\n\022Se"
|
||||
"archInFilesParam\022\025\n\rfile_id_array\030\001 \003(\t\022"
|
||||
".\n\014search_param\030\002 \001(\0132\030.milvus.grpc.Sear"
|
||||
"chParam\"g\n\017TopKQueryResult\022#\n\006status\030\001 \001"
|
||||
"(\0132\023.milvus.grpc.Status\022\017\n\007row_num\030\002 \001(\003"
|
||||
"\022\013\n\003ids\030\003 \003(\003\022\021\n\tdistances\030\004 \003(\002\"H\n\013Stri"
|
||||
"ngReply\022#\n\006status\030\001 \001(\0132\023.milvus.grpc.St"
|
||||
"atus\022\024\n\014string_reply\030\002 \001(\t\"D\n\tBoolReply\022"
|
||||
"#\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\022\n\n"
|
||||
"bool_reply\030\002 \001(\010\"M\n\rTableRowCount\022#\n\006sta"
|
||||
"tus\030\001 \001(\0132\023.milvus.grpc.Status\022\027\n\017table_"
|
||||
"row_count\030\002 \001(\003\"\026\n\007Command\022\013\n\003cmd\030\001 \001(\t\""
|
||||
"*\n\005Index\022\022\n\nindex_type\030\001 \001(\005\022\r\n\005nlist\030\002 "
|
||||
"\001(\005\"h\n\nIndexParam\022#\n\006status\030\001 \001(\0132\023.milv"
|
||||
"us.grpc.Status\022\022\n\ntable_name\030\002 \001(\t\022!\n\005in"
|
||||
"dex\030\003 \001(\0132\022.milvus.grpc.Index\"J\n\021DeleteB"
|
||||
"yDateParam\022!\n\005range\030\001 \001(\0132\022.milvus.grpc."
|
||||
"Range\022\022\n\ntable_name\030\002 \001(\t2\272\t\n\rMilvusServ"
|
||||
"ice\022>\n\013CreateTable\022\030.milvus.grpc.TableSc"
|
||||
"hema\032\023.milvus.grpc.Status\"\000\022<\n\010HasTable\022"
|
||||
"\026.milvus.grpc.TableName\032\026.milvus.grpc.Bo"
|
||||
"olReply\"\000\022C\n\rDescribeTable\022\026.milvus.grpc"
|
||||
".TableName\032\030.milvus.grpc.TableSchema\"\000\022B"
|
||||
"\n\nCountTable\022\026.milvus.grpc.TableName\032\032.m"
|
||||
"ilvus.grpc.TableRowCount\"\000\022@\n\nShowTables"
|
||||
"\022\024.milvus.grpc.Command\032\032.milvus.grpc.Tab"
|
||||
"leNameList\"\000\022:\n\tDropTable\022\026.milvus.grpc."
|
||||
"TableName\032\023.milvus.grpc.Status\"\000\022=\n\013Crea"
|
||||
"teIndex\022\027.milvus.grpc.IndexParam\032\023.milvu"
|
||||
"s.grpc.Status\"\000\022B\n\rDescribeIndex\022\026.milvu"
|
||||
"s.grpc.TableName\032\027.milvus.grpc.IndexPara"
|
||||
"m\"\000\022:\n\tDropIndex\022\026.milvus.grpc.TableName"
|
||||
"\032\023.milvus.grpc.Status\"\000\022E\n\017CreatePartiti"
|
||||
"on\022\033.milvus.grpc.PartitionParam\032\023.milvus"
|
||||
".grpc.Status\"\000\022F\n\016ShowPartitions\022\026.milvu"
|
||||
"s.grpc.TableName\032\032.milvus.grpc.Partition"
|
||||
"List\"\000\022C\n\rDropPartition\022\033.milvus.grpc.Pa"
|
||||
"rtitionParam\032\023.milvus.grpc.Status\"\000\022<\n\006I"
|
||||
"nsert\022\030.milvus.grpc.InsertParam\032\026.milvus"
|
||||
".grpc.VectorIds\"\000\022B\n\006Search\022\030.milvus.grp"
|
||||
"c.SearchParam\032\034.milvus.grpc.TopKQueryRes"
|
||||
"ult\"\000\022P\n\rSearchInFiles\022\037.milvus.grpc.Sea"
|
||||
"rchInFilesParam\032\034.milvus.grpc.TopKQueryR"
|
||||
"esult\"\000\0227\n\003Cmd\022\024.milvus.grpc.Command\032\030.m"
|
||||
"ilvus.grpc.StringReply\"\000\022E\n\014DeleteByDate"
|
||||
"\022\036.milvus.grpc.DeleteByDateParam\032\023.milvu"
|
||||
"s.grpc.Status\"\000\022=\n\014PreloadTable\022\026.milvus"
|
||||
".grpc.TableName\032\023.milvus.grpc.Status\"\000b\006"
|
||||
"proto3"
|
||||
"\"4\n\tRowRecord\022\022\n\nfloat_data\030\001 \003(\002\022\023\n\013bin"
|
||||
"ary_data\030\002 \001(\014\"\200\001\n\013InsertParam\022\022\n\ntable_"
|
||||
"name\030\001 \001(\t\0220\n\020row_record_array\030\002 \003(\0132\026.m"
|
||||
"ilvus.grpc.RowRecord\022\024\n\014row_id_array\030\003 \003"
|
||||
"(\003\022\025\n\rpartition_tag\030\004 \001(\t\"I\n\tVectorIds\022#"
|
||||
"\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\027\n\017v"
|
||||
"ector_id_array\030\002 \003(\003\"\277\001\n\013SearchParam\022\022\n\n"
|
||||
"table_name\030\001 \001(\t\0222\n\022query_record_array\030\002"
|
||||
" \003(\0132\026.milvus.grpc.RowRecord\022-\n\021query_ra"
|
||||
"nge_array\030\003 \003(\0132\022.milvus.grpc.Range\022\014\n\004t"
|
||||
"opk\030\004 \001(\003\022\016\n\006nprobe\030\005 \001(\003\022\033\n\023partition_t"
|
||||
"ag_array\030\006 \003(\t\"[\n\022SearchInFilesParam\022\025\n\r"
|
||||
"file_id_array\030\001 \003(\t\022.\n\014search_param\030\002 \001("
|
||||
"\0132\030.milvus.grpc.SearchParam\"g\n\017TopKQuery"
|
||||
"Result\022#\n\006status\030\001 \001(\0132\023.milvus.grpc.Sta"
|
||||
"tus\022\017\n\007row_num\030\002 \001(\003\022\013\n\003ids\030\003 \003(\003\022\021\n\tdis"
|
||||
"tances\030\004 \003(\002\"H\n\013StringReply\022#\n\006status\030\001 "
|
||||
"\001(\0132\023.milvus.grpc.Status\022\024\n\014string_reply"
|
||||
"\030\002 \001(\t\"D\n\tBoolReply\022#\n\006status\030\001 \001(\0132\023.mi"
|
||||
"lvus.grpc.Status\022\022\n\nbool_reply\030\002 \001(\010\"M\n\r"
|
||||
"TableRowCount\022#\n\006status\030\001 \001(\0132\023.milvus.g"
|
||||
"rpc.Status\022\027\n\017table_row_count\030\002 \001(\003\"\026\n\007C"
|
||||
"ommand\022\013\n\003cmd\030\001 \001(\t\"*\n\005Index\022\022\n\nindex_ty"
|
||||
"pe\030\001 \001(\005\022\r\n\005nlist\030\002 \001(\005\"h\n\nIndexParam\022#\n"
|
||||
"\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\022\n\nta"
|
||||
"ble_name\030\002 \001(\t\022!\n\005index\030\003 \001(\0132\022.milvus.g"
|
||||
"rpc.Index\"J\n\021DeleteByDateParam\022!\n\005range\030"
|
||||
"\001 \001(\0132\022.milvus.grpc.Range\022\022\n\ntable_name\030"
|
||||
"\002 \001(\t2\272\t\n\rMilvusService\022>\n\013CreateTable\022\030"
|
||||
".milvus.grpc.TableSchema\032\023.milvus.grpc.S"
|
||||
"tatus\"\000\022<\n\010HasTable\022\026.milvus.grpc.TableN"
|
||||
"ame\032\026.milvus.grpc.BoolReply\"\000\022C\n\rDescrib"
|
||||
"eTable\022\026.milvus.grpc.TableName\032\030.milvus."
|
||||
"grpc.TableSchema\"\000\022B\n\nCountTable\022\026.milvu"
|
||||
"s.grpc.TableName\032\032.milvus.grpc.TableRowC"
|
||||
"ount\"\000\022@\n\nShowTables\022\024.milvus.grpc.Comma"
|
||||
"nd\032\032.milvus.grpc.TableNameList\"\000\022:\n\tDrop"
|
||||
"Table\022\026.milvus.grpc.TableName\032\023.milvus.g"
|
||||
"rpc.Status\"\000\022=\n\013CreateIndex\022\027.milvus.grp"
|
||||
"c.IndexParam\032\023.milvus.grpc.Status\"\000\022B\n\rD"
|
||||
"escribeIndex\022\026.milvus.grpc.TableName\032\027.m"
|
||||
"ilvus.grpc.IndexParam\"\000\022:\n\tDropIndex\022\026.m"
|
||||
"ilvus.grpc.TableName\032\023.milvus.grpc.Statu"
|
||||
"s\"\000\022E\n\017CreatePartition\022\033.milvus.grpc.Par"
|
||||
"titionParam\032\023.milvus.grpc.Status\"\000\022F\n\016Sh"
|
||||
"owPartitions\022\026.milvus.grpc.TableName\032\032.m"
|
||||
"ilvus.grpc.PartitionList\"\000\022C\n\rDropPartit"
|
||||
"ion\022\033.milvus.grpc.PartitionParam\032\023.milvu"
|
||||
"s.grpc.Status\"\000\022<\n\006Insert\022\030.milvus.grpc."
|
||||
"InsertParam\032\026.milvus.grpc.VectorIds\"\000\022B\n"
|
||||
"\006Search\022\030.milvus.grpc.SearchParam\032\034.milv"
|
||||
"us.grpc.TopKQueryResult\"\000\022P\n\rSearchInFil"
|
||||
"es\022\037.milvus.grpc.SearchInFilesParam\032\034.mi"
|
||||
"lvus.grpc.TopKQueryResult\"\000\0227\n\003Cmd\022\024.mil"
|
||||
"vus.grpc.Command\032\030.milvus.grpc.StringRep"
|
||||
"ly\"\000\022E\n\014DeleteByDate\022\036.milvus.grpc.Delet"
|
||||
"eByDateParam\032\023.milvus.grpc.Status\"\000\022=\n\014P"
|
||||
"reloadTable\022\026.milvus.grpc.TableName\032\023.mi"
|
||||
"lvus.grpc.Status\"\000b\006proto3"
|
||||
;
|
||||
static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_milvus_2eproto_deps[1] = {
|
||||
&::descriptor_table_status_2eproto,
|
||||
|
@ -705,7 +706,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_mil
|
|||
static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_milvus_2eproto_once;
|
||||
static bool descriptor_table_milvus_2eproto_initialized = false;
|
||||
const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_milvus_2eproto = {
|
||||
&descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 2886,
|
||||
&descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 2906,
|
||||
&descriptor_table_milvus_2eproto_once, descriptor_table_milvus_2eproto_sccs, descriptor_table_milvus_2eproto_deps, 20, 1,
|
||||
schemas, file_default_instances, TableStruct_milvus_2eproto::offsets,
|
||||
file_level_metadata_milvus_2eproto, 20, file_level_enum_descriptors_milvus_2eproto, file_level_service_descriptors_milvus_2eproto,
|
||||
|
@ -3122,12 +3123,18 @@ RowRecord::RowRecord()
|
|||
RowRecord::RowRecord(const RowRecord& from)
|
||||
: ::PROTOBUF_NAMESPACE_ID::Message(),
|
||||
_internal_metadata_(nullptr),
|
||||
vector_data_(from.vector_data_) {
|
||||
float_data_(from.float_data_) {
|
||||
_internal_metadata_.MergeFrom(from._internal_metadata_);
|
||||
binary_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
if (!from.binary_data().empty()) {
|
||||
binary_data_.AssignWithDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from.binary_data_);
|
||||
}
|
||||
// @@protoc_insertion_point(copy_constructor:milvus.grpc.RowRecord)
|
||||
}
|
||||
|
||||
void RowRecord::SharedCtor() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_RowRecord_milvus_2eproto.base);
|
||||
binary_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
|
||||
RowRecord::~RowRecord() {
|
||||
|
@ -3136,6 +3143,7 @@ RowRecord::~RowRecord() {
|
|||
}
|
||||
|
||||
void RowRecord::SharedDtor() {
|
||||
binary_data_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
|
||||
void RowRecord::SetCachedSize(int size) const {
|
||||
|
@ -3153,7 +3161,8 @@ void RowRecord::Clear() {
|
|||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
vector_data_.Clear();
|
||||
float_data_.Clear();
|
||||
binary_data_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
_internal_metadata_.Clear();
|
||||
}
|
||||
|
||||
|
@ -3165,16 +3174,23 @@ const char* RowRecord::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::
|
|||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag);
|
||||
CHK_(ptr);
|
||||
switch (tag >> 3) {
|
||||
// repeated float vector_data = 1;
|
||||
// repeated float float_data = 1;
|
||||
case 1:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) {
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(mutable_vector_data(), ptr, ctx);
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(mutable_float_data(), ptr, ctx);
|
||||
CHK_(ptr);
|
||||
} else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 13) {
|
||||
add_vector_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<float>(ptr));
|
||||
add_float_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<float>(ptr));
|
||||
ptr += sizeof(float);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// bytes binary_data = 2;
|
||||
case 2:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) {
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(mutable_binary_data(), ptr, ctx);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
default: {
|
||||
handle_unusual:
|
||||
if ((tag & 7) == 4 || tag == 0) {
|
||||
|
@ -3205,16 +3221,27 @@ bool RowRecord::MergePartialFromCodedStream(
|
|||
tag = p.first;
|
||||
if (!p.second) goto handle_unusual;
|
||||
switch (::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::GetTagFieldNumber(tag)) {
|
||||
// repeated float vector_data = 1;
|
||||
// repeated float float_data = 1;
|
||||
case 1: {
|
||||
if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (10 & 0xFF)) {
|
||||
DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadPackedPrimitive<
|
||||
float, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_FLOAT>(
|
||||
input, this->mutable_vector_data())));
|
||||
input, this->mutable_float_data())));
|
||||
} else if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (13 & 0xFF)) {
|
||||
DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadRepeatedPrimitiveNoInline<
|
||||
float, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_FLOAT>(
|
||||
1, 10u, input, this->mutable_vector_data())));
|
||||
1, 10u, input, this->mutable_float_data())));
|
||||
} else {
|
||||
goto handle_unusual;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// bytes binary_data = 2;
|
||||
case 2: {
|
||||
if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (18 & 0xFF)) {
|
||||
DO_(::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadBytes(
|
||||
input, this->mutable_binary_data()));
|
||||
} else {
|
||||
goto handle_unusual;
|
||||
}
|
||||
|
@ -3248,13 +3275,19 @@ void RowRecord::SerializeWithCachedSizes(
|
|||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
// repeated float vector_data = 1;
|
||||
if (this->vector_data_size() > 0) {
|
||||
// repeated float float_data = 1;
|
||||
if (this->float_data_size() > 0) {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteTag(1, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
|
||||
output->WriteVarint32(_vector_data_cached_byte_size_.load(
|
||||
output->WriteVarint32(_float_data_cached_byte_size_.load(
|
||||
std::memory_order_relaxed));
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatArray(
|
||||
this->vector_data().data(), this->vector_data_size(), output);
|
||||
this->float_data().data(), this->float_data_size(), output);
|
||||
}
|
||||
|
||||
// bytes binary_data = 2;
|
||||
if (this->binary_data().size() > 0) {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBytesMaybeAliased(
|
||||
2, this->binary_data(), output);
|
||||
}
|
||||
|
||||
if (_internal_metadata_.have_unknown_fields()) {
|
||||
|
@ -3270,17 +3303,24 @@ void RowRecord::SerializeWithCachedSizes(
|
|||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
// repeated float vector_data = 1;
|
||||
if (this->vector_data_size() > 0) {
|
||||
// repeated float float_data = 1;
|
||||
if (this->float_data_size() > 0) {
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteTagToArray(
|
||||
1,
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
|
||||
target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::io::CodedOutputStream::WriteVarint32ToArray(
|
||||
_vector_data_cached_byte_size_.load(std::memory_order_relaxed),
|
||||
_float_data_cached_byte_size_.load(std::memory_order_relaxed),
|
||||
target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::
|
||||
WriteFloatNoTagToArray(this->vector_data_, target);
|
||||
WriteFloatNoTagToArray(this->float_data_, target);
|
||||
}
|
||||
|
||||
// bytes binary_data = 2;
|
||||
if (this->binary_data().size() > 0) {
|
||||
target =
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBytesToArray(
|
||||
2, this->binary_data(), target);
|
||||
}
|
||||
|
||||
if (_internal_metadata_.have_unknown_fields()) {
|
||||
|
@ -3304,9 +3344,9 @@ size_t RowRecord::ByteSizeLong() const {
|
|||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
// repeated float vector_data = 1;
|
||||
// repeated float float_data = 1;
|
||||
{
|
||||
unsigned int count = static_cast<unsigned int>(this->vector_data_size());
|
||||
unsigned int count = static_cast<unsigned int>(this->float_data_size());
|
||||
size_t data_size = 4UL * count;
|
||||
if (data_size > 0) {
|
||||
total_size += 1 +
|
||||
|
@ -3314,11 +3354,18 @@ size_t RowRecord::ByteSizeLong() const {
|
|||
static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size));
|
||||
}
|
||||
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size);
|
||||
_vector_data_cached_byte_size_.store(cached_size,
|
||||
_float_data_cached_byte_size_.store(cached_size,
|
||||
std::memory_order_relaxed);
|
||||
total_size += data_size;
|
||||
}
|
||||
|
||||
// bytes binary_data = 2;
|
||||
if (this->binary_data().size() > 0) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize(
|
||||
this->binary_data());
|
||||
}
|
||||
|
||||
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
|
||||
SetCachedSize(cached_size);
|
||||
return total_size;
|
||||
|
@ -3346,7 +3393,11 @@ void RowRecord::MergeFrom(const RowRecord& from) {
|
|||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
vector_data_.MergeFrom(from.vector_data_);
|
||||
float_data_.MergeFrom(from.float_data_);
|
||||
if (from.binary_data().size() > 0) {
|
||||
|
||||
binary_data_.AssignWithDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from.binary_data_);
|
||||
}
|
||||
}
|
||||
|
||||
void RowRecord::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) {
|
||||
|
@ -3370,7 +3421,9 @@ bool RowRecord::IsInitialized() const {
|
|||
void RowRecord::InternalSwap(RowRecord* other) {
|
||||
using std::swap;
|
||||
_internal_metadata_.Swap(&other->_internal_metadata_);
|
||||
vector_data_.InternalSwap(&other->vector_data_);
|
||||
float_data_.InternalSwap(&other->float_data_);
|
||||
binary_data_.Swap(&other->binary_data_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
|
||||
GetArenaNoVirtual());
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::Metadata RowRecord::GetMetadata() const {
|
||||
|
|
|
@ -1314,26 +1314,39 @@ class RowRecord :
|
|||
// accessors -------------------------------------------------------
|
||||
|
||||
enum : int {
|
||||
kVectorDataFieldNumber = 1,
|
||||
kFloatDataFieldNumber = 1,
|
||||
kBinaryDataFieldNumber = 2,
|
||||
};
|
||||
// repeated float vector_data = 1;
|
||||
int vector_data_size() const;
|
||||
void clear_vector_data();
|
||||
float vector_data(int index) const;
|
||||
void set_vector_data(int index, float value);
|
||||
void add_vector_data(float value);
|
||||
// repeated float float_data = 1;
|
||||
int float_data_size() const;
|
||||
void clear_float_data();
|
||||
float float_data(int index) const;
|
||||
void set_float_data(int index, float value);
|
||||
void add_float_data(float value);
|
||||
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >&
|
||||
vector_data() const;
|
||||
float_data() const;
|
||||
::PROTOBUF_NAMESPACE_ID::RepeatedField< float >*
|
||||
mutable_vector_data();
|
||||
mutable_float_data();
|
||||
|
||||
// bytes binary_data = 2;
|
||||
void clear_binary_data();
|
||||
const std::string& binary_data() const;
|
||||
void set_binary_data(const std::string& value);
|
||||
void set_binary_data(std::string&& value);
|
||||
void set_binary_data(const char* value);
|
||||
void set_binary_data(const void* value, size_t size);
|
||||
std::string* mutable_binary_data();
|
||||
std::string* release_binary_data();
|
||||
void set_allocated_binary_data(std::string* binary_data);
|
||||
|
||||
// @@protoc_insertion_point(class_scope:milvus.grpc.RowRecord)
|
||||
private:
|
||||
class _Internal;
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_;
|
||||
::PROTOBUF_NAMESPACE_ID::RepeatedField< float > vector_data_;
|
||||
mutable std::atomic<int> _vector_data_cached_byte_size_;
|
||||
::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_data_;
|
||||
mutable std::atomic<int> _float_data_cached_byte_size_;
|
||||
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr binary_data_;
|
||||
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
|
||||
friend struct ::TableStruct_milvus_2eproto;
|
||||
};
|
||||
|
@ -3907,34 +3920,85 @@ inline void Range::set_allocated_end_value(std::string* end_value) {
|
|||
|
||||
// RowRecord
|
||||
|
||||
// repeated float vector_data = 1;
|
||||
inline int RowRecord::vector_data_size() const {
|
||||
return vector_data_.size();
|
||||
// repeated float float_data = 1;
|
||||
inline int RowRecord::float_data_size() const {
|
||||
return float_data_.size();
|
||||
}
|
||||
inline void RowRecord::clear_vector_data() {
|
||||
vector_data_.Clear();
|
||||
inline void RowRecord::clear_float_data() {
|
||||
float_data_.Clear();
|
||||
}
|
||||
inline float RowRecord::vector_data(int index) const {
|
||||
// @@protoc_insertion_point(field_get:milvus.grpc.RowRecord.vector_data)
|
||||
return vector_data_.Get(index);
|
||||
inline float RowRecord::float_data(int index) const {
|
||||
// @@protoc_insertion_point(field_get:milvus.grpc.RowRecord.float_data)
|
||||
return float_data_.Get(index);
|
||||
}
|
||||
inline void RowRecord::set_vector_data(int index, float value) {
|
||||
vector_data_.Set(index, value);
|
||||
// @@protoc_insertion_point(field_set:milvus.grpc.RowRecord.vector_data)
|
||||
inline void RowRecord::set_float_data(int index, float value) {
|
||||
float_data_.Set(index, value);
|
||||
// @@protoc_insertion_point(field_set:milvus.grpc.RowRecord.float_data)
|
||||
}
|
||||
inline void RowRecord::add_vector_data(float value) {
|
||||
vector_data_.Add(value);
|
||||
// @@protoc_insertion_point(field_add:milvus.grpc.RowRecord.vector_data)
|
||||
inline void RowRecord::add_float_data(float value) {
|
||||
float_data_.Add(value);
|
||||
// @@protoc_insertion_point(field_add:milvus.grpc.RowRecord.float_data)
|
||||
}
|
||||
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >&
|
||||
RowRecord::vector_data() const {
|
||||
// @@protoc_insertion_point(field_list:milvus.grpc.RowRecord.vector_data)
|
||||
return vector_data_;
|
||||
RowRecord::float_data() const {
|
||||
// @@protoc_insertion_point(field_list:milvus.grpc.RowRecord.float_data)
|
||||
return float_data_;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >*
|
||||
RowRecord::mutable_vector_data() {
|
||||
// @@protoc_insertion_point(field_mutable_list:milvus.grpc.RowRecord.vector_data)
|
||||
return &vector_data_;
|
||||
RowRecord::mutable_float_data() {
|
||||
// @@protoc_insertion_point(field_mutable_list:milvus.grpc.RowRecord.float_data)
|
||||
return &float_data_;
|
||||
}
|
||||
|
||||
// bytes binary_data = 2;
|
||||
inline void RowRecord::clear_binary_data() {
|
||||
binary_data_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
inline const std::string& RowRecord::binary_data() const {
|
||||
// @@protoc_insertion_point(field_get:milvus.grpc.RowRecord.binary_data)
|
||||
return binary_data_.GetNoArena();
|
||||
}
|
||||
inline void RowRecord::set_binary_data(const std::string& value) {
|
||||
|
||||
binary_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value);
|
||||
// @@protoc_insertion_point(field_set:milvus.grpc.RowRecord.binary_data)
|
||||
}
|
||||
inline void RowRecord::set_binary_data(std::string&& value) {
|
||||
|
||||
binary_data_.SetNoArena(
|
||||
&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value));
|
||||
// @@protoc_insertion_point(field_set_rvalue:milvus.grpc.RowRecord.binary_data)
|
||||
}
|
||||
inline void RowRecord::set_binary_data(const char* value) {
|
||||
GOOGLE_DCHECK(value != nullptr);
|
||||
|
||||
binary_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value));
|
||||
// @@protoc_insertion_point(field_set_char:milvus.grpc.RowRecord.binary_data)
|
||||
}
|
||||
inline void RowRecord::set_binary_data(const void* value, size_t size) {
|
||||
|
||||
binary_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
|
||||
::std::string(reinterpret_cast<const char*>(value), size));
|
||||
// @@protoc_insertion_point(field_set_pointer:milvus.grpc.RowRecord.binary_data)
|
||||
}
|
||||
inline std::string* RowRecord::mutable_binary_data() {
|
||||
|
||||
// @@protoc_insertion_point(field_mutable:milvus.grpc.RowRecord.binary_data)
|
||||
return binary_data_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
inline std::string* RowRecord::release_binary_data() {
|
||||
// @@protoc_insertion_point(field_release:milvus.grpc.RowRecord.binary_data)
|
||||
|
||||
return binary_data_.ReleaseNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
inline void RowRecord::set_allocated_binary_data(std::string* binary_data) {
|
||||
if (binary_data != nullptr) {
|
||||
|
||||
} else {
|
||||
|
||||
}
|
||||
binary_data_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), binary_data);
|
||||
// @@protoc_insertion_point(field_set_allocated:milvus.grpc.RowRecord.binary_data)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
|
|
@ -66,7 +66,8 @@ message Range {
|
|||
* @brief Record inserted
|
||||
*/
|
||||
message RowRecord {
|
||||
repeated float vector_data = 1; //binary vector data
|
||||
repeated float float_data = 1; //float vector data
|
||||
bytes binary_data = 2; //binary vector data
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -32,6 +32,9 @@ set(index_srcs
|
|||
knowhere/index/vector_index/IndexSPTAG.cpp
|
||||
knowhere/index/vector_index/IndexIDMAP.cpp
|
||||
knowhere/index/vector_index/IndexIVF.cpp
|
||||
knowhere/index/vector_index/IndexBinaryIVF.cpp
|
||||
knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
|
||||
knowhere/index/vector_index/IndexBinaryIDMAP.cpp
|
||||
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
|
||||
knowhere/index/vector_index/IndexNSG.cpp
|
||||
knowhere/index/vector_index/nsg/NSG.cpp
|
||||
|
|
|
@ -35,4 +35,9 @@ extern const char* DISTANCE;
|
|||
auto rows = dataset->Get<int64_t>(meta::ROWS); \
|
||||
auto p_data = dataset->Get<const float*>(meta::TENSOR);
|
||||
|
||||
#define GETBINARYTENSOR(dataset) \
|
||||
auto dim = dataset->Get<int64_t>(meta::DIM); \
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS); \
|
||||
auto p_data = dataset->Get<const uint8_t*>(meta::TENSOR);
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <sstream>
|
||||
#include "Log.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -27,6 +28,9 @@ enum class METRICTYPE {
|
|||
INVALID = 0,
|
||||
L2 = 1,
|
||||
IP = 2,
|
||||
HAMMING = 20,
|
||||
JACCARD = 21,
|
||||
TANIMOTO = 22,
|
||||
};
|
||||
|
||||
// General Config
|
||||
|
@ -50,7 +54,13 @@ struct Cfg {
|
|||
|
||||
virtual bool
|
||||
CheckValid() {
|
||||
return true;
|
||||
if (metric_type == METRICTYPE::IP || metric_type == METRICTYPE::L2) {
|
||||
return true;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "MetricType: " << int(metric_type) << " not support!";
|
||||
KNOWHERE_THROW_MSG(ss.str());
|
||||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <faiss/index_io.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
FaissBaseBinaryIndex::FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index) : index_(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
FaissBaseBinaryIndex::SerializeImpl() {
|
||||
try {
|
||||
faiss::IndexBinary* index = index_.get();
|
||||
|
||||
// SealImpl();
|
||||
|
||||
MemoryIOWriter writer;
|
||||
faiss::write_index_binary(index, &writer);
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
data.reset(writer.data_);
|
||||
|
||||
BinarySet res_set;
|
||||
// TODO(linxj): use virtual func Name() instead of raw string.
|
||||
res_set.Append("BinaryIVF", data, writer.rp);
|
||||
return res_set;
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary) {
|
||||
auto binary = index_binary.GetByName("BinaryIVF");
|
||||
|
||||
MemoryIOReader reader;
|
||||
reader.total = binary->size;
|
||||
reader.data_ = binary->data.get();
|
||||
|
||||
faiss::IndexBinary* index = faiss::read_index_binary(&reader);
|
||||
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
|
@ -0,0 +1,43 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <faiss/IndexBinary.h>
|
||||
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
class FaissBaseBinaryIndex {
|
||||
protected:
|
||||
explicit FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index);
|
||||
|
||||
virtual BinarySet
|
||||
SerializeImpl();
|
||||
|
||||
virtual void
|
||||
LoadImpl(const BinarySet& index_binary);
|
||||
|
||||
public:
|
||||
std::shared_ptr<faiss::IndexBinary> index_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
|
@ -0,0 +1,146 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/MetaIndexes.h>
|
||||
|
||||
#include <faiss/index_factory.h>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
BinarySet
|
||||
BinaryIDMAP::Serialize() {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
GETBINARYTENSOR(dataset)
|
||||
|
||||
auto elems = rows * config->k;
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, Config());
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& cfg) {
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
index_->search(n, (uint8_t*)data, k, pdistances, labels);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
GETBINARYTENSOR(dataset)
|
||||
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
index_->add_with_ids(rows, (uint8_t*)p_data, p_ids);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Train(const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
|
||||
if (build_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
}
|
||||
config->CheckValid();
|
||||
|
||||
const char* type = "BFlat";
|
||||
auto index = faiss::index_binary_factory(config->d, type, GetMetricType(config->metric_type));
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIDMAP::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
const uint8_t*
|
||||
BinaryIDMAP::GetRawVectors() {
|
||||
try {
|
||||
auto file_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get());
|
||||
auto flat_index = dynamic_cast<faiss::IndexBinaryFlat*>(file_index->index);
|
||||
return flat_index->xb.data();
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t*
|
||||
BinaryIDMAP::GetRawIds() {
|
||||
try {
|
||||
auto file_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get());
|
||||
return file_index->id_map.data();
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIDMAP::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
|
@ -0,0 +1,78 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "FaissBaseBinaryIndex.h"
|
||||
#include "VectorIndex.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
class BinaryIDMAP : public VectorIndex, public FaissBaseBinaryIndex {
|
||||
public:
|
||||
BinaryIDMAP() : FaissBaseBinaryIndex(nullptr) {
|
||||
}
|
||||
|
||||
explicit BinaryIDMAP(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Train(const Config& config);
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
const uint8_t*
|
||||
GetRawVectors();
|
||||
|
||||
const int64_t*
|
||||
GetRawIds();
|
||||
|
||||
protected:
|
||||
virtual void
|
||||
search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
using BinaryIDMAPPtr = std::shared_ptr<BinaryIDMAP>;
|
||||
|
||||
} // namespace knowhere
|
|
@ -0,0 +1,162 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/IndexBinaryIVF.h>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
using stdclock = std::chrono::high_resolution_clock;
|
||||
|
||||
BinarySet
|
||||
BinaryIVF::Serialize() {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
return SerializeImpl();
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Load(const BinarySet& index_binary) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
LoadImpl(index_binary);
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
|
||||
// if (search_cfg == nullptr) {
|
||||
// KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
// }
|
||||
|
||||
GETBINARYTENSOR(dataset)
|
||||
|
||||
try {
|
||||
auto elems = rows * config->k;
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, config);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
auto pf_dist = (float*)malloc(p_dist_size);
|
||||
int32_t* pi_dist = (int32_t*)p_dist;
|
||||
for (int i = 0; i < elems; i++) {
|
||||
*(pf_dist + i) = (float)(*(pi_dist + i));
|
||||
}
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, pf_dist);
|
||||
free(p_dist);
|
||||
} else {
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
ret_ds->Set(meta::DISTANCE, p_dist);
|
||||
}
|
||||
return ret_ds;
|
||||
} catch (faiss::FaissException& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
} catch (std::exception& e) {
|
||||
KNOWHERE_THROW_MSG(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& cfg) {
|
||||
auto params = GenParams(cfg);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
stdclock::time_point before = stdclock::now();
|
||||
ivf_index->search(n, (uint8_t*)data, k, pdistances, labels);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost
|
||||
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
|
||||
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
|
||||
faiss::indexIVF_stats.quantization_time = 0;
|
||||
faiss::indexIVF_stats.search_time = 0;
|
||||
}
|
||||
|
||||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
BinaryIVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->max_codes = config.get_with_default("max_codes", size_t(0));
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
BinaryIVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETBINARYTENSOR(dataset)
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, GetMetricType(build_cfg->metric_type));
|
||||
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, build_cfg->nlist,
|
||||
GetMetricType(build_cfg->metric_type));
|
||||
index->train(rows, (uint8_t*)p_data);
|
||||
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
|
||||
index_ = index;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIVF::Count() {
|
||||
return index_->ntotal;
|
||||
}
|
||||
|
||||
int64_t
|
||||
BinaryIVF::Dimension() {
|
||||
return index_->d;
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Add(const DatasetPtr& dataset, const Config& config) {
|
||||
KNOWHERE_THROW_MSG("not support yet");
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::Seal() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
|
@ -0,0 +1,76 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "FaissBaseBinaryIndex.h"
|
||||
#include "VectorIndex.h"
|
||||
#include "faiss/IndexIVF.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
class BinaryIVF : public VectorIndex, public FaissBaseBinaryIndex {
|
||||
public:
|
||||
BinaryIVF() : FaissBaseBinaryIndex(nullptr) {
|
||||
}
|
||||
|
||||
explicit BinaryIVF(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
|
||||
}
|
||||
|
||||
BinarySet
|
||||
Serialize() override;
|
||||
|
||||
void
|
||||
Load(const BinarySet& index_binary) override;
|
||||
|
||||
DatasetPtr
|
||||
Search(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Add(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
void
|
||||
Seal() override;
|
||||
|
||||
IndexModelPtr
|
||||
Train(const DatasetPtr& dataset, const Config& config) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
|
||||
protected:
|
||||
virtual std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GenParams(const Config& config);
|
||||
|
||||
virtual void
|
||||
search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
|
||||
|
||||
protected:
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
using BinaryIVFIndexPtr = std::shared_ptr<BinaryIVF>;
|
||||
|
||||
} // namespace knowhere
|
|
@ -66,7 +66,6 @@ IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
config->CheckValid();
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto elems = rows * config->k;
|
||||
|
@ -149,12 +148,11 @@ IDMAP::GetRawIds() {
|
|||
}
|
||||
}
|
||||
|
||||
const char* type = "IDMap,Flat";
|
||||
|
||||
void
|
||||
IDMAP::Train(const Config& config) {
|
||||
config->CheckValid();
|
||||
|
||||
const char* type = "IDMap,Flat";
|
||||
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
|
||||
index_.reset(index);
|
||||
}
|
||||
|
|
|
@ -112,8 +112,8 @@ IVF::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (search_cfg != nullptr) {
|
||||
search_cfg->CheckValid(); // throw exception
|
||||
if (search_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
|
|
@ -73,9 +73,9 @@ NSG::Load(const BinarySet& index_binary) {
|
|||
DatasetPtr
|
||||
NSG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
// if (build_cfg != nullptr) {
|
||||
// build_cfg->CheckValid(); // throw exception
|
||||
// }
|
||||
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
|
|
|
@ -218,9 +218,9 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
|||
DatasetPtr
|
||||
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
SetParameters(config);
|
||||
if (config != nullptr) {
|
||||
config->CheckValid(); // throw exception
|
||||
}
|
||||
// if (config != nullptr) {
|
||||
// config->CheckValid(); // throw exception
|
||||
// }
|
||||
|
||||
auto p_data = dataset->Get<const float*>(meta::TENSOR);
|
||||
for (auto i = 0; i < 10; ++i) {
|
||||
|
|
|
@ -30,6 +30,16 @@ GetMetricType(METRICTYPE& type) {
|
|||
if (type == METRICTYPE::IP) {
|
||||
return faiss::METRIC_INNER_PRODUCT;
|
||||
}
|
||||
// binary only
|
||||
if (type == METRICTYPE::JACCARD) {
|
||||
return faiss::METRIC_Jaccard;
|
||||
}
|
||||
if (type == METRICTYPE::TANIMOTO) {
|
||||
return faiss::METRIC_Tanimoto;
|
||||
}
|
||||
if (type == METRICTYPE::HAMMING) {
|
||||
return faiss::METRIC_Hamming;
|
||||
}
|
||||
|
||||
KNOWHERE_THROW_MSG("Metric type is invalid");
|
||||
}
|
||||
|
|
|
@ -82,13 +82,27 @@ struct IVFCfg : public Cfg {
|
|||
std::stringstream
|
||||
DumpImpl() override;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using IVFConfig = std::shared_ptr<IVFCfg>;
|
||||
|
||||
struct IVFBinCfg : public IVFCfg {
|
||||
bool
|
||||
CheckValid() override {
|
||||
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
|
||||
metric_type == METRICTYPE::JACCARD) {
|
||||
return true;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "MetricType: " << int(metric_type) << " not support!";
|
||||
KNOWHERE_THROW_MSG(ss.str());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct IVFSQCfg : public IVFCfg {
|
||||
// TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16
|
||||
int64_t nbits = DEFAULT_NBITS;
|
||||
|
@ -103,10 +117,10 @@ struct IVFSQCfg : public IVFCfg {
|
|||
|
||||
IVFSQCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using IVFSQConfig = std::shared_ptr<IVFSQCfg>;
|
||||
|
||||
|
@ -126,10 +140,10 @@ struct IVFPQCfg : public IVFCfg {
|
|||
|
||||
IVFPQCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using IVFPQConfig = std::shared_ptr<IVFPQCfg>;
|
||||
|
||||
|
@ -154,10 +168,10 @@ struct NSGCfg : public IVFCfg {
|
|||
std::stringstream
|
||||
DumpImpl() override;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using NSGConfig = std::shared_ptr<NSGCfg>;
|
||||
|
||||
|
@ -193,10 +207,10 @@ struct KDTCfg : public SPTAGCfg {
|
|||
|
||||
KDTCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using KDTConfig = std::shared_ptr<KDTCfg>;
|
||||
|
||||
|
@ -207,11 +221,25 @@ struct BKTCfg : public SPTAGCfg {
|
|||
|
||||
BKTCfg() = default;
|
||||
|
||||
bool
|
||||
CheckValid() override {
|
||||
return true;
|
||||
};
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using BKTConfig = std::shared_ptr<BKTCfg>;
|
||||
|
||||
struct BinIDMAPCfg : public Cfg {
|
||||
bool
|
||||
CheckValid() override {
|
||||
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
|
||||
metric_type == METRICTYPE::JACCARD) {
|
||||
return true;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "MetricType: " << int(metric_type) << " not support!";
|
||||
KNOWHERE_THROW_MSG(ss.str());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -47,6 +47,9 @@ enum MetricType {
|
|||
METRIC_L1, ///< L1 (aka cityblock)
|
||||
METRIC_Linf, ///< infinity distance
|
||||
METRIC_Lp, ///< L_p distance, p is given by metric_arg
|
||||
METRIC_Jaccard,
|
||||
METRIC_Tanimoto,
|
||||
METRIC_Hamming,
|
||||
|
||||
/// some additional metrics defined in scipy.spatial.distance
|
||||
METRIC_Canberra = 20,
|
||||
|
|
|
@ -7,10 +7,14 @@
|
|||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/IndexBinary.h>
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/jaccard.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
|
@ -21,6 +25,9 @@ namespace faiss {
|
|||
IndexBinaryFlat::IndexBinaryFlat(idx_t d)
|
||||
: IndexBinary(d) {}
|
||||
|
||||
IndexBinaryFlat::IndexBinaryFlat(idx_t d, MetricType metric)
|
||||
: IndexBinary(d, metric) {}
|
||||
|
||||
void IndexBinaryFlat::add(idx_t n, const uint8_t *x) {
|
||||
xb.insert(xb.end(), x, x + n * code_size);
|
||||
ntotal += n;
|
||||
|
@ -34,24 +41,54 @@ void IndexBinaryFlat::reset() {
|
|||
void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels) const {
|
||||
const idx_t block_size = query_batch_size;
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
|
||||
float *D = new float[k * n];
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
|
||||
if (use_heap) {
|
||||
// We see the distances and labels as heaps.
|
||||
int_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, distances + s * k
|
||||
};
|
||||
if (use_heap) {
|
||||
// We see the distances and labels as heaps.
|
||||
|
||||
hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
float_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, D + s * k
|
||||
};
|
||||
|
||||
jaccard_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
/* ordered = */ true);
|
||||
|
||||
} else {
|
||||
FAISS_THROW_MSG("tanimoto_knn_mc not implemented");
|
||||
}
|
||||
}
|
||||
if (metric_type == METRIC_Tanimoto) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = -log2(1-D[i]);
|
||||
}
|
||||
}
|
||||
memcpy(distances, D, sizeof(float) * n * k);
|
||||
delete [] D;
|
||||
} else {
|
||||
for (idx_t s = 0; s < n; s += block_size) {
|
||||
idx_t nn = block_size;
|
||||
if (s + block_size > n) {
|
||||
nn = n - s;
|
||||
}
|
||||
if (use_heap) {
|
||||
// We see the distances and labels as heaps.
|
||||
int_maxheap_array_t res = {
|
||||
size_t(nn), size_t(k), labels + s * k, distances + s * k
|
||||
};
|
||||
|
||||
hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
|
||||
/* ordered = */ true);
|
||||
} else {
|
||||
hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
|
||||
distances + s * k, labels + s * k);
|
||||
}
|
||||
} else {
|
||||
hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
|
||||
distances + s * k, labels + s * k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ struct IndexBinaryFlat : IndexBinary {
|
|||
|
||||
explicit IndexBinaryFlat(idx_t d);
|
||||
|
||||
IndexBinaryFlat(idx_t d, MetricType metric);
|
||||
|
||||
void add(idx_t n, const uint8_t *x) override;
|
||||
|
||||
void reset() override;
|
||||
|
|
|
@ -8,17 +8,21 @@
|
|||
// Copyright 2004-present Facebook. All Rights Reserved
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexBinaryIVF.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/jaccard.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
@ -41,6 +45,24 @@ IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist)
|
|||
cp.niter = 10;
|
||||
}
|
||||
|
||||
IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist, MetricType metric)
|
||||
: IndexBinary(d, metric),
|
||||
invlists(new ArrayInvertedLists(nlist, code_size)),
|
||||
own_invlists(true),
|
||||
nprobe(1),
|
||||
max_codes(0),
|
||||
maintain_direct_map(false),
|
||||
quantizer(quantizer),
|
||||
nlist(nlist),
|
||||
own_fields(false),
|
||||
clustering_index(nullptr)
|
||||
{
|
||||
FAISS_THROW_IF_NOT (d == quantizer->d);
|
||||
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
||||
|
||||
cp.niter = 10;
|
||||
}
|
||||
|
||||
IndexBinaryIVF::IndexBinaryIVF()
|
||||
: invlists(nullptr),
|
||||
own_invlists(false),
|
||||
|
@ -270,7 +292,13 @@ void IndexBinaryIVF::train(idx_t n, const uint8_t *x) {
|
|||
std::unique_ptr<float[]> x_f(new float[n * d]);
|
||||
binary_to_real(n * d, x, x_f.get());
|
||||
|
||||
IndexFlatL2 index_tmp(d);
|
||||
IndexFlat index_tmp;
|
||||
|
||||
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
|
||||
index_tmp = IndexFlat(d, METRIC_Jaccard);
|
||||
} else {
|
||||
index_tmp = IndexFlat(d, METRIC_L2);
|
||||
}
|
||||
|
||||
if (clustering_index && verbose) {
|
||||
printf("using clustering_index of dimension %d to do the clustering\n",
|
||||
|
@ -369,6 +397,50 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
|
||||
};
|
||||
|
||||
template<class JaccardComputer, bool store_pairs>
|
||||
struct IVFBinaryScannerJaccard: BinaryInvertedListScanner {
|
||||
|
||||
JaccardComputer hc;
|
||||
size_t code_size;
|
||||
|
||||
IVFBinaryScannerJaccard (size_t code_size): code_size (code_size)
|
||||
{}
|
||||
|
||||
void set_query (const uint8_t *query_vector) override {
|
||||
hc.set (query_vector, code_size);
|
||||
}
|
||||
|
||||
idx_t list_no;
|
||||
void set_list (idx_t list_no, uint8_t /* coarse_dis */) override {
|
||||
this->list_no = list_no;
|
||||
}
|
||||
|
||||
uint32_t distance_to_code (const uint8_t *code) const override {
|
||||
}
|
||||
|
||||
size_t scan_codes (size_t n,
|
||||
const uint8_t *codes,
|
||||
const idx_t *ids,
|
||||
int32_t *simi, idx_t *idxi,
|
||||
size_t k) const override
|
||||
{
|
||||
using C = CMax<float, idx_t>;
|
||||
float* psimi = (float*)simi;
|
||||
size_t nup = 0;
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
float dis = hc.jaccard (codes);
|
||||
if (dis < psimi[0]) {
|
||||
heap_pop<C> (k, psimi, idxi);
|
||||
idx_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
heap_push<C> (k, psimi, idxi, dis, id);
|
||||
nup++;
|
||||
}
|
||||
codes += code_size;
|
||||
}
|
||||
return nup;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <bool store_pairs>
|
||||
BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
|
||||
|
@ -398,6 +470,23 @@ BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
|
|||
}
|
||||
}
|
||||
|
||||
template <bool store_pairs>
|
||||
BinaryInvertedListScanner *select_IVFBinaryScannerJaccard (size_t code_size) {
|
||||
switch (code_size) {
|
||||
#define HANDLE_CS(cs) \
|
||||
case cs: \
|
||||
return new IVFBinaryScannerJaccard<JaccardComputer ## cs, store_pairs> (cs);
|
||||
HANDLE_CS(16)
|
||||
HANDLE_CS(32)
|
||||
HANDLE_CS(64)
|
||||
HANDLE_CS(128)
|
||||
#undef HANDLE_CS
|
||||
default:
|
||||
return new IVFBinaryScannerJaccard<JaccardComputerDefault,
|
||||
store_pairs>(code_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
|
||||
size_t n,
|
||||
|
@ -491,6 +580,89 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
|
|||
|
||||
}
|
||||
|
||||
void search_knn_jaccard_heap(const IndexBinaryIVF& ivf,
|
||||
size_t n,
|
||||
const uint8_t *x,
|
||||
idx_t k,
|
||||
const idx_t *keys,
|
||||
const float * coarse_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params)
|
||||
{
|
||||
long nprobe = params ? params->nprobe : ivf.nprobe;
|
||||
long max_codes = params ? params->max_codes : ivf.max_codes;
|
||||
MetricType metric_type = ivf.metric_type;
|
||||
|
||||
// almost verbatim copy from IndexIVF::search_preassigned
|
||||
|
||||
size_t nlistv = 0, ndis = 0, nheap = 0;
|
||||
using HeapForJaccard = CMax<float, idx_t>;
|
||||
|
||||
#pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap)
|
||||
{
|
||||
std::unique_ptr<BinaryInvertedListScanner> scanner
|
||||
(ivf.get_InvertedListScannerJaccard (store_pairs));
|
||||
|
||||
#pragma omp for
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
const uint8_t *xi = x + i * ivf.code_size;
|
||||
scanner->set_query(xi);
|
||||
|
||||
const idx_t * keysi = keys + i * nprobe;
|
||||
float * simi = distances + k * i;
|
||||
idx_t * idxi = labels + k * i;
|
||||
|
||||
heap_heapify<HeapForJaccard> (k, simi, idxi);
|
||||
|
||||
size_t nscan = 0;
|
||||
|
||||
for (size_t ik = 0; ik < nprobe; ik++) {
|
||||
idx_t key = keysi[ik]; /* select the list */
|
||||
if (key < 0) {
|
||||
// not enough centroids for multiprobe
|
||||
continue;
|
||||
}
|
||||
FAISS_THROW_IF_NOT_FMT
|
||||
(key < (idx_t) ivf.nlist,
|
||||
"Invalid key=%ld at ik=%ld nlist=%ld\n",
|
||||
key, ik, ivf.nlist);
|
||||
|
||||
scanner->set_list (key, (int32_t)coarse_dis[i * nprobe + ik]);
|
||||
|
||||
nlistv++;
|
||||
|
||||
size_t list_size = ivf.invlists->list_size(key);
|
||||
InvertedLists::ScopedCodes scodes (ivf.invlists, key);
|
||||
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
||||
const Index::idx_t * ids = nullptr;
|
||||
|
||||
if (!store_pairs) {
|
||||
sids.reset (new InvertedLists::ScopedIds (ivf.invlists, key));
|
||||
ids = sids->get();
|
||||
}
|
||||
|
||||
nheap += scanner->scan_codes (list_size, scodes.get(),
|
||||
ids, (int32_t*)simi, idxi, k);
|
||||
|
||||
nscan += list_size;
|
||||
if (max_codes && nscan >= max_codes)
|
||||
break;
|
||||
}
|
||||
|
||||
ndis += nscan;
|
||||
heap_reorder<HeapForJaccard> (k, simi, idxi);
|
||||
|
||||
} // parallel for
|
||||
} // parallel
|
||||
|
||||
indexIVF_stats.nq += n;
|
||||
indexIVF_stats.nlist += nlistv;
|
||||
indexIVF_stats.ndis += ndis;
|
||||
indexIVF_stats.nheap_updates += nheap;
|
||||
|
||||
}
|
||||
|
||||
template<class HammingComputer, bool store_pairs>
|
||||
void search_knn_hamming_count(const IndexBinaryIVF& ivf,
|
||||
size_t nx,
|
||||
|
@ -634,6 +806,16 @@ BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
|
|||
}
|
||||
}
|
||||
|
||||
BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScannerJaccard
|
||||
(bool store_pairs) const
|
||||
{
|
||||
if (store_pairs) {
|
||||
return select_IVFBinaryScannerJaccard<true> (code_size);
|
||||
} else {
|
||||
return select_IVFBinaryScannerJaccard<false> (code_size);
|
||||
}
|
||||
}
|
||||
|
||||
void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
|
||||
const idx_t *idx,
|
||||
const int32_t * coarse_dis,
|
||||
|
@ -642,17 +824,38 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
|
|||
const IVFSearchParameters *params
|
||||
) const {
|
||||
|
||||
if (use_heap) {
|
||||
search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
|
||||
distances, labels, store_pairs,
|
||||
params);
|
||||
} else {
|
||||
if (store_pairs) {
|
||||
search_knn_hamming_count_1<true>
|
||||
(*this, n, x, idx, k, distances, labels, params);
|
||||
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
|
||||
if (use_heap) {
|
||||
float *D = new float[k * n];
|
||||
float *c_dis = new float [n * nprobe];
|
||||
memcpy(c_dis, coarse_dis, sizeof(float) * n * nprobe);
|
||||
search_knn_jaccard_heap (*this, n, x, k, idx, c_dis ,
|
||||
D, labels, store_pairs,
|
||||
params);
|
||||
if (metric_type == METRIC_Tanimoto) {
|
||||
for (int i = 0; i < k * n; i++) {
|
||||
D[i] = -log2(1-D[i]);
|
||||
}
|
||||
}
|
||||
memcpy(distances, D, sizeof(float) * n * k);
|
||||
delete [] D;
|
||||
delete [] c_dis;
|
||||
} else {
|
||||
search_knn_hamming_count_1<false>
|
||||
(*this, n, x, idx, k, distances, labels, params);
|
||||
//not implemented
|
||||
}
|
||||
} else {
|
||||
if (use_heap) {
|
||||
search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
|
||||
distances, labels, store_pairs,
|
||||
params);
|
||||
} else {
|
||||
if (store_pairs) {
|
||||
search_knn_hamming_count_1<true>
|
||||
(*this, n, x, idx, k, distances, labels, params);
|
||||
} else {
|
||||
search_knn_hamming_count_1<false>
|
||||
(*this, n, x, idx, k, distances, labels, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,6 +64,8 @@ struct IndexBinaryIVF : IndexBinary {
|
|||
*/
|
||||
IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist);
|
||||
|
||||
IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist, MetricType metric);
|
||||
|
||||
IndexBinaryIVF();
|
||||
|
||||
~IndexBinaryIVF() override;
|
||||
|
@ -109,6 +111,9 @@ struct IndexBinaryIVF : IndexBinary {
|
|||
virtual BinaryInvertedListScanner *get_InvertedListScanner (
|
||||
bool store_pairs=false) const;
|
||||
|
||||
virtual BinaryInvertedListScanner *get_InvertedListScannerJaccard (
|
||||
bool store_pairs=false) const;
|
||||
|
||||
/** assign the vectors, then call search_preassign */
|
||||
virtual void search(idx_t n, const uint8_t *x, idx_t k,
|
||||
int32_t *distances, idx_t *labels) const override;
|
||||
|
|
|
@ -52,6 +52,10 @@ void IndexFlat::search (idx_t n, const float *x, idx_t k,
|
|||
float_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, distances};
|
||||
knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
|
||||
} else if (metric_type == METRIC_Jaccard) {
|
||||
float_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, distances};
|
||||
knn_jaccard (x, xb.data(), d, n, ntotal, &res);
|
||||
} else {
|
||||
float_maxheap_array_t res = {
|
||||
size_t(n), size_t(k), labels, distances};
|
||||
|
|
|
@ -371,23 +371,24 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
return index;
|
||||
}
|
||||
|
||||
IndexBinary *index_binary_factory(int d, const char *description)
|
||||
IndexBinary *index_binary_factory(int d, const char *description, MetricType metric = METRIC_L2)
|
||||
{
|
||||
IndexBinary *index = nullptr;
|
||||
|
||||
int ncentroids = -1;
|
||||
int M;
|
||||
|
||||
ScopeDeleter1<IndexBinary> del_index;
|
||||
if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) {
|
||||
IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
|
||||
new IndexBinaryHNSW(d, M), d, ncentroids
|
||||
new IndexBinaryHNSW(d, M), d, ncentroids
|
||||
);
|
||||
index_ivf->own_fields = true;
|
||||
index = index_ivf;
|
||||
|
||||
} else if (sscanf(description, "BIVF%d", &ncentroids) == 1) {
|
||||
IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
|
||||
new IndexBinaryFlat(d), d, ncentroids
|
||||
new IndexBinaryFlat(d), d, ncentroids
|
||||
);
|
||||
index_ivf->own_fields = true;
|
||||
index = index_ivf;
|
||||
|
@ -397,13 +398,27 @@ IndexBinary *index_binary_factory(int d, const char *description)
|
|||
index = index_hnsw;
|
||||
|
||||
} else if (std::string(description) == "BFlat") {
|
||||
index = new IndexBinaryFlat(d);
|
||||
IndexBinary* index_x = new IndexBinaryFlat(d, metric);
|
||||
|
||||
{
|
||||
IndexBinaryIDMap *idmap = new IndexBinaryIDMap(index_x);
|
||||
del_index.set (idmap);
|
||||
idmap->own_fields = true;
|
||||
index_x = idmap;
|
||||
}
|
||||
|
||||
if (index_x) {
|
||||
index = index_x;
|
||||
del_index.set(index);
|
||||
}
|
||||
|
||||
} else {
|
||||
FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
|
||||
description);
|
||||
description);
|
||||
}
|
||||
|
||||
del_index.release();
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace faiss {
|
|||
Index *index_factory (int d, const char *description,
|
||||
MetricType metric = METRIC_L2);
|
||||
|
||||
IndexBinary *index_binary_factory (int d, const char *description);
|
||||
IndexBinary *index_binary_factory (int d, const char *description, MetricType metric = METRIC_L2);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -340,7 +340,78 @@ static void knn_L2sqr_blas (const float * x,
|
|||
|
||||
}
|
||||
|
||||
template<class DistanceCorrection>
|
||||
static void knn_jaccard_blas (const float * x,
|
||||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float_maxheap_array_t * res,
|
||||
const DistanceCorrection &corr)
|
||||
{
|
||||
res->heapify ();
|
||||
|
||||
// BLAS does not like empty matrices
|
||||
if (nx == 0 || ny == 0) return;
|
||||
|
||||
size_t k = res->k;
|
||||
|
||||
/* block sizes */
|
||||
const size_t bs_x = 4096, bs_y = 1024;
|
||||
// const size_t bs_x = 16, bs_y = 16;
|
||||
float *ip_block = new float[bs_x * bs_y];
|
||||
float *x_norms = new float[nx];
|
||||
float *y_norms = new float[ny];
|
||||
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
|
||||
|
||||
fvec_norms_L2sqr (x_norms, x, d, nx);
|
||||
fvec_norms_L2sqr (y_norms, y, d, ny);
|
||||
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
||||
size_t i1 = i0 + bs_x;
|
||||
if(i1 > nx) i1 = nx;
|
||||
|
||||
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
||||
size_t j1 = j0 + bs_y;
|
||||
if (j1 > ny) j1 = ny;
|
||||
/* compute the actual dot products */
|
||||
{
|
||||
float one = 1, zero = 0;
|
||||
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
||||
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
||||
y + j0 * d, &di,
|
||||
x + i0 * d, &di, &zero,
|
||||
ip_block, &nyi);
|
||||
}
|
||||
|
||||
/* collect minima */
|
||||
#pragma omp parallel for
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
float * __restrict simi = res->get_val(i);
|
||||
int64_t * __restrict idxi = res->get_ids (i);
|
||||
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
||||
|
||||
for (size_t j = j0; j < j1; j++) {
|
||||
float ip = *ip_line++;
|
||||
float dis = 1.0 - ip / (x_norms[i] + y_norms[j] - ip);
|
||||
|
||||
// negative values can occur for identical vectors
|
||||
// due to roundoff errors
|
||||
if (dis < 0) dis = 0;
|
||||
|
||||
dis = corr (dis, i, j);
|
||||
|
||||
if (dis < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InterruptCallback::check ();
|
||||
}
|
||||
res->reorder ();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -387,6 +458,20 @@ void knn_L2sqr (const float * x,
|
|||
}
|
||||
}
|
||||
|
||||
void knn_jaccard (const float * x,
|
||||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float_maxheap_array_t * res)
|
||||
{
|
||||
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
||||
// knn_jaccard_sse (x, y, d, nx, ny, res);
|
||||
printf("sse_not implemented!\n");
|
||||
} else {
|
||||
NopDistanceCorrection nop;
|
||||
knn_jaccard_blas (x, y, d, nx, ny, res, nop);
|
||||
}
|
||||
}
|
||||
|
||||
struct BaseShiftDistanceCorrection {
|
||||
const float *base_shift;
|
||||
float operator()(float dis, size_t /*qno*/, size_t bno) const {
|
||||
|
|
|
@ -47,6 +47,10 @@ float fvec_Linf (
|
|||
const float * y,
|
||||
size_t d);
|
||||
|
||||
float fvec_jaccard (
|
||||
const float * x,
|
||||
const float * y,
|
||||
size_t d);
|
||||
|
||||
/** Compute pairwise distances between sets of vectors
|
||||
*
|
||||
|
@ -175,7 +179,11 @@ void knn_L2sqr (
|
|||
size_t d, size_t nx, size_t ny,
|
||||
float_maxheap_array_t * res);
|
||||
|
||||
|
||||
void knn_jaccard (
|
||||
const float * x,
|
||||
const float * y,
|
||||
size_t d, size_t nx, size_t ny,
|
||||
float_maxheap_array_t * res);
|
||||
|
||||
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
|
||||
* computed distances.
|
||||
|
|
|
@ -112,7 +112,39 @@ struct VectorDistanceJensenShannon {
|
|||
}
|
||||
};
|
||||
|
||||
struct VectorDistanceJaccard {
|
||||
size_t d;
|
||||
|
||||
float operator () (const float *x, const float *y) const {
|
||||
float accu_num = 0, accu_den = 0;
|
||||
const float EPSILON = 0.000001;
|
||||
for (size_t i = 0; i < d; i++) {
|
||||
float xi = x[i], yi = y[i];
|
||||
if (fabs (xi - yi) < EPSILON) {
|
||||
accu_num += xi;
|
||||
accu_den += xi;
|
||||
} else {
|
||||
accu_den += xi;
|
||||
accu_den += yi;
|
||||
}
|
||||
}
|
||||
return 1 - accu_num / accu_den;
|
||||
}
|
||||
};
|
||||
|
||||
struct VectorDistanceTanimoto {
|
||||
size_t d;
|
||||
|
||||
float operator () (const float *x, const float *y) const {
|
||||
float accu_num = 0, accu_den = 0;
|
||||
for (size_t i = 0; i < d; i++) {
|
||||
float xi = x[i], yi = y[i];
|
||||
accu_num += xi * yi;
|
||||
accu_den += xi * xi + yi * yi - xi * yi;
|
||||
}
|
||||
return -log2(accu_num / accu_den) ;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
@ -263,6 +295,18 @@ void pairwise_extra_distances (
|
|||
dis, ldq, ldb, ldd);
|
||||
break;
|
||||
}
|
||||
case METRIC_Jaccard: {
|
||||
VectorDistanceJaccard vd({(size_t) d});
|
||||
pairwise_extra_distances_template(vd, nq, xq, nb, xb,
|
||||
dis, ldq, ldb, ldd);
|
||||
break;
|
||||
}
|
||||
case METRIC_Tanimoto: {
|
||||
VectorDistanceTanimoto vd({(size_t) d});
|
||||
pairwise_extra_distances_template(vd, nq, xq, nb, xb,
|
||||
dis, ldq, ldb, ldd);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
FAISS_THROW_MSG ("metric type not implemented");
|
||||
}
|
||||
|
@ -296,6 +340,16 @@ void knn_extra_metrics (
|
|||
knn_extra_metrics_template (vd, x, y, nx, ny, res);
|
||||
break;
|
||||
}
|
||||
case METRIC_Jaccard: {
|
||||
VectorDistanceJaccard vd({(size_t) d});
|
||||
knn_extra_metrics_template(vd, x, y, nx, ny, res);
|
||||
break;
|
||||
}
|
||||
case METRIC_Tanimoto: {
|
||||
VectorDistanceTanimoto vd({(size_t) d});
|
||||
knn_extra_metrics_template(vd, x, y, nx, ny, res);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
FAISS_THROW_MSG ("metric type not implemented");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
//
|
||||
// Created by czr on 2019/12/19.
|
||||
//
|
||||
|
||||
namespace faiss {
|
||||
|
||||
struct JaccardComputer16 {
|
||||
uint64_t a0, a1;
|
||||
|
||||
JaccardComputer16 () {}
|
||||
|
||||
JaccardComputer16 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size == 16);
|
||||
const uint64_t *a = (uint64_t *)a8;
|
||||
a0 = a[0]; a1 = a[1];
|
||||
}
|
||||
|
||||
inline float jaccard (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
|
||||
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct JaccardComputer32 {
|
||||
uint64_t a0, a1, a2, a3;
|
||||
|
||||
JaccardComputer32 () {}
|
||||
|
||||
JaccardComputer32 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size == 32);
|
||||
const uint64_t *a = (uint64_t *)a8;
|
||||
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
|
||||
}
|
||||
|
||||
inline float jaccard (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
|
||||
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
|
||||
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
|
||||
popcount64 (b[2] | a2) + popcount64 (b[3] | a3);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct JaccardComputer64 {
|
||||
uint64_t a0, a1, a2, a3, a4, a5, a6, a7;
|
||||
|
||||
JaccardComputer64 () {}
|
||||
|
||||
JaccardComputer64 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size == 64);
|
||||
const uint64_t *a = (uint64_t *)a8;
|
||||
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
|
||||
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
|
||||
}
|
||||
|
||||
inline float jaccard (const uint8_t *b8) const {
|
||||
const uint64_t *b = (uint64_t *)b8;
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
|
||||
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
|
||||
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
|
||||
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
|
||||
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
|
||||
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
|
||||
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
|
||||
popcount64 (b[6] | a6) + popcount64 (b[7] | a7);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct JaccardComputer128 {
|
||||
uint64_t a0, a1, a2, a3, a4, a5, a6, a7,
|
||||
a8, a9, a10, a11, a12, a13, a14, a15;
|
||||
|
||||
JaccardComputer128 () {}
|
||||
|
||||
JaccardComputer128 (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a16, int code_size) {
|
||||
assert (code_size == 128 );
|
||||
const uint64_t *a = (uint64_t *)a16;
|
||||
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
|
||||
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
|
||||
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
|
||||
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
|
||||
}
|
||||
|
||||
inline float jaccard (const uint8_t *b16) const {
|
||||
const uint64_t *b = (uint64_t *)b16;
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
|
||||
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
|
||||
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
|
||||
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
|
||||
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
|
||||
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
|
||||
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
|
||||
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
|
||||
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
|
||||
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
|
||||
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
|
||||
popcount64 (b[6] | a6) + popcount64 (b[7] | a7) +
|
||||
popcount64 (b[8] | a8) + popcount64 (b[9] | a9) +
|
||||
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
|
||||
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
|
||||
popcount64 (b[14] | a14) + popcount64 (b[15] | a15);
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct JaccardComputerDefault {
|
||||
const uint8_t *a;
|
||||
int n;
|
||||
|
||||
JaccardComputerDefault () {}
|
||||
|
||||
JaccardComputerDefault (const uint8_t *a8, int code_size) {
|
||||
set (a8, code_size);
|
||||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
a = a8;
|
||||
n = code_size;
|
||||
}
|
||||
|
||||
float jaccard (const uint8_t *b8) const {
|
||||
int accu_num = 0;
|
||||
int accu_den = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
accu_num += popcount64(a[i] & b8[i]);
|
||||
accu_den += popcount64(a[i] | b8[i]);
|
||||
}
|
||||
if (accu_num == 0)
|
||||
return 1.0;
|
||||
return 1.0 - (float)(accu_num) / (float)(accu_den);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// default template
|
||||
template<int CODE_SIZE>
|
||||
struct JaccardComputer: JaccardComputerDefault {
|
||||
JaccardComputer (const uint8_t *a, int code_size):
|
||||
JaccardComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
template<> struct JaccardComputer<CODE_SIZE>: \
|
||||
JaccardComputer ## CODE_SIZE { \
|
||||
JaccardComputer (const uint8_t *a): \
|
||||
JaccardComputer ## CODE_SIZE(a, CODE_SIZE) {} \
|
||||
}
|
||||
|
||||
SPECIALIZED_HC(16);
|
||||
SPECIALIZED_HC(32);
|
||||
SPECIALIZED_HC(64);
|
||||
SPECIALIZED_HC(128);
|
||||
|
||||
#undef SPECIALIZED_HC
|
||||
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
#include <faiss/utils/jaccard.h>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
size_t jaccard_batch_size = 65536;
|
||||
|
||||
template <class JaccardComputer>
|
||||
static
|
||||
void jaccard_knn_hc(
|
||||
int bytes_per_code,
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * bs1,
|
||||
const uint8_t * bs2,
|
||||
size_t n2,
|
||||
bool order = true,
|
||||
bool init_heap = true)
|
||||
{
|
||||
size_t k = ha->k;
|
||||
if (init_heap) ha->heapify ();
|
||||
|
||||
const size_t block_size = jaccard_batch_size;
|
||||
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
|
||||
const size_t j1 = std::min(j0 + block_size, n2);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < ha->nh; i++) {
|
||||
JaccardComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
|
||||
|
||||
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
|
||||
tadis_t dis;
|
||||
tadis_t * __restrict bh_val_ = ha->val + i * k;
|
||||
int64_t * __restrict bh_ids_ = ha->ids + i * k;
|
||||
size_t j;
|
||||
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
|
||||
dis = hc.jaccard (bs2_);
|
||||
if (dis < bh_val_[0]) {
|
||||
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if (order) ha->reorder ();
|
||||
}
|
||||
|
||||
void jaccard_knn_hc (
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int order)
|
||||
{
|
||||
switch (ncodes) {
|
||||
case 16:
|
||||
jaccard_knn_hc<faiss::JaccardComputer16>
|
||||
(16, ha, a, b, nb, order, true);
|
||||
break;
|
||||
case 32:
|
||||
jaccard_knn_hc<faiss::JaccardComputer32>
|
||||
(32, ha, a, b, nb, order, true);
|
||||
break;
|
||||
case 64:
|
||||
jaccard_knn_hc<faiss::JaccardComputer64>
|
||||
(64, ha, a, b, nb, order, true);
|
||||
break;
|
||||
case 128:
|
||||
jaccard_knn_hc<faiss::JaccardComputer128>
|
||||
(128, ha, a, b, nb, order, true);
|
||||
break;
|
||||
default:
|
||||
jaccard_knn_hc<faiss::JaccardComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
#ifndef FAISS_JACCARD_H
|
||||
#define FAISS_JACCARD_H
|
||||
|
||||
#include <faiss/utils/hamming.h>
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <faiss/utils/Heap.h>
|
||||
|
||||
/* The Jaccard distance type */
|
||||
typedef float tadis_t;
|
||||
|
||||
namespace faiss {
|
||||
|
||||
extern size_t jaccard_batch_size;
|
||||
|
||||
/** Return the k smallest Jaccard distances for a set of binary query vectors,
|
||||
* using a max heap.
|
||||
* @param a queries, size ha->nh * ncodes
|
||||
* @param b database, size nb * ncodes
|
||||
* @param nb number of database vectors
|
||||
* @param ncodes size of the binary codes (bytes)
|
||||
* @param ordered if != 0: order the results by decreasing distance
|
||||
* (may be bottleneck for k/n > 0.01) */
|
||||
void jaccard_knn_hc (
|
||||
float_maxheap_array_t * ha,
|
||||
const uint8_t * a,
|
||||
const uint8_t * b,
|
||||
size_t nb,
|
||||
size_t ncodes,
|
||||
int ordered);
|
||||
|
||||
} //namespace faiss
|
||||
|
||||
#include <faiss/utils/jaccard-inl.h>
|
||||
|
||||
#endif //FAISS_JACCARD_H
|
|
@ -58,6 +58,9 @@ set(ivf_srcs
|
|||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp
|
||||
)
|
||||
if (KNOWHERE_GPU_VERSION)
|
||||
set(ivf_srcs ${ivf_srcs}
|
||||
|
@ -74,6 +77,11 @@ if (NOT TARGET test_ivf)
|
|||
endif ()
|
||||
target_link_libraries(test_ivf ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
|
||||
if (NOT TARGET test_binaryivf)
|
||||
add_executable(test_binaryivf test_binaryivf.cpp ${ivf_srcs} ${util_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_binaryivf ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
|
||||
|
||||
#<IDMAP-TEST>
|
||||
if (NOT TARGET test_idmap)
|
||||
|
@ -81,6 +89,12 @@ if (NOT TARGET test_idmap)
|
|||
endif ()
|
||||
target_link_libraries(test_idmap ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
|
||||
#<BinaryIDMAP-TEST>
|
||||
if (NOT TARGET test_binaryidmap)
|
||||
add_executable(test_binaryidmap test_binaryidmap.cpp ${ivf_srcs} ${util_srcs})
|
||||
endif ()
|
||||
target_link_libraries(test_binaryidmap ${depend_libs} ${unittest_libs} ${basic_libs})
|
||||
|
||||
#<SPTAG-TEST>
|
||||
set(sptag_srcs
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/SptagAdapter.cpp
|
||||
|
@ -104,7 +118,9 @@ if (KNOWHERE_GPU_VERSION)
|
|||
endif ()
|
||||
|
||||
install(TARGETS test_ivf DESTINATION unittest)
|
||||
install(TARGETS test_binaryivf DESTINATION unittest)
|
||||
install(TARGETS test_idmap DESTINATION unittest)
|
||||
install(TARGETS test_binaryidmap DESTINATION unittest)
|
||||
install(TARGETS test_sptag DESTINATION unittest)
|
||||
if (KNOWHERE_GPU_VERSION)
|
||||
install(TARGETS test_gpuresource DESTINATION unittest)
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
|
||||
#include "Helper.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
Init_with_binary_default();
|
||||
index_ = std::make_shared<knowhere::BinaryIDMAP>();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override{};
|
||||
|
||||
protected:
|
||||
knowhere::BinaryIDMAPPtr index_ = nullptr;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
|
||||
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
|
||||
knowhere::METRICTYPE::HAMMING));
|
||||
|
||||
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
knowhere::METRICTYPE MetricType = GetParam();
|
||||
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = MetricType;
|
||||
|
||||
index_->Train(conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
|
||||
ASSERT_TRUE(index_->GetRawIds() != nullptr);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<knowhere::BinaryIDMAP>();
|
||||
new_index->Load(binaryset);
|
||||
auto re_result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(re_result, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
||||
auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
knowhere::METRICTYPE MetricType = GetParam();
|
||||
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = MetricType;
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->Train(conf);
|
||||
index_->Add(base_dataset, knowhere::Config());
|
||||
auto re_result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(re_result, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto binaryset = index_->Serialize();
|
||||
auto bin = binaryset.GetByName("BinaryIVF");
|
||||
|
||||
std::string filename = "/tmp/bianryidmap_test_serialize.bin";
|
||||
auto load_data = new uint8_t[bin->size];
|
||||
serialize(filename, bin, load_data);
|
||||
|
||||
binaryset.clear();
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
data.reset(load_data);
|
||||
binaryset.Append("BinaryIVF", data, bin->size);
|
||||
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
||||
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
|
||||
|
||||
#include "unittest/Helper.h"
|
||||
#include "unittest/utils.h"
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
knowhere::METRICTYPE MetricType = GetParam();
|
||||
Init_with_binary_default();
|
||||
// nb = 1000000;
|
||||
// nq = 1000;
|
||||
// k = 1000;
|
||||
// Generate(DIM, NB, NQ);
|
||||
index_ = std::make_shared<knowhere::BinaryIVF>();
|
||||
auto x_conf = std::make_shared<knowhere::IVFBinCfg>();
|
||||
x_conf->d = dim;
|
||||
x_conf->k = k;
|
||||
x_conf->metric_type = MetricType;
|
||||
x_conf->nlist = 100;
|
||||
x_conf->nprobe = 10;
|
||||
conf = x_conf;
|
||||
conf->Dump();
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string index_type;
|
||||
knowhere::Config conf;
|
||||
knowhere::BinaryIVFIndexPtr index_ = nullptr;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
|
||||
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
|
||||
knowhere::METRICTYPE::HAMMING));
|
||||
|
||||
TEST_P(BinaryIVFTest, binaryivf_basic) {
|
||||
assert(!xb.empty());
|
||||
|
||||
// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
|
||||
// index_->set_preprocessor(preprocessor);
|
||||
|
||||
index_->Train(base_dataset, conf);
|
||||
// index_->set_index_model(model);
|
||||
// index_->Add(base_dataset, conf);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
|
||||
TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
||||
auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
|
||||
FileIOWriter writer(filename);
|
||||
writer(static_cast<void*>(bin->data.get()), bin->size);
|
||||
|
||||
FileIOReader reader(filename);
|
||||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
// {
|
||||
// // serialize index-model
|
||||
// auto model = index_->Train(base_dataset, conf);
|
||||
// auto binaryset = model->Serialize();
|
||||
// auto bin = binaryset.GetByName("BinaryIVF");
|
||||
//
|
||||
// std::string filename = "/tmp/binaryivf_test_model_serialize.bin";
|
||||
// auto load_data = new uint8_t[bin->size];
|
||||
// serialize(filename, bin, load_data);
|
||||
//
|
||||
// binaryset.clear();
|
||||
// auto data = std::make_shared<uint8_t>();
|
||||
// data.reset(load_data);
|
||||
// binaryset.Append("BinaryIVF", data, bin->size);
|
||||
//
|
||||
// model->Load(binaryset);
|
||||
//
|
||||
// index_->set_index_model(model);
|
||||
// index_->Add(base_dataset, conf);
|
||||
// auto result = index_->Search(query_dataset, conf);
|
||||
// AssertAnns(result, nq, conf->k);
|
||||
// }
|
||||
|
||||
{
|
||||
// serialize index
|
||||
index_->Train(base_dataset, conf);
|
||||
// index_->set_index_model(model);
|
||||
// index_->Add(base_dataset, conf);
|
||||
auto binaryset = index_->Serialize();
|
||||
auto bin = binaryset.GetByName("BinaryIVF");
|
||||
|
||||
std::string filename = "/tmp/binaryivf_test_serialize.bin";
|
||||
auto load_data = new uint8_t[bin->size];
|
||||
serialize(filename, bin, load_data);
|
||||
|
||||
binaryset.clear();
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
data.reset(load_data);
|
||||
binaryset.Append("BinaryIVF", data, bin->size);
|
||||
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
}
|
|
@ -42,11 +42,13 @@ class SPTAGTest : public DataGen, public TestWithParam<std::string> {
|
|||
auto tempconf = std::make_shared<knowhere::KDTCfg>();
|
||||
tempconf->tptnumber = 1;
|
||||
tempconf->k = 10;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
conf = tempconf;
|
||||
} else {
|
||||
auto tempconf = std::make_shared<knowhere::BKTCfg>();
|
||||
tempconf->tptnumber = 1;
|
||||
tempconf->k = 10;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
conf = tempconf;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,11 @@ DataGen::Init_with_default() {
|
|||
Generate(dim, nb, nq);
|
||||
}
|
||||
|
||||
void
|
||||
BinaryDataGen::Init_with_binary_default() {
|
||||
Generate(dim, nb, nq);
|
||||
}
|
||||
|
||||
void
|
||||
DataGen::Generate(const int& dim, const int& nb, const int& nq) {
|
||||
this->nb = nb;
|
||||
|
@ -52,6 +57,21 @@ DataGen::Generate(const int& dim, const int& nb, const int& nq) {
|
|||
query_dataset = generate_query_dataset(nq, dim, xq.data());
|
||||
}
|
||||
|
||||
void
|
||||
BinaryDataGen::Generate(const int& dim, const int& nb, const int& nq) {
|
||||
this->nb = nb;
|
||||
this->nq = nq;
|
||||
this->dim = dim;
|
||||
|
||||
int64_t dim_x = dim / 8;
|
||||
GenBinaryAll(dim_x, nb, xb, ids, nq, xq);
|
||||
assert(xb.size() == (size_t)dim_x * nb);
|
||||
assert(xq.size() == (size_t)dim_x * nq);
|
||||
|
||||
base_dataset = generate_binary_dataset(nb, dim, xb.data(), ids.data());
|
||||
query_dataset = generate_binary_query_dataset(nq, dim, xq.data());
|
||||
}
|
||||
|
||||
knowhere::DatasetPtr
|
||||
DataGen::GenQuery(const int& nq) {
|
||||
xq.resize(nq * dim);
|
||||
|
@ -78,6 +98,23 @@ GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
GenBinaryAll(const int64_t dim, const int64_t& nb, std::vector<uint8_t>& xb, std::vector<int64_t>& ids,
|
||||
const int64_t& nq, std::vector<uint8_t>& xq) {
|
||||
xb.resize(nb * dim);
|
||||
xq.resize(nq * dim);
|
||||
ids.resize(nb);
|
||||
GenBinaryAll(dim, nb, xb.data(), ids.data(), nq, xq.data());
|
||||
}
|
||||
|
||||
void
|
||||
GenBinaryAll(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids, const int64_t& nq, uint8_t* xq) {
|
||||
GenBinaryBase(dim, nb, xb, ids);
|
||||
for (int64_t i = 0; i < nq * dim; ++i) {
|
||||
xq[i] = xb[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
|
||||
for (auto i = 0; i < nb; ++i) {
|
||||
|
@ -90,6 +127,17 @@ GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
GenBinaryBase(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids) {
|
||||
for (auto i = 0; i < nb; ++i) {
|
||||
for (auto j = 0; j < dim; ++j) {
|
||||
// p_data[i * d + j] = float(base + i);
|
||||
xb[i * dim + j] = (uint8_t)lrand48();
|
||||
}
|
||||
ids[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
FileIOReader::FileIOReader(const std::string& fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::in | std::ios::binary);
|
||||
|
@ -130,6 +178,16 @@ generate_dataset(int64_t nb, int64_t dim, const float* xb, const int64_t* ids) {
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_binary_dataset(int64_t nb, int64_t dim, const uint8_t* xb, const int64_t* ids) {
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xb);
|
||||
ret_ds->Set(knowhere::meta::IDS, ids);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_query_dataset(int64_t nb, int64_t dim, const float* xb) {
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
|
@ -139,6 +197,15 @@ generate_query_dataset(int64_t nb, int64_t dim, const float* xb) {
|
|||
return ret_ds;
|
||||
}
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_binary_query_dataset(int64_t nb, int64_t dim, const uint8_t* xb) {
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xb);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
void
|
||||
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
|
||||
auto ids = result->Get<int64_t*>(knowhere::meta::IDS);
|
||||
|
|
|
@ -49,6 +49,29 @@ class DataGen {
|
|||
knowhere::DatasetPtr query_dataset = nullptr;
|
||||
};
|
||||
|
||||
class BinaryDataGen {
|
||||
protected:
|
||||
void
|
||||
Init_with_binary_default();
|
||||
|
||||
void
|
||||
Generate(const int& dim, const int& nb, const int& nq);
|
||||
|
||||
knowhere::DatasetPtr
|
||||
GenQuery(const int& nq);
|
||||
|
||||
protected:
|
||||
int nb = 10000;
|
||||
int nq = 10;
|
||||
int dim = 512;
|
||||
int k = 10;
|
||||
std::vector<uint8_t> xb;
|
||||
std::vector<uint8_t> xq;
|
||||
std::vector<int64_t> ids;
|
||||
knowhere::DatasetPtr base_dataset = nullptr;
|
||||
knowhere::DatasetPtr query_dataset = nullptr;
|
||||
};
|
||||
|
||||
extern void
|
||||
GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector<int64_t>& ids, const int64_t& nq,
|
||||
std::vector<float>& xq);
|
||||
|
@ -56,18 +79,34 @@ GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector
|
|||
extern void
|
||||
GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq);
|
||||
|
||||
extern void
|
||||
GenBinaryAll(const int64_t dim, const int64_t& nb, std::vector<uint8_t>& xb, std::vector<int64_t>& ids,
|
||||
const int64_t& nq, std::vector<uint8_t>& xq);
|
||||
|
||||
extern void
|
||||
GenBinaryAll(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids, const int64_t& nq, uint8_t* xq);
|
||||
|
||||
extern void
|
||||
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids);
|
||||
|
||||
extern void
|
||||
GenBinaryBase(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids);
|
||||
|
||||
extern void
|
||||
InitLog();
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_dataset(int64_t nb, int64_t dim, const float* xb, const int64_t* ids);
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_binary_dataset(int64_t nb, int64_t dim, const uint8_t* xb, const int64_t* ids);
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_query_dataset(int64_t nb, int64_t dim, const float* xb);
|
||||
|
||||
knowhere::DatasetPtr
|
||||
generate_binary_query_dataset(int64_t nb, int64_t dim, const uint8_t* xb);
|
||||
|
||||
void
|
||||
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k);
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@
|
|||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors)
|
||||
: Job(JobType::SEARCH), context_(context), topk_(topk), nq_(nq), nprobe_(nprobe), vectors_(vectors) {
|
||||
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
|
||||
const engine::VectorsData& vectors)
|
||||
: Job(JobType::SEARCH), context_(context), topk_(topk), nprobe_(nprobe), vectors_(vectors) {
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -77,7 +77,7 @@ json
|
|||
SearchJob::Dump() const {
|
||||
json ret{
|
||||
{"topk", topk_},
|
||||
{"nq", nq_},
|
||||
{"nq", vectors_.vector_count_},
|
||||
{"nprobe", nprobe_},
|
||||
};
|
||||
auto base = Job::Dump();
|
||||
|
|
|
@ -46,8 +46,8 @@ using ResultDistances = engine::ResultDistances;
|
|||
|
||||
class SearchJob : public Job {
|
||||
public:
|
||||
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nq, uint64_t nprobe,
|
||||
const float* vectors);
|
||||
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
|
||||
const engine::VectorsData& vectors);
|
||||
|
||||
public:
|
||||
bool
|
||||
|
@ -82,7 +82,7 @@ class SearchJob : public Job {
|
|||
|
||||
uint64_t
|
||||
nq() const {
|
||||
return nq_;
|
||||
return vectors_.vector_count_;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
|
@ -90,7 +90,7 @@ class SearchJob : public Job {
|
|||
return nprobe_;
|
||||
}
|
||||
|
||||
const float*
|
||||
const engine::VectorsData&
|
||||
vectors() const {
|
||||
return vectors_;
|
||||
}
|
||||
|
@ -109,10 +109,9 @@ class SearchJob : public Job {
|
|||
const std::shared_ptr<server::Context> context_;
|
||||
|
||||
uint64_t topk_ = 0;
|
||||
uint64_t nq_ = 0;
|
||||
uint64_t nprobe_ = 0;
|
||||
// TODO: smart pointer
|
||||
const float* vectors_ = nullptr;
|
||||
const engine::VectorsData& vectors_;
|
||||
|
||||
Id2IndexMap index_files_;
|
||||
// TODO: column-base better ?
|
||||
|
|
|
@ -103,8 +103,10 @@ CollectFileMetrics(int file_type, size_t file_size) {
|
|||
XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, TableFileSchemaPtr file, TaskLabelPtr label)
|
||||
: Task(TaskType::SearchTask, std::move(label)), context_(context), file_(file) {
|
||||
if (file_) {
|
||||
if (file_->metric_type_ != static_cast<int>(MetricType::L2)) {
|
||||
metric_l2 = false;
|
||||
// distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ...
|
||||
// similarity -- infinity value means two vectors equal, descending reduce, IP
|
||||
if (file_->metric_type_ == static_cast<int>(MetricType::IP)) {
|
||||
ascending_reduce = false;
|
||||
}
|
||||
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, (EngineType)file_->engine_type_,
|
||||
(MetricType)file_->metric_type_, file_->nlist_);
|
||||
|
@ -207,7 +209,7 @@ XSearchTask::Execute() {
|
|||
uint64_t nq = search_job->nq();
|
||||
uint64_t topk = search_job->topk();
|
||||
uint64_t nprobe = search_job->nprobe();
|
||||
const float* vectors = search_job->vectors();
|
||||
const engine::VectorsData& vectors = search_job->vectors();
|
||||
|
||||
output_ids.resize(topk * nq);
|
||||
output_distance.resize(topk * nq);
|
||||
|
@ -221,8 +223,14 @@ XSearchTask::Execute() {
|
|||
ResMgrInst::GetInstance()->GetResource(path().Last())->type() == ResourceType::CPU) {
|
||||
hybrid = true;
|
||||
}
|
||||
Status s =
|
||||
index_engine_->Search(nq, vectors, topk, nprobe, output_distance.data(), output_ids.data(), hybrid);
|
||||
Status s;
|
||||
if (!vectors.float_data_.empty()) {
|
||||
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, nprobe, output_distance.data(),
|
||||
output_ids.data(), hybrid);
|
||||
} else if (!vectors.binary_data_.empty()) {
|
||||
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, nprobe, output_distance.data(),
|
||||
output_ids.data(), hybrid);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
search_job->GetStatus() = s;
|
||||
search_job->SearchDone(index_id_);
|
||||
|
@ -236,7 +244,7 @@ XSearchTask::Execute() {
|
|||
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(search_job->mutex());
|
||||
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
|
||||
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
|
||||
search_job->GetResultIds(), search_job->GetResultDistances());
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,10 @@ class XSearchTask : public Task {
|
|||
size_t index_id_ = 0;
|
||||
int index_type_ = 0;
|
||||
ExecutionEnginePtr index_engine_ = nullptr;
|
||||
bool metric_l2 = true;
|
||||
|
||||
// distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ...
|
||||
// similarity -- infinity value means two vectors equal, descending reduce, IP
|
||||
bool ascending_reduce = true;
|
||||
};
|
||||
|
||||
} // namespace scheduler
|
||||
|
|
|
@ -75,11 +75,9 @@ RequestHandler::CreateIndex(const std::shared_ptr<Context>& context, const std::
|
|||
}
|
||||
|
||||
Status
|
||||
RequestHandler::Insert(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
std::vector<float>& data_list, const std::string& partition_tag,
|
||||
std::vector<int64_t>& id_array) {
|
||||
BaseRequestPtr request_ptr =
|
||||
InsertRequest::Create(context, table_name, record_size, data_list, partition_tag, id_array);
|
||||
RequestHandler::Insert(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
engine::VectorsData& vectors, const std::string& partition_tag) {
|
||||
BaseRequestPtr request_ptr = InsertRequest::Create(context, table_name, vectors, partition_tag);
|
||||
RequestScheduler::ExecRequest(request_ptr);
|
||||
|
||||
return request_ptr->status();
|
||||
|
@ -94,13 +92,13 @@ RequestHandler::ShowTables(const std::shared_ptr<Context>& context, std::vector<
|
|||
}
|
||||
|
||||
Status
|
||||
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
const std::vector<float>& data_list,
|
||||
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors,
|
||||
const std::vector<std::pair<std::string, std::string>>& range_list, int64_t topk, int64_t nprobe,
|
||||
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
|
||||
TopKQueryResult& result) {
|
||||
BaseRequestPtr request_ptr = SearchRequest::Create(context, table_name, record_size, data_list, range_list, topk,
|
||||
nprobe, partition_list, file_id_list, result);
|
||||
BaseRequestPtr request_ptr = SearchRequest::Create(context, table_name, vectors, range_list, topk, nprobe,
|
||||
partition_list, file_id_list, result);
|
||||
RequestScheduler::ExecRequest(request_ptr);
|
||||
|
||||
return request_ptr->status();
|
||||
|
|
|
@ -47,15 +47,15 @@ class RequestHandler {
|
|||
int64_t nlist);
|
||||
|
||||
Status
|
||||
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array);
|
||||
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
|
||||
const std::string& partition_tag);
|
||||
|
||||
Status
|
||||
ShowTables(const std::shared_ptr<Context>& context, std::vector<std::string>& tables);
|
||||
|
||||
Status
|
||||
Search(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
|
||||
Search(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
|
||||
const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
|
||||
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
|
||||
TopKQueryResult& result);
|
||||
|
||||
|
|
|
@ -71,23 +71,36 @@ CreateIndexRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
// step 2: binary and float vector support different index/metric type, need to adapt here
|
||||
engine::meta::TableSchema table_info;
|
||||
table_info.table_id_ = table_name_;
|
||||
status = DBWrapper::DB()->DescribeTable(table_info);
|
||||
|
||||
int32_t adapter_index_type = index_type_;
|
||||
if (ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) { // binary vector not allow
|
||||
if (adapter_index_type == static_cast<int32_t>(engine::EngineType::FAISS_IDMAP)) {
|
||||
adapter_index_type = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP);
|
||||
} else if (adapter_index_type == static_cast<int32_t>(engine::EngineType::FAISS_IVFFLAT)) {
|
||||
adapter_index_type = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT);
|
||||
} else {
|
||||
return Status(SERVER_INVALID_INDEX_TYPE, "Invalid index type for table metric type");
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
Status s;
|
||||
bool enable_gpu = false;
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
s = config.GetGpuResourceConfigEnable(enable_gpu);
|
||||
engine::meta::TableSchema table_info;
|
||||
table_info.table_id_ = table_name_;
|
||||
status = DBWrapper::DB()->DescribeTable(table_info);
|
||||
if (s.ok() && index_type_ == (int)engine::EngineType::FAISS_PQ &&
|
||||
if (s.ok() && adapter_index_type == (int)engine::EngineType::FAISS_PQ &&
|
||||
table_info.metric_type_ == (int)engine::MetricType::IP) {
|
||||
return Status(SERVER_UNEXPECTED_ERROR, "PQ not support IP in GPU version!");
|
||||
}
|
||||
#endif
|
||||
|
||||
// step 2: check table existence
|
||||
// step 3: create index
|
||||
engine::TableIndex index;
|
||||
index.engine_type_ = index_type_;
|
||||
index.engine_type_ = adapter_index_type;
|
||||
index.nlist_ = nlist_;
|
||||
status = DBWrapper::DB()->CreateIndex(table_name_, index);
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -78,6 +78,15 @@ CreateTableRequest::OnExecute() {
|
|||
table_info.index_file_size_ = index_file_size_;
|
||||
table_info.metric_type_ = metric_type_;
|
||||
|
||||
// some metric type only support binary vector, adapt the index type
|
||||
if (ValidationUtil::IsBinaryMetricType(metric_type_)) {
|
||||
if (table_info.engine_type_ == static_cast<int32_t>(engine::EngineType::FAISS_IDMAP)) {
|
||||
table_info.engine_type_ = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP);
|
||||
} else if (table_info.engine_type_ == static_cast<int32_t>(engine::EngineType::FAISS_IVFFLAT)) {
|
||||
table_info.engine_type_ = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT);
|
||||
}
|
||||
}
|
||||
|
||||
// step 3: create table
|
||||
status = DBWrapper::DB()->CreateTable(table_info);
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -56,6 +56,14 @@ DescribeIndexRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
// for binary vector, IDMAP and IVFLAT will be treated as BIN_IDMAP and BIN_IVFLAT internally
|
||||
// return IDMAP and IVFLAT for outside caller
|
||||
if (index.engine_type_ == (int32_t)engine::EngineType::FAISS_BIN_IDMAP) {
|
||||
index.engine_type_ = (int32_t)engine::EngineType::FAISS_IDMAP;
|
||||
} else if (index.engine_type_ == (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT) {
|
||||
index.engine_type_ = (int32_t)engine::EngineType::FAISS_IVFFLAT;
|
||||
}
|
||||
|
||||
index_param_.table_name_ = table_name_;
|
||||
index_param_.index_type_ = index.engine_type_;
|
||||
index_param_.nlist_ = index.nlist_;
|
||||
|
|
|
@ -29,27 +29,24 @@ namespace milvus {
|
|||
namespace server {
|
||||
|
||||
InsertRequest::InsertRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
int64_t record_size, std::vector<float>& data_list, const std::string& partition_tag,
|
||||
std::vector<int64_t>& id_array)
|
||||
engine::VectorsData& vectors, const std::string& partition_tag)
|
||||
: BaseRequest(context, DDL_DML_REQUEST_GROUP),
|
||||
table_name_(table_name),
|
||||
record_size_(record_size),
|
||||
data_list_(data_list),
|
||||
partition_tag_(partition_tag),
|
||||
id_array_(id_array) {
|
||||
vectors_data_(vectors),
|
||||
partition_tag_(partition_tag) {
|
||||
}
|
||||
|
||||
BaseRequestPtr
|
||||
InsertRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array) {
|
||||
return std::shared_ptr<BaseRequest>(
|
||||
new InsertRequest(context, table_name, record_size, data_list, partition_tag, id_array));
|
||||
InsertRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
engine::VectorsData& vectors, const std::string& partition_tag) {
|
||||
return std::shared_ptr<BaseRequest>(new InsertRequest(context, table_name, vectors, partition_tag));
|
||||
}
|
||||
|
||||
Status
|
||||
InsertRequest::OnExecute() {
|
||||
try {
|
||||
std::string hdr = "InsertRequest(table=" + table_name_ + ", n=" + std::to_string(record_size_) +
|
||||
int64_t vector_count = vectors_data_.vector_count_;
|
||||
std::string hdr = "InsertRequest(table=" + table_name_ + ", n=" + std::to_string(vector_count) +
|
||||
", partition_tag=" + partition_tag_ + ")";
|
||||
TimeRecorder rc(hdr);
|
||||
|
||||
|
@ -58,13 +55,13 @@ InsertRequest::OnExecute() {
|
|||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
if (data_list_.empty()) {
|
||||
if (vectors_data_.float_data_.empty() && vectors_data_.binary_data_.empty()) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
|
||||
"The vector array is empty. Make sure you have entered vector records.");
|
||||
}
|
||||
|
||||
if (!id_array_.empty()) {
|
||||
if (id_array_.size() != record_size_) {
|
||||
if (!vectors_data_.id_array_.empty()) {
|
||||
if (vectors_data_.id_array_.size() != vector_count) {
|
||||
return Status(SERVER_ILLEGAL_VECTOR_ID,
|
||||
"The size of vector ID array must be equal to the size of the vector.");
|
||||
}
|
||||
|
@ -84,7 +81,7 @@ InsertRequest::OnExecute() {
|
|||
|
||||
// step 3: check table flag
|
||||
// all user provide id, or all internal id
|
||||
bool user_provide_ids = !id_array_.empty();
|
||||
bool user_provide_ids = !vectors_data_.id_array_.empty();
|
||||
// user already provided id before, all insert action require user id
|
||||
if ((table_info.flag_ & engine::meta::FLAG_MASK_HAS_USERID) != 0 && !user_provide_ids) {
|
||||
return Status(SERVER_ILLEGAL_VECTOR_ID,
|
||||
|
@ -105,27 +102,49 @@ InsertRequest::OnExecute() {
|
|||
"/tmp/insert_" + std::to_string(this->insert_param_->row_record_array_size()) + ".profiling";
|
||||
ProfilerStart(fname.c_str());
|
||||
#endif
|
||||
// step 4: some metric type doesn't support float vectors
|
||||
if (!vectors_data_.float_data_.empty()) { // insert float vectors
|
||||
if (ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Table metric type doesn't support float vectors.");
|
||||
}
|
||||
|
||||
// step 4: check prepared float data
|
||||
if (data_list_.size() % record_size_ != 0) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
// check prepared float data
|
||||
if (vectors_data_.float_data_.size() % vector_count != 0) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
|
||||
if (data_list_.size() / record_size_ != table_info.dimension_) {
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
if (vectors_data_.float_data_.size() / vector_count != table_info.dimension_) {
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
} else if (!vectors_data_.binary_data_.empty()) { // insert binary vectors
|
||||
if (!ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Table metric type doesn't support binary vectors.");
|
||||
}
|
||||
|
||||
// check prepared binary data
|
||||
if (vectors_data_.binary_data_.size() % vector_count != 0) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
|
||||
if (vectors_data_.binary_data_.size() * 8 / vector_count != table_info.dimension_) {
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
}
|
||||
|
||||
// step 5: insert vectors
|
||||
auto vec_count = static_cast<uint64_t>(record_size_);
|
||||
auto vec_count = static_cast<uint64_t>(vector_count);
|
||||
|
||||
rc.RecordSection("prepare vectors data");
|
||||
status = DBWrapper::DB()->InsertVectors(table_name_, partition_tag_, vec_count, data_list_.data(), id_array_);
|
||||
status = DBWrapper::DB()->InsertVectors(table_name_, partition_tag_, vectors_data_);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto ids_size = id_array_.size();
|
||||
auto ids_size = vectors_data_.id_array_.size();
|
||||
if (ids_size != vec_count) {
|
||||
std::string msg =
|
||||
"Add " + std::to_string(vec_count) + " vectors but only return " + std::to_string(ids_size) + " id";
|
||||
|
|
|
@ -29,23 +29,20 @@ namespace server {
|
|||
class InsertRequest : public BaseRequest {
|
||||
public:
|
||||
static BaseRequestPtr
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array);
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
|
||||
const std::string& partition_tag);
|
||||
|
||||
protected:
|
||||
InsertRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array);
|
||||
InsertRequest(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
|
||||
const std::string& partition_tag);
|
||||
|
||||
Status
|
||||
OnExecute() override;
|
||||
|
||||
private:
|
||||
const std::string table_name_;
|
||||
int64_t record_size_;
|
||||
const std::vector<float>& data_list_;
|
||||
engine::VectorsData& vectors_data_;
|
||||
const std::string partition_tag_;
|
||||
|
||||
std::vector<int64_t>& id_array_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -23,20 +23,16 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#
|
||||
|
||||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
int64_t record_size, const std::vector<float>& data_list,
|
||||
const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
|
||||
const std::vector<std::string>& partition_list,
|
||||
const engine::VectorsData& vectors, const std::vector<Range>& range_list, int64_t topk,
|
||||
int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
const std::vector<std::string>& file_id_list, TopKQueryResult& result)
|
||||
: BaseRequest(context, DQL_REQUEST_GROUP),
|
||||
table_name_(table_name),
|
||||
record_size_(record_size),
|
||||
data_list_(data_list),
|
||||
vectors_data_(vectors),
|
||||
range_list_(range_list),
|
||||
topk_(topk),
|
||||
nprobe_(nprobe),
|
||||
|
@ -46,20 +42,21 @@ SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std:
|
|||
}
|
||||
|
||||
BaseRequestPtr
|
||||
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk,
|
||||
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors, const std::vector<Range>& range_list, int64_t topk,
|
||||
int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
const std::vector<std::string>& file_id_list, TopKQueryResult& result) {
|
||||
return std::shared_ptr<BaseRequest>(new SearchRequest(context, table_name, record_size, data_list, range_list, topk,
|
||||
nprobe, partition_list, file_id_list, result));
|
||||
return std::shared_ptr<BaseRequest>(new SearchRequest(context, table_name, vectors, range_list, topk, nprobe,
|
||||
partition_list, file_id_list, result));
|
||||
}
|
||||
|
||||
Status
|
||||
SearchRequest::OnExecute() {
|
||||
try {
|
||||
uint64_t vector_count = vectors_data_.vector_count_;
|
||||
auto pre_query_ctx = context_->Child("Pre query");
|
||||
|
||||
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(record_size_) +
|
||||
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
|
||||
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
|
||||
|
||||
TimeRecorder rc(hdr);
|
||||
|
@ -93,7 +90,7 @@ SearchRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
if (data_list_.empty()) {
|
||||
if (vectors_data_.float_data_.empty() && vectors_data_.binary_data_.empty()) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
|
||||
"The vector array is empty. Make sure you have entered vector records.");
|
||||
}
|
||||
|
@ -107,14 +104,28 @@ SearchRequest::OnExecute() {
|
|||
|
||||
rc.RecordSection("check validation");
|
||||
|
||||
// step 5: check prepared float data
|
||||
if (data_list_.size() % record_size_ != 0) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
if (ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) {
|
||||
// check prepared binary data
|
||||
if (vectors_data_.binary_data_.size() % vector_count != 0) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
|
||||
if (data_list_.size() / record_size_ != table_info.dimension_) {
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
if (vectors_data_.binary_data_.size() * 8 / vector_count != table_info.dimension_) {
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
} else {
|
||||
// check prepared float data
|
||||
if (vectors_data_.float_data_.size() % vector_count != 0) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
|
||||
if (vectors_data_.float_data_.size() / vector_count != table_info.dimension_) {
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION,
|
||||
"The vector dimension must be equal to the table dimension.");
|
||||
}
|
||||
}
|
||||
|
||||
rc.RecordSection("prepare vector data");
|
||||
|
@ -122,7 +133,6 @@ SearchRequest::OnExecute() {
|
|||
// step 6: search vectors
|
||||
engine::ResultIds result_ids;
|
||||
engine::ResultDistances result_distances;
|
||||
auto record_count = static_cast<uint64_t>(record_size_);
|
||||
|
||||
#ifdef MILVUS_ENABLE_PROFILING
|
||||
std::string fname =
|
||||
|
@ -138,11 +148,11 @@ SearchRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, record_count,
|
||||
nprobe_, data_list_.data(), dates, result_ids, result_distances);
|
||||
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, nprobe_,
|
||||
vectors_data_, dates, result_ids, result_distances);
|
||||
} else {
|
||||
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, record_count,
|
||||
nprobe_, data_list_.data(), dates, result_ids, result_distances);
|
||||
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, nprobe_,
|
||||
vectors_data_, dates, result_ids, result_distances);
|
||||
}
|
||||
|
||||
#ifdef MILVUS_ENABLE_PROFILING
|
||||
|
@ -161,7 +171,7 @@ SearchRequest::OnExecute() {
|
|||
auto post_query_ctx = context_->Child("Constructing result");
|
||||
|
||||
// step 7: construct result array
|
||||
result_.row_num_ = record_count;
|
||||
result_.row_num_ = vector_count;
|
||||
result_.distance_list_ = result_distances;
|
||||
result_.id_list_ = result_ids;
|
||||
|
||||
|
|
|
@ -29,14 +29,14 @@ namespace server {
|
|||
class SearchRequest : public BaseRequest {
|
||||
public:
|
||||
static BaseRequestPtr
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
|
||||
const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
|
||||
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
|
||||
TopKQueryResult& result);
|
||||
|
||||
protected:
|
||||
SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
|
||||
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk,
|
||||
SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors, const std::vector<Range>& range_list, int64_t topk,
|
||||
int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
|
||||
|
||||
|
@ -45,8 +45,7 @@ class SearchRequest : public BaseRequest {
|
|||
|
||||
private:
|
||||
const std::string table_name_;
|
||||
int64_t record_size_;
|
||||
const std::vector<float>& data_list_;
|
||||
const engine::VectorsData& vectors_data_;
|
||||
const std::vector<Range> range_list_;
|
||||
int64_t topk_;
|
||||
int64_t nprobe_;
|
||||
|
|
|
@ -72,6 +72,49 @@ ErrorMap(ErrorCode code) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void
|
||||
CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::RowRecord>& grpc_records,
|
||||
const google::protobuf::RepeatedField<google::protobuf::int64>& grpc_id_array,
|
||||
engine::VectorsData& vectors) {
|
||||
// step 1: copy vector data
|
||||
int64_t float_data_size = 0, binary_data_size = 0;
|
||||
for (auto& record : grpc_records) {
|
||||
float_data_size += record.float_data_size();
|
||||
binary_data_size += record.binary_data().size();
|
||||
}
|
||||
|
||||
std::vector<float> float_array(float_data_size, 0.0f);
|
||||
std::vector<uint8_t> binary_array(binary_data_size, 0);
|
||||
int64_t float_offset = 0, binary_offset = 0;
|
||||
if (float_data_size > 0) {
|
||||
for (auto& record : grpc_records) {
|
||||
memcpy(&float_array[float_offset], record.float_data().data(), record.float_data_size() * sizeof(float));
|
||||
float_offset += record.float_data_size();
|
||||
}
|
||||
} else if (binary_data_size > 0) {
|
||||
for (auto& record : grpc_records) {
|
||||
memcpy(&binary_array[binary_offset], record.binary_data().data(), record.binary_data().size());
|
||||
binary_offset += record.binary_data().size();
|
||||
}
|
||||
}
|
||||
|
||||
// step 2: copy id array
|
||||
std::vector<int64_t> id_array;
|
||||
if (grpc_id_array.size() > 0) {
|
||||
id_array.resize(grpc_id_array.size());
|
||||
memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
// step 3: contruct vectors
|
||||
vectors.vector_count_ = grpc_records.size();
|
||||
vectors.float_data_.swap(float_array);
|
||||
vectors.binary_data_.swap(binary_array);
|
||||
vectors.id_array_.swap(id_array);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer)
|
||||
: tracer_(tracer), random_num_generator_() {
|
||||
std::random_device random_device;
|
||||
|
@ -206,30 +249,18 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
|
|||
::milvus::grpc::VectorIds* response) {
|
||||
CHECK_NULLPTR_RETURN(request);
|
||||
|
||||
int64_t record_data_size = 0;
|
||||
for (auto& record : request->row_record_array()) {
|
||||
record_data_size += record.vector_data_size();
|
||||
}
|
||||
|
||||
std::vector<float> record_array(record_data_size, 0.0f);
|
||||
int64_t offset = 0;
|
||||
for (auto& record : request->row_record_array()) {
|
||||
memcpy(&record_array[offset], record.vector_data().data(), record.vector_data_size() * sizeof(float));
|
||||
offset += record.vector_data_size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> id_array;
|
||||
if (request->row_id_array_size() > 0) {
|
||||
id_array.resize(request->row_id_array_size());
|
||||
memcpy(id_array.data(), request->row_id_array().data(), request->row_id_array_size() * sizeof(int64_t));
|
||||
}
|
||||
// step 1: copy vector data
|
||||
engine::VectorsData vectors;
|
||||
CopyRowRecords(request->row_record_array(), request->row_id_array(), vectors);
|
||||
|
||||
// step 2: insert vectors
|
||||
Status status =
|
||||
request_handler_.Insert(context_map_[context], request->table_name(), request->row_record_array_size(),
|
||||
record_array, request->partition_tag(), id_array);
|
||||
request_handler_.Insert(context_map_[context], request->table_name(), vectors, request->partition_tag());
|
||||
|
||||
response->mutable_vector_id_array()->Resize(static_cast<int>(id_array.size()), 0);
|
||||
memcpy(response->mutable_vector_id_array()->mutable_data(), id_array.data(), id_array.size() * sizeof(int64_t));
|
||||
// step 3: return id array
|
||||
response->mutable_vector_id_array()->Resize(static_cast<int>(vectors.id_array_.size()), 0);
|
||||
memcpy(response->mutable_vector_id_array()->mutable_data(), vectors.id_array_.data(),
|
||||
vectors.id_array_.size() * sizeof(int64_t));
|
||||
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
return ::grpc::Status::OK;
|
||||
|
@ -240,36 +271,29 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
|
|||
::milvus::grpc::TopKQueryResult* response) {
|
||||
CHECK_NULLPTR_RETURN(request);
|
||||
|
||||
int64_t record_data_size = 0;
|
||||
for (auto& record : request->query_record_array()) {
|
||||
record_data_size += record.vector_data_size();
|
||||
}
|
||||
|
||||
std::vector<float> record_array(record_data_size);
|
||||
int64_t offset = 0;
|
||||
for (auto& record : request->query_record_array()) {
|
||||
memcpy(&record_array[offset], record.vector_data().data(), record.vector_data_size() * sizeof(float));
|
||||
offset += record.vector_data_size();
|
||||
}
|
||||
// step 1: copy vector data
|
||||
engine::VectorsData vectors;
|
||||
CopyRowRecords(request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(), vectors);
|
||||
|
||||
// deprecated
|
||||
std::vector<Range> ranges;
|
||||
for (auto& range : request->query_range_array()) {
|
||||
ranges.emplace_back(range.start_value(), range.end_value());
|
||||
}
|
||||
|
||||
// step 2: partition tags
|
||||
std::vector<std::string> partitions;
|
||||
for (auto& partition : request->partition_tag_array()) {
|
||||
partitions.emplace_back(partition);
|
||||
}
|
||||
|
||||
// step 3: search vectors
|
||||
std::vector<std::string> file_ids;
|
||||
TopKQueryResult result;
|
||||
Status status = request_handler_.Search(context_map_[context], request->table_name(), vectors, ranges,
|
||||
request->topk(), request->nprobe(), partitions, file_ids, result);
|
||||
|
||||
Status status =
|
||||
request_handler_.Search(context_map_[context], request->table_name(), request->query_record_array_size(),
|
||||
record_array, ranges, request->topk(), request->nprobe(), partitions, file_ids, result);
|
||||
|
||||
// construct result
|
||||
// step 4: construct and return result
|
||||
response->set_row_num(result.row_num_);
|
||||
|
||||
response->mutable_ids()->Resize(static_cast<int>(result.id_list_.size()), 0);
|
||||
|
@ -289,42 +313,38 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus
|
|||
::milvus::grpc::TopKQueryResult* response) {
|
||||
CHECK_NULLPTR_RETURN(request);
|
||||
|
||||
std::vector<std::string> file_ids;
|
||||
for (auto& file_id : request->file_id_array()) {
|
||||
file_ids.emplace_back(file_id);
|
||||
}
|
||||
|
||||
auto* search_request = &request->search_param();
|
||||
|
||||
int64_t record_data_size = 0;
|
||||
for (auto& record : search_request->query_record_array()) {
|
||||
record_data_size += record.vector_data_size();
|
||||
}
|
||||
|
||||
std::vector<float> record_array(record_data_size);
|
||||
int64_t offset = 0;
|
||||
for (auto& record : search_request->query_record_array()) {
|
||||
memcpy(&record_array[offset], record.vector_data().data(), record.vector_data_size() * sizeof(float));
|
||||
offset += record.vector_data_size();
|
||||
}
|
||||
// step 1: copy vector data
|
||||
engine::VectorsData vectors;
|
||||
CopyRowRecords(search_request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(),
|
||||
vectors);
|
||||
|
||||
// deprecated
|
||||
std::vector<Range> ranges;
|
||||
for (auto& range : search_request->query_range_array()) {
|
||||
ranges.emplace_back(range.start_value(), range.end_value());
|
||||
}
|
||||
|
||||
// step 2: copy file id array
|
||||
std::vector<std::string> file_ids;
|
||||
for (auto& file_id : request->file_id_array()) {
|
||||
file_ids.emplace_back(file_id);
|
||||
}
|
||||
|
||||
// step 3: partition tags
|
||||
std::vector<std::string> partitions;
|
||||
for (auto& partition : search_request->partition_tag_array()) {
|
||||
partitions.emplace_back(partition);
|
||||
}
|
||||
|
||||
// step 4: search vectors
|
||||
TopKQueryResult result;
|
||||
Status status =
|
||||
request_handler_.Search(context_map_[context], search_request->table_name(), vectors, ranges,
|
||||
search_request->topk(), search_request->nprobe(), partitions, file_ids, result);
|
||||
|
||||
Status status = request_handler_.Search(
|
||||
context_map_[context], search_request->table_name(), search_request->query_record_array_size(), record_array,
|
||||
ranges, search_request->topk(), search_request->nprobe(), partitions, file_ids, result);
|
||||
|
||||
// construct result
|
||||
// step 5: construct and return result
|
||||
response->set_row_num(result.row_num_);
|
||||
|
||||
response->mutable_ids()->Resize(static_cast<int>(result.id_list_.size()), 0);
|
||||
|
|
|
@ -79,6 +79,78 @@ WebErrorMap(ErrorCode code) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
Status
|
||||
CopyRowRecords(const InsertRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) {
|
||||
vectors.float_data_.clear();
|
||||
vectors.binary_data_.clear();
|
||||
vectors.id_array_.clear();
|
||||
vectors.vector_count_ = param->records->count();
|
||||
|
||||
// step 1: copy vector data
|
||||
if (nullptr == param->records.get()) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "");
|
||||
}
|
||||
|
||||
size_t tal_size = 0;
|
||||
for (int64_t i = 0; i < param->records->count(); i++) {
|
||||
tal_size += param->records->get(i)->count();
|
||||
}
|
||||
|
||||
std::vector<float>& datas = vectors.float_data_;
|
||||
datas.resize(tal_size);
|
||||
size_t index_offset = 0;
|
||||
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
|
||||
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
|
||||
datas[index_offset] = item->getValue();
|
||||
index_offset++;
|
||||
});
|
||||
});
|
||||
|
||||
// step 2: copy id array
|
||||
if (nullptr == param->ids.get()) {
|
||||
return Status(SERVER_ILLEGAL_VECTOR_ID, "");
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < param->ids->count(); i++) {
|
||||
vectors.id_array_.emplace_back(param->ids->get(i)->getValue());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
CopyRowRecords(const SearchRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) {
|
||||
vectors.float_data_.clear();
|
||||
vectors.binary_data_.clear();
|
||||
vectors.id_array_.clear();
|
||||
vectors.vector_count_ = param->records->count();
|
||||
|
||||
// step 1: copy vector data
|
||||
if (nullptr == param->records.get()) {
|
||||
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "");
|
||||
}
|
||||
|
||||
size_t tal_size = 0;
|
||||
for (int64_t i = 0; i < param->records->count(); i++) {
|
||||
tal_size += param->records->get(i)->count();
|
||||
}
|
||||
|
||||
std::vector<float>& datas = vectors.float_data_;
|
||||
datas.resize(tal_size);
|
||||
size_t index_offset = 0;
|
||||
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
|
||||
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
|
||||
datas[index_offset] = item->getValue();
|
||||
index_offset++;
|
||||
});
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
|
||||
|
||||
Status
|
||||
|
@ -567,37 +639,17 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& tag)
|
|||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& param,
|
||||
VectorIdsDto::ObjectWrapper& ids_dto) {
|
||||
std::vector<int64_t> ids;
|
||||
if (nullptr != param->ids.get() && param->ids->count() > 0) {
|
||||
for (int64_t i = 0; i < param->ids->count(); i++) {
|
||||
ids.emplace_back(param->ids->get(i)->getValue());
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr == param->records.get()) {
|
||||
engine::VectorsData vectors;
|
||||
auto status = CopyRowRecords(param, vectors);
|
||||
if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors")
|
||||
}
|
||||
|
||||
size_t tal_size = 0;
|
||||
for (int64_t i = 0; i < param->records->count(); i++) {
|
||||
tal_size += param->records->get(i)->count();
|
||||
}
|
||||
|
||||
std::vector<float> datas(tal_size);
|
||||
size_t index_offset = 0;
|
||||
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
|
||||
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
|
||||
datas[index_offset] = item->getValue();
|
||||
index_offset++;
|
||||
});
|
||||
});
|
||||
|
||||
auto status = request_handler_.Insert(context_ptr_, table_name->std_str(), param->records->count(), datas,
|
||||
param->tag->std_str(), ids);
|
||||
status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, param->tag->std_str());
|
||||
|
||||
if (status.ok()) {
|
||||
ids_dto->ids = ids_dto->ids->createShared();
|
||||
for (auto& id : ids) {
|
||||
for (auto& id : vectors.id_array_) {
|
||||
ids_dto->ids->pushBack(std::to_string(id).c_str());
|
||||
}
|
||||
}
|
||||
|
@ -634,25 +686,18 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj
|
|||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill query vectors")
|
||||
}
|
||||
|
||||
size_t tal_size = 0;
|
||||
search_request->records->forEach(
|
||||
[&tal_size](const OList<OFloat32>::ObjectWrapper& item) { tal_size += item->count(); });
|
||||
|
||||
std::vector<float> datas(tal_size);
|
||||
size_t index_offset = 0;
|
||||
search_request->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& elem) {
|
||||
elem->forEach([&datas, &index_offset](const OFloat32& item) {
|
||||
datas[index_offset] = item->getValue();
|
||||
index_offset++;
|
||||
});
|
||||
});
|
||||
engine::VectorsData vectors;
|
||||
auto status = CopyRowRecords(search_request, vectors);
|
||||
if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors")
|
||||
}
|
||||
|
||||
std::vector<Range> range_list;
|
||||
|
||||
TopKQueryResult result;
|
||||
auto context_ptr = GenContextPtr("Web Handler");
|
||||
auto status = request_handler_.Search(context_ptr, table_name->std_str(), search_request->records->count(), datas,
|
||||
range_list, topk_t, nprobe_t, tag_list, file_id_list, result);
|
||||
status = request_handler_.Search(context_ptr, table_name->std_str(), vectors, range_list, topk_t, nprobe_t,
|
||||
tag_list, file_id_list, result);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
|
|
@ -21,9 +21,13 @@
|
|||
#include "utils/StringHelpFunctions.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <regex>
|
||||
|
@ -33,7 +37,7 @@ namespace milvus {
|
|||
namespace server {
|
||||
|
||||
constexpr size_t TABLE_NAME_SIZE_LIMIT = 255;
|
||||
constexpr int64_t TABLE_DIMENSION_LIMIT = 16384;
|
||||
constexpr int64_t TABLE_DIMENSION_LIMIT = 32768;
|
||||
constexpr int32_t INDEX_FILE_SIZE_LIMIT = 4096; // index trigger size max = 4096 MB
|
||||
|
||||
Status
|
||||
|
@ -78,7 +82,8 @@ Status
|
|||
ValidationUtil::ValidateTableDimension(int64_t dimension) {
|
||||
if (dimension <= 0 || dimension > TABLE_DIMENSION_LIMIT) {
|
||||
std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " +
|
||||
"The table dimension must be within the range of 1 ~ 16384.";
|
||||
"The table dimension must be within the range of 1 ~ " +
|
||||
std::to_string(TABLE_DIMENSION_LIMIT) + ".";
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
|
||||
} else {
|
||||
|
@ -108,6 +113,12 @@ ValidationUtil::ValidateTableIndexType(int32_t index_type) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool
|
||||
ValidationUtil::IsBinaryIndexType(int32_t index_type) {
|
||||
return (index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP)) ||
|
||||
(index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT));
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateTableIndexNlist(int32_t nlist) {
|
||||
if (nlist <= 0) {
|
||||
|
@ -135,16 +146,22 @@ ValidationUtil::ValidateTableIndexFileSize(int64_t index_file_size) {
|
|||
|
||||
Status
|
||||
ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) {
|
||||
if (metric_type != static_cast<int32_t>(engine::MetricType::L2) &&
|
||||
metric_type != static_cast<int32_t>(engine::MetricType::IP)) {
|
||||
if (metric_type <= 0 || metric_type > static_cast<int32_t>(engine::MetricType::MAX_VALUE)) {
|
||||
std::string msg = "Invalid index metric type: " + std::to_string(metric_type) + ". " +
|
||||
"Make sure the metric type is either MetricType.L2 or MetricType.IP.";
|
||||
"Make sure the metric type is in MetricType list.";
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status(SERVER_INVALID_INDEX_METRIC_TYPE, msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool
|
||||
ValidationUtil::IsBinaryMetricType(int32_t metric_type) {
|
||||
return (metric_type == static_cast<int32_t>(engine::MetricType::HAMMING)) ||
|
||||
(metric_type == static_cast<int32_t>(engine::MetricType::JACCARD)) ||
|
||||
(metric_type == static_cast<int32_t>(engine::MetricType::TANIMOTO));
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) {
|
||||
if (top_k <= 0 || top_k > 2048) {
|
||||
|
@ -219,6 +236,7 @@ ValidationUtil::ValidateGpuIndex(int32_t gpu_index) {
|
|||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
Status
|
||||
ValidationUtil::GetGpuMemory(int32_t gpu_index, size_t& memory) {
|
||||
cudaDeviceProp deviceProp;
|
||||
|
@ -233,6 +251,7 @@ ValidationUtil::GetGpuMemory(int32_t gpu_index, size_t& memory) {
|
|||
memory = deviceProp.totalGlobalMem;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Status
|
||||
|
|
|
@ -40,6 +40,9 @@ class ValidationUtil {
|
|||
static Status
|
||||
ValidateTableIndexType(int32_t index_type);
|
||||
|
||||
static bool
|
||||
IsBinaryIndexType(int32_t index_type);
|
||||
|
||||
static Status
|
||||
ValidateTableIndexNlist(int32_t nlist);
|
||||
|
||||
|
@ -49,6 +52,9 @@ class ValidationUtil {
|
|||
static Status
|
||||
ValidateTableIndexMetricType(int32_t metric_type);
|
||||
|
||||
static bool
|
||||
IsBinaryMetricType(int32_t metric_type);
|
||||
|
||||
static Status
|
||||
ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema);
|
||||
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "wrapper/BinVecImpl.h"
|
||||
#include "WrapperException.h"
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
Status
|
||||
BinVecImpl::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const uint8_t* xt) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xb);
|
||||
ret_ds->Set(knowhere::meta::IDS, ids);
|
||||
|
||||
index_->Train(ret_ds, cfg);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_ERROR, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
BinVecImpl::Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg) {
|
||||
try {
|
||||
auto k = cfg->k;
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nq);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xq);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
|
||||
auto res = index_->Search(ret_ds, search_cfg);
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
// std::stringstream ss_id;
|
||||
// std::stringstream ss_dist;
|
||||
// for (auto i = 0; i < 10; i++) {
|
||||
// for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
|
||||
// ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
|
||||
// }
|
||||
// ss_id << std::endl;
|
||||
// ss_dist << std::endl;
|
||||
// }
|
||||
// std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
// std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
//}
|
||||
|
||||
// auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
|
||||
// auto p_dist = dis_array->data()->GetValues<float>(1, 0);
|
||||
|
||||
// TODO(linxj): avoid copy here.
|
||||
auto res_ids = res->Get<int64_t*>(knowhere::meta::IDS);
|
||||
auto res_dist = res->Get<float*>(knowhere::meta::DISTANCE);
|
||||
memcpy(ids, res_ids, sizeof(int64_t) * nq * k);
|
||||
memcpy(dist, res_dist, sizeof(float) * nq * k);
|
||||
free(res_ids);
|
||||
free(res_dist);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_ERROR, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
BinVecImpl::Add(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg) {
|
||||
try {
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xb);
|
||||
ret_ds->Set(knowhere::meta::IDS, ids);
|
||||
|
||||
index_->Add(ret_ds, cfg);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_ERROR, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
BinVecImpl::CopyToGpu(const int64_t& device_id, const Config& cfg) {
|
||||
char* errmsg = "Binary Index not support CopyToGpu";
|
||||
WRAPPER_LOG_ERROR << errmsg;
|
||||
throw WrapperException("errmsg");
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
BinVecImpl::CopyToCpu(const Config& cfg) {
|
||||
char* errmsg = "Binary Index not support CopyToCpu";
|
||||
WRAPPER_LOG_ERROR << errmsg;
|
||||
throw WrapperException("errmsg");
|
||||
}
|
||||
|
||||
ErrorCode
|
||||
BinBFIndex::Build(const Config& cfg) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->Train(cfg);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return KNOWHERE_ERROR;
|
||||
}
|
||||
return KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
Status
|
||||
BinBFIndex::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const uint8_t* xt) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xb);
|
||||
ret_ds->Set(knowhere::meta::IDS, ids);
|
||||
|
||||
std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->Train(cfg);
|
||||
index_->Add(ret_ds, cfg);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_ERROR, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const uint8_t*
|
||||
BinBFIndex::GetRawVectors() {
|
||||
auto raw_index = std::dynamic_pointer_cast<knowhere::BinaryIDMAP>(index_);
|
||||
if (raw_index) {
|
||||
return raw_index->GetRawVectors();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int64_t*
|
||||
BinBFIndex::GetRawIds() {
|
||||
return std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->GetRawIds();
|
||||
}
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -0,0 +1,71 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "VecImpl.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
class BinVecImpl : public VecIndexImpl {
|
||||
public:
|
||||
explicit BinVecImpl(std::shared_ptr<knowhere::VectorIndex> index, const IndexType& type)
|
||||
: VecIndexImpl(std::move(index), type) {
|
||||
}
|
||||
|
||||
Status
|
||||
BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const uint8_t* xt) override;
|
||||
Status
|
||||
Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg) override;
|
||||
|
||||
Status
|
||||
Add(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyToGpu(const int64_t& device_id, const Config& cfg) override;
|
||||
|
||||
VecIndexPtr
|
||||
CopyToCpu(const Config& cfg) override;
|
||||
};
|
||||
|
||||
class BinBFIndex : public BinVecImpl {
|
||||
public:
|
||||
explicit BinBFIndex(std::shared_ptr<knowhere::VectorIndex> index)
|
||||
: BinVecImpl(std::move(index), IndexType::FAISS_BIN_IDMAP) {
|
||||
}
|
||||
|
||||
ErrorCode
|
||||
Build(const Config& cfg);
|
||||
|
||||
Status
|
||||
BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const uint8_t* xt) override;
|
||||
|
||||
const uint8_t*
|
||||
GetRawVectors();
|
||||
|
||||
const int64_t*
|
||||
GetRawIds();
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
|
@ -37,9 +37,9 @@ namespace engine {
|
|||
#endif
|
||||
|
||||
void
|
||||
ConfAdapter::MatchBase(knowhere::Config conf) {
|
||||
ConfAdapter::MatchBase(knowhere::Config conf, knowhere::METRICTYPE default_metric) {
|
||||
if (conf->metric_type == knowhere::DEFAULT_TYPE)
|
||||
conf->metric_type = knowhere::METRICTYPE::L2;
|
||||
conf->metric_type = default_metric;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
|
@ -63,7 +63,7 @@ ConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
|||
knowhere::Config
|
||||
IVFConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
|
@ -74,13 +74,13 @@ IVFConfAdapter::Match(const TempMetaConf& metaconf) {
|
|||
static constexpr float TYPICAL_COUNT = 1000000.0;
|
||||
|
||||
int64_t
|
||||
IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
|
||||
if (size <= TYPICAL_COUNT / 16384 + 1) {
|
||||
IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
|
||||
if (size <= TYPICAL_COUNT / per_nlist + 1) {
|
||||
// handle less row count, avoid nlist set to 0
|
||||
return 1;
|
||||
} else if (int(size / TYPICAL_COUNT) * nlist <= 0) {
|
||||
// calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
|
||||
return int(size / TYPICAL_COUNT * 16384);
|
||||
return int(size / TYPICAL_COUNT * per_nlist);
|
||||
}
|
||||
return nlist;
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ IVFConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
|
|||
knowhere::Config
|
||||
IVFSQConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFSQCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
|
@ -207,7 +207,7 @@ IVFPQConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
|
|||
knowhere::Config
|
||||
NSGConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::NSGCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
|
@ -266,5 +266,26 @@ SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType&
|
|||
return conf;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
BinIDMAPConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
conf->k = metaconf.k;
|
||||
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
|
||||
return conf;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
BinIVFConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFBinCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 2048);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
|
||||
return conf;
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
||||
|
|
|
@ -48,7 +48,7 @@ class ConfAdapter {
|
|||
|
||||
protected:
|
||||
static void
|
||||
MatchBase(knowhere::Config conf);
|
||||
MatchBase(knowhere::Config conf, knowhere::METRICTYPE defalut_metric = knowhere::METRICTYPE::L2);
|
||||
};
|
||||
|
||||
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
|
||||
|
@ -63,7 +63,7 @@ class IVFConfAdapter : public ConfAdapter {
|
|||
|
||||
protected:
|
||||
static int64_t
|
||||
MatchNlist(const int64_t& size, const int64_t& nlist);
|
||||
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist);
|
||||
};
|
||||
|
||||
class IVFSQConfAdapter : public IVFConfAdapter {
|
||||
|
@ -112,5 +112,17 @@ class SPTAGBKTConfAdapter : public ConfAdapter {
|
|||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
|
||||
};
|
||||
|
||||
class BinIDMAPConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
};
|
||||
|
||||
class BinIVFConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
||||
|
|
|
@ -41,10 +41,12 @@ AdapterMgr::RegisterAdapter() {
|
|||
init_ = true;
|
||||
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexType::FAISS_IDMAP, idmap);
|
||||
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexType::FAISS_BIN_IDMAP, idmap_bin);
|
||||
|
||||
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexType::FAISS_IVFFLAT_CPU, ivf_cpu);
|
||||
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexType::FAISS_IVFFLAT_GPU, ivf_gpu);
|
||||
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexType::FAISS_IVFFLAT_MIX, ivf_mix);
|
||||
REGISTER_CONF_ADAPTER(BinIVFConfAdapter, IndexType::FAISS_BIN_IVFLAT_CPU, ivf_bin_cpu);
|
||||
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_CPU, ivfsq8_cpu);
|
||||
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_GPU, ivfsq8_gpu);
|
||||
|
|
|
@ -148,7 +148,7 @@ VecIndexImpl::Count() {
|
|||
}
|
||||
|
||||
IndexType
|
||||
VecIndexImpl::GetType() {
|
||||
VecIndexImpl::GetType() const {
|
||||
return type;
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class VecIndexImpl : public VecIndex {
|
|||
CopyToCpu(const Config& cfg) override;
|
||||
|
||||
IndexType
|
||||
GetType() override;
|
||||
GetType() const override;
|
||||
|
||||
int64_t
|
||||
Dimension() override;
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include "wrapper/VecIndex.h"
|
||||
#include "VecImpl.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/IndexIVFPQ.h"
|
||||
|
@ -32,6 +34,7 @@
|
|||
#include "utils/Exception.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "wrapper/BinVecImpl.h"
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
#include <cuda.h>
|
||||
|
@ -68,10 +71,18 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
|
|||
index = std::make_shared<knowhere::IDMAP>();
|
||||
return std::make_shared<BFIndex>(index);
|
||||
}
|
||||
case IndexType::FAISS_BIN_IDMAP: {
|
||||
index = std::make_shared<knowhere::BinaryIDMAP>();
|
||||
return std::make_shared<BinBFIndex>(index);
|
||||
}
|
||||
case IndexType::FAISS_IVFFLAT_CPU: {
|
||||
index = std::make_shared<knowhere::IVF>();
|
||||
break;
|
||||
}
|
||||
case IndexType::FAISS_BIN_IVFLAT_CPU: {
|
||||
index = std::make_shared<knowhere::BinaryIVF>();
|
||||
return std::make_shared<BinVecImpl>(index, type);
|
||||
}
|
||||
case IndexType::FAISS_IVFPQ_CPU: {
|
||||
index = std::make_shared<knowhere::IVFPQ>();
|
||||
break;
|
||||
|
|
|
@ -50,6 +50,8 @@ enum class IndexType {
|
|||
NSG_MIX,
|
||||
FAISS_IVFPQ_MIX,
|
||||
SPTAG_BKT_RNT_CPU,
|
||||
FAISS_BIN_IDMAP = 100,
|
||||
FAISS_BIN_IVFLAT_CPU = 101,
|
||||
};
|
||||
|
||||
class VecIndex;
|
||||
|
@ -62,12 +64,31 @@ class VecIndex : public cache::DataObj {
|
|||
BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt = 0,
|
||||
const float* xt = nullptr) = 0;
|
||||
|
||||
virtual Status
|
||||
BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt = 0,
|
||||
const uint8_t* xt = nullptr) {
|
||||
ENGINE_LOG_ERROR << "BuildAll with uint8_t not support";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual Status
|
||||
Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg = Config()) = 0;
|
||||
|
||||
virtual Status
|
||||
Add(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg = Config()) {
|
||||
ENGINE_LOG_ERROR << "Add with uint8_t not support";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual Status
|
||||
Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg = Config()) = 0;
|
||||
|
||||
virtual Status
|
||||
Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg = Config()) {
|
||||
ENGINE_LOG_ERROR << "Search with uint8_t not support";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual VecIndexPtr
|
||||
CopyToGpu(const int64_t& device_id, const Config& cfg = Config()) = 0;
|
||||
|
||||
|
@ -82,7 +103,7 @@ class VecIndex : public cache::DataObj {
|
|||
GetDeviceId() = 0;
|
||||
|
||||
virtual IndexType
|
||||
GetType() = 0;
|
||||
GetType() const = 0;
|
||||
|
||||
virtual int64_t
|
||||
Dimension() = 0;
|
||||
|
|
|
@ -50,12 +50,13 @@ BuildTableSchema() {
|
|||
}
|
||||
|
||||
void
|
||||
BuildVectors(int64_t n, std::vector<float>& vectors) {
|
||||
vectors.clear();
|
||||
vectors.resize(n * TABLE_DIM);
|
||||
float* data = vectors.data();
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
|
||||
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
|
||||
vectors.vector_count_ = n;
|
||||
vectors.float_data_.clear();
|
||||
vectors.float_data_.resize(n * TABLE_DIM);
|
||||
float* data = vectors.float_data_.data();
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
|
||||
data[TABLE_DIM * i] += i / 2000.;
|
||||
}
|
||||
}
|
||||
|
@ -161,15 +162,12 @@ TEST_F(DBTest, DB_TEST) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
milvus::engine::IDNumbers target_ids;
|
||||
|
||||
int64_t nb = 50;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = 50;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int64_t qb = 5;
|
||||
std::vector<float> qxb;
|
||||
uint64_t qb = 5;
|
||||
milvus::engine::VectorsData qxb;
|
||||
BuildVectors(qb, qxb);
|
||||
|
||||
std::thread search([&]() {
|
||||
|
@ -191,13 +189,12 @@ TEST_F(DBTest, DB_TEST) {
|
|||
START_TIMER;
|
||||
|
||||
std::vector<std::string> tags;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, qb, 10, qxb.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
ASSERT_TRUE(stat.ok());
|
||||
for (auto i = 0; i < qb; ++i) {
|
||||
ASSERT_EQ(result_ids[i * k], target_ids[i]);
|
||||
ss.str("");
|
||||
ss << "Result [" << i << "]:";
|
||||
for (auto t = 0; t < k; t++) {
|
||||
|
@ -214,10 +211,13 @@ TEST_F(DBTest, DB_TEST) {
|
|||
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
if (i == 40) {
|
||||
db_->InsertVectors(TABLE_NAME, "", qb, qxb.data(), target_ids);
|
||||
ASSERT_EQ(target_ids.size(), qb);
|
||||
qxb.id_array_.clear();
|
||||
db_->InsertVectors(TABLE_NAME, "", qxb);
|
||||
ASSERT_EQ(qxb.id_array_.size(), qb);
|
||||
} else {
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
xb.id_array_.clear();
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
ASSERT_EQ(xb.id_array_.size(), nb);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
|
@ -250,21 +250,24 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
size_t nb = VECTOR_COUNT;
|
||||
size_t nq = 10;
|
||||
size_t k = 5;
|
||||
std::vector<float> xb(nb * TABLE_DIM);
|
||||
std::vector<float> xq(nq * TABLE_DIM);
|
||||
std::vector<int64_t> ids(nb);
|
||||
milvus::engine::VectorsData xb, xq;
|
||||
xb.vector_count_ = nb;
|
||||
xb.float_data_.resize(nb * TABLE_DIM);
|
||||
xq.vector_count_ = nq;
|
||||
xq.float_data_.resize(nq * TABLE_DIM);
|
||||
xb.id_array_.resize(nb);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
|
||||
for (size_t i = 0; i < nb * TABLE_DIM; i++) {
|
||||
xb[i] = dis_xt(gen);
|
||||
xb.float_data_[i] = dis_xt(gen);
|
||||
if (i < nb) {
|
||||
ids[i] = i;
|
||||
xb.id_array_[i] = i;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < nq * TABLE_DIM; i++) {
|
||||
xq[i] = dis_xt(gen);
|
||||
xq.float_data_[i] = dis_xt(gen);
|
||||
}
|
||||
|
||||
// result data
|
||||
|
@ -274,14 +277,8 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<float> dis(k * nq);
|
||||
|
||||
// insert data
|
||||
const int batch_size = 100;
|
||||
for (int j = 0; j < nb / batch_size; ++j) {
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", batch_size, xb.data() + batch_size * j * TABLE_DIM, ids);
|
||||
if (j == 200) {
|
||||
sleep(1);
|
||||
}
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
milvus::engine::TableIndex index;
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IDMAP;
|
||||
|
@ -291,9 +288,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -304,9 +299,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -317,9 +310,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -331,9 +322,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
#endif
|
||||
|
@ -349,7 +338,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
}
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids,
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, dates, result_ids,
|
||||
result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
@ -361,9 +350,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -378,14 +365,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
{
|
||||
result_ids.clear();
|
||||
result_dists.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, nq, 10, xq.data(), result_ids, result_dists);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
{
|
||||
result_ids.clear();
|
||||
result_dists.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, 200, 10, xq.data(), result_ids, result_dists);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, 10, xq, result_ids, result_dists);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -400,7 +380,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
}
|
||||
result_ids.clear();
|
||||
result_dists.clear();
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids,
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, dates, result_ids,
|
||||
result_dists);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
@ -418,15 +398,15 @@ TEST_F(DBTest, PRELOADTABLE_TEST) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
|
||||
|
||||
int64_t nb = VECTOR_COUNT;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = VECTOR_COUNT;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int loop = 5;
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
ASSERT_EQ(vector_ids.size(), nb);
|
||||
xb.id_array_.clear();
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
ASSERT_EQ(xb.id_array_.size(), nb);
|
||||
}
|
||||
|
||||
milvus::engine::TableIndex index;
|
||||
|
@ -454,8 +434,8 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
|
|||
stat = db_->HasTable(table_info.table_id_, has_table);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
milvus::engine::IDNumbers ids;
|
||||
stat = db_->InsertVectors(table_info.table_id_, "", 0, nullptr, ids);
|
||||
milvus::engine::VectorsData xb;
|
||||
stat = db_->InsertVectors(table_info.table_id_, "", xb);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
stat = db_->PreloadTable(table_info.table_id_);
|
||||
|
@ -477,10 +457,10 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
|
|||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat =
|
||||
db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, 1, nullptr, dates, result_ids, result_distances);
|
||||
db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, xb, dates, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
std::vector<std::string> file_ids;
|
||||
stat = db_->QueryByFileID(dummy_context_, table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, result_ids,
|
||||
stat = db_->QueryByFileID(dummy_context_, table_info.table_id_, file_ids, 1, 1, xb, dates, result_ids,
|
||||
result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
|
@ -492,13 +472,12 @@ TEST_F(DBTest, INDEX_TEST) {
|
|||
milvus::engine::meta::TableSchema table_info = BuildTableSchema();
|
||||
auto stat = db_->CreateTable(table_info);
|
||||
|
||||
int64_t nb = VECTOR_COUNT;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = VECTOR_COUNT;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
ASSERT_EQ(vector_ids.size(), nb);
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
ASSERT_EQ(xb.id_array_.size(), nb);
|
||||
|
||||
milvus::engine::TableIndex index;
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
|
||||
|
@ -552,7 +531,7 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
stat = db_->CreatePartition(table_name, partition_name, partition_tag);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(INSERT_BATCH, xb);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
|
@ -561,7 +540,7 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
vector_ids[k] = i * INSERT_BATCH + k;
|
||||
}
|
||||
|
||||
db_->InsertVectors(table_name, partition_tag, INSERT_BATCH, xb.data(), vector_ids);
|
||||
db_->InsertVectors(table_name, partition_tag, xb);
|
||||
ASSERT_EQ(vector_ids.size(), INSERT_BATCH);
|
||||
}
|
||||
|
||||
|
@ -594,14 +573,14 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
const int64_t nq = 5;
|
||||
const int64_t topk = 10;
|
||||
const int64_t nprobe = 10;
|
||||
std::vector<float> xq;
|
||||
milvus::engine::VectorsData xq;
|
||||
BuildVectors(nq, xq);
|
||||
|
||||
// specify partition tags
|
||||
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nq, nprobe, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -609,7 +588,7 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
tags.clear();
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nq, nprobe, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -617,7 +596,7 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
tags.push_back("\\d");
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nq, nprobe, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
}
|
||||
|
@ -661,14 +640,14 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) {
|
|||
uint64_t size;
|
||||
db_->Size(size);
|
||||
|
||||
int64_t nb = 10;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = 10;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int loop = INSERT_LOOP;
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
|
||||
|
@ -695,12 +674,12 @@ TEST_F(DBTest2, DELETE_TEST) {
|
|||
uint64_t size;
|
||||
db_->Size(size);
|
||||
|
||||
int64_t nb = VECTOR_COUNT;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = VECTOR_COUNT;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
milvus::engine::TableIndex index;
|
||||
stat = db_->CreateIndex(TABLE_NAME, index);
|
||||
|
||||
|
@ -730,12 +709,12 @@ TEST_F(DBTest2, DELETE_BY_RANGE_TEST) {
|
|||
db_->Size(size);
|
||||
ASSERT_EQ(size, 0UL);
|
||||
|
||||
int64_t nb = VECTOR_COUNT;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = VECTOR_COUNT;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
milvus::engine::TableIndex index;
|
||||
stat = db_->CreateIndex(TABLE_NAME, index);
|
||||
|
||||
|
|
|
@ -44,12 +44,13 @@ BuildTableSchema() {
|
|||
}
|
||||
|
||||
void
|
||||
BuildVectors(int64_t n, std::vector<float>& vectors) {
|
||||
vectors.clear();
|
||||
vectors.resize(n * TABLE_DIM);
|
||||
float* data = vectors.data();
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
|
||||
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
|
||||
vectors.vector_count_ = n;
|
||||
vectors.float_data_.clear();
|
||||
vectors.float_data_.resize(n * TABLE_DIM);
|
||||
float* data = vectors.float_data_.data();
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
|
||||
data[TABLE_DIM * i] += i / 2000.;
|
||||
}
|
||||
}
|
||||
|
@ -66,19 +67,16 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
milvus::engine::IDNumbers target_ids;
|
||||
|
||||
int64_t nb = 50;
|
||||
std::vector<float> xb;
|
||||
uint64_t nb = 50;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int64_t qb = 5;
|
||||
std::vector<float> qxb;
|
||||
uint64_t qb = 5;
|
||||
milvus::engine::VectorsData qxb;
|
||||
BuildVectors(qb, qxb);
|
||||
|
||||
db_->InsertVectors(TABLE_NAME, "", qb, qxb.data(), target_ids);
|
||||
ASSERT_EQ(target_ids.size(), qb);
|
||||
db_->InsertVectors(TABLE_NAME, "", qxb);
|
||||
ASSERT_EQ(qxb.id_array_.size(), qb);
|
||||
|
||||
std::thread search([&]() {
|
||||
milvus::engine::ResultIds result_ids;
|
||||
|
@ -98,7 +96,7 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
|
||||
START_TIMER;
|
||||
std::vector<std::string> tags;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, qb, 10, qxb.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
|
@ -106,13 +104,6 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
for (auto i = 0; i < qb; ++i) {
|
||||
// std::cout << results[k][0].first << " " << target_ids[k] << std::endl;
|
||||
// ASSERT_EQ(results[k][0].first, target_ids[k]);
|
||||
bool exists = false;
|
||||
for (auto t = 0; t < k; t++) {
|
||||
if (result_ids[i * k + t] == target_ids[i]) {
|
||||
exists = true;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(exists);
|
||||
ss.str("");
|
||||
ss << "Result [" << i << "]:";
|
||||
for (auto t = 0; t < k; t++) {
|
||||
|
@ -136,7 +127,8 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
// } else {
|
||||
// db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
// }
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
xb.id_array_.clear();
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
|
||||
|
@ -157,21 +149,24 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
|
|||
size_t nb = VECTOR_COUNT;
|
||||
size_t nq = 10;
|
||||
size_t k = 5;
|
||||
std::vector<float> xb(nb * TABLE_DIM);
|
||||
std::vector<float> xq(nq * TABLE_DIM);
|
||||
std::vector<int64_t> ids(nb);
|
||||
milvus::engine::VectorsData xb, xq;
|
||||
xb.vector_count_ = nb;
|
||||
xb.float_data_.resize(nb * TABLE_DIM);
|
||||
xq.vector_count_ = nq;
|
||||
xq.float_data_.resize(nq * TABLE_DIM);
|
||||
xb.id_array_.resize(nb);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
|
||||
for (size_t i = 0; i < nb * TABLE_DIM; i++) {
|
||||
xb[i] = dis_xt(gen);
|
||||
xb.float_data_[i] = dis_xt(gen);
|
||||
if (i < nb) {
|
||||
ids[i] = i;
|
||||
xb.id_array_[i] = i;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < nq * TABLE_DIM; i++) {
|
||||
xq[i] = dis_xt(gen);
|
||||
xq.float_data_[i] = dis_xt(gen);
|
||||
}
|
||||
|
||||
// result data
|
||||
|
@ -181,21 +176,15 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
|
|||
std::vector<float> dis(k * nq);
|
||||
|
||||
// insert data
|
||||
const int batch_size = 100;
|
||||
for (int j = 0; j < nb / batch_size; ++j) {
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", batch_size, xb.data() + batch_size * j * TABLE_DIM, ids);
|
||||
if (j == 200) {
|
||||
sleep(1);
|
||||
}
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
stat = db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
sleep(2); // wait until build index finish
|
||||
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -228,12 +217,12 @@ TEST_F(MySqlDBTest, ARHIVE_DISK_CHECK) {
|
|||
db_->Size(size);
|
||||
|
||||
int64_t nb = 10;
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int loop = INSERT_LOOP;
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
|
||||
|
@ -264,12 +253,12 @@ TEST_F(MySqlDBTest, DELETE_TEST) {
|
|||
db_->Size(size);
|
||||
|
||||
int64_t nb = INSERT_LOOP;
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int loop = 20;
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
|
||||
db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
|
||||
|
@ -307,7 +296,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
stat = db_->CreatePartition(table_name, partition_name, partition_tag);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(INSERT_BATCH, xb);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
|
@ -316,7 +305,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
vector_ids[k] = i * INSERT_BATCH + k;
|
||||
}
|
||||
|
||||
db_->InsertVectors(table_name, partition_tag, INSERT_BATCH, xb.data(), vector_ids);
|
||||
db_->InsertVectors(table_name, partition_tag, xb);
|
||||
ASSERT_EQ(vector_ids.size(), INSERT_BATCH);
|
||||
}
|
||||
|
||||
|
@ -349,14 +338,14 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
const int64_t nq = 5;
|
||||
const int64_t topk = 10;
|
||||
const int64_t nprobe = 10;
|
||||
std::vector<float> xq;
|
||||
milvus::engine::VectorsData xq;
|
||||
BuildVectors(nq, xq);
|
||||
|
||||
// specify partition tags
|
||||
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, nq, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -364,7 +353,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
tags.clear();
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, nq, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -372,7 +361,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
tags.push_back("\\d");
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, nq, 10, xq.data(), result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
}
|
||||
|
|
|
@ -55,10 +55,11 @@ BuildTableSchema() {
|
|||
}
|
||||
|
||||
void
|
||||
BuildVectors(int64_t n, std::vector<float>& vectors) {
|
||||
vectors.clear();
|
||||
vectors.resize(n * TABLE_DIM);
|
||||
float* data = vectors.data();
|
||||
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
|
||||
vectors.vector_count_ = n;
|
||||
vectors.float_data_.clear();
|
||||
vectors.float_data_.resize(n * TABLE_DIM);
|
||||
float* data = vectors.float_data_.data();
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
|
||||
}
|
||||
|
@ -76,10 +77,10 @@ TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) {
|
|||
ASSERT_TRUE(status.ok());
|
||||
|
||||
int64_t n = 100;
|
||||
std::vector<float> vectors;
|
||||
milvus::engine::VectorsData vectors;
|
||||
BuildVectors(n, vectors);
|
||||
|
||||
milvus::engine::VectorSource source(n, vectors.data());
|
||||
milvus::engine::VectorSource source(vectors);
|
||||
|
||||
size_t num_vectors_added;
|
||||
milvus::engine::ExecutionEnginePtr execution_engine_ = milvus::engine::EngineFactory::Build(
|
||||
|
@ -87,21 +88,16 @@ TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) {
|
|||
(milvus::engine::EngineType)table_file_schema.engine_type_,
|
||||
(milvus::engine::MetricType)table_file_schema.metric_type_, table_schema.nlist_);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added, vector_ids);
|
||||
status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added);
|
||||
ASSERT_TRUE(status.ok());
|
||||
vector_ids = source.GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), 50);
|
||||
ASSERT_EQ(num_vectors_added, 50);
|
||||
ASSERT_EQ(source.GetVectorIds().size(), 50);
|
||||
|
||||
vector_ids.clear();
|
||||
status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added, vector_ids);
|
||||
vectors.id_array_.clear();
|
||||
status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
ASSERT_EQ(num_vectors_added, 50);
|
||||
|
||||
vector_ids = source.GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), 100);
|
||||
ASSERT_EQ(source.GetVectorIds().size(), 100);
|
||||
}
|
||||
|
||||
TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) {
|
||||
|
@ -114,34 +110,30 @@ TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) {
|
|||
milvus::engine::MemTableFile mem_table_file(GetTableName(), impl_, options);
|
||||
|
||||
int64_t n_100 = 100;
|
||||
std::vector<float> vectors_100;
|
||||
milvus::engine::VectorsData vectors_100;
|
||||
BuildVectors(n_100, vectors_100);
|
||||
|
||||
milvus::engine::VectorSourcePtr source = std::make_shared<milvus::engine::VectorSource>(n_100, vectors_100.data());
|
||||
milvus::engine::VectorSourcePtr source = std::make_shared<milvus::engine::VectorSource>(vectors_100);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
status = mem_table_file.Add(source, vector_ids);
|
||||
status = mem_table_file.Add(source);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
// std::cout << mem_table_file.GetCurrentMem() << " " << mem_table_file.GetMemLeft() << std::endl;
|
||||
|
||||
vector_ids = source->GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), 100);
|
||||
|
||||
size_t singleVectorMem = sizeof(float) * TABLE_DIM;
|
||||
ASSERT_EQ(mem_table_file.GetCurrentMem(), n_100 * singleVectorMem);
|
||||
|
||||
int64_t n_max = milvus::engine::MAX_TABLE_FILE_MEM / singleVectorMem;
|
||||
std::vector<float> vectors_128M;
|
||||
milvus::engine::VectorsData vectors_128M;
|
||||
BuildVectors(n_max, vectors_128M);
|
||||
|
||||
milvus::engine::VectorSourcePtr source_128M =
|
||||
std::make_shared<milvus::engine::VectorSource>(n_max, vectors_128M.data());
|
||||
std::make_shared<milvus::engine::VectorSource>(vectors_128M);
|
||||
vector_ids.clear();
|
||||
status = mem_table_file.Add(source_128M, vector_ids);
|
||||
status = mem_table_file.Add(source_128M);
|
||||
|
||||
vector_ids = source_128M->GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), n_max - n_100);
|
||||
ASSERT_EQ(source_128M->GetVectorIds().size(), n_max - n_100);
|
||||
|
||||
ASSERT_TRUE(mem_table_file.IsFull());
|
||||
}
|
||||
|
@ -154,19 +146,18 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
|
|||
ASSERT_TRUE(status.ok());
|
||||
|
||||
int64_t n_100 = 100;
|
||||
std::vector<float> vectors_100;
|
||||
milvus::engine::VectorsData vectors_100;
|
||||
BuildVectors(n_100, vectors_100);
|
||||
|
||||
milvus::engine::VectorSourcePtr source_100 =
|
||||
std::make_shared<milvus::engine::VectorSource>(n_100, vectors_100.data());
|
||||
std::make_shared<milvus::engine::VectorSource>(vectors_100);
|
||||
|
||||
milvus::engine::MemTable mem_table(GetTableName(), impl_, options);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
status = mem_table.Add(source_100, vector_ids);
|
||||
status = mem_table.Add(source_100);
|
||||
ASSERT_TRUE(status.ok());
|
||||
vector_ids = source_100->GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), 100);
|
||||
ASSERT_EQ(source_100->GetVectorIds().size(), 100);
|
||||
|
||||
milvus::engine::MemTableFilePtr mem_table_file;
|
||||
mem_table.GetCurrentMemTableFile(mem_table_file);
|
||||
|
@ -174,17 +165,16 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
|
|||
ASSERT_EQ(mem_table_file->GetCurrentMem(), n_100 * singleVectorMem);
|
||||
|
||||
int64_t n_max = milvus::engine::MAX_TABLE_FILE_MEM / singleVectorMem;
|
||||
std::vector<float> vectors_128M;
|
||||
milvus::engine::VectorsData vectors_128M;
|
||||
BuildVectors(n_max, vectors_128M);
|
||||
|
||||
vector_ids.clear();
|
||||
milvus::engine::VectorSourcePtr source_128M =
|
||||
std::make_shared<milvus::engine::VectorSource>(n_max, vectors_128M.data());
|
||||
status = mem_table.Add(source_128M, vector_ids);
|
||||
std::make_shared<milvus::engine::VectorSource>(vectors_128M);
|
||||
status = mem_table.Add(source_128M);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
vector_ids = source_128M->GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), n_max);
|
||||
ASSERT_EQ(source_128M->GetVectorIds().size(), n_max);
|
||||
|
||||
mem_table.GetCurrentMemTableFile(mem_table_file);
|
||||
ASSERT_EQ(mem_table_file->GetCurrentMem(), n_100 * singleVectorMem);
|
||||
|
@ -192,17 +182,16 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
|
|||
ASSERT_EQ(mem_table.GetTableFileCount(), 2);
|
||||
|
||||
int64_t n_1G = 1024000;
|
||||
std::vector<float> vectors_1G;
|
||||
milvus::engine::VectorsData vectors_1G;
|
||||
BuildVectors(n_1G, vectors_1G);
|
||||
|
||||
milvus::engine::VectorSourcePtr source_1G = std::make_shared<milvus::engine::VectorSource>(n_1G, vectors_1G.data());
|
||||
milvus::engine::VectorSourcePtr source_1G = std::make_shared<milvus::engine::VectorSource>(vectors_1G);
|
||||
|
||||
vector_ids.clear();
|
||||
status = mem_table.Add(source_1G, vector_ids);
|
||||
status = mem_table.Add(source_1G);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
vector_ids = source_1G->GetVectorIds();
|
||||
ASSERT_EQ(vector_ids.size(), n_1G);
|
||||
ASSERT_EQ(source_1G->GetVectorIds().size(), n_1G);
|
||||
|
||||
int expectedTableFileCount = 2 + std::ceil((n_1G - n_100) * singleVectorMem / milvus::engine::MAX_TABLE_FILE_MEM);
|
||||
ASSERT_EQ(mem_table.GetTableFileCount(), expectedTableFileCount);
|
||||
|
@ -222,15 +211,14 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
|
|||
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
|
||||
|
||||
int64_t nb = 100000;
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
for (int64_t i = 0; i < nb; i++) {
|
||||
vector_ids.push_back(i);
|
||||
xb.id_array_.push_back(i);
|
||||
}
|
||||
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::seconds(3)); // ensure raw data write to disk
|
||||
|
@ -240,14 +228,15 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
|
|||
std::uniform_int_distribution<int64_t> dis(0, nb - 1);
|
||||
|
||||
int64_t num_query = 10;
|
||||
std::map<int64_t, std::vector<float>> search_vectors;
|
||||
std::map<int64_t, milvus::engine::VectorsData> search_vectors;
|
||||
for (int64_t i = 0; i < num_query; ++i) {
|
||||
int64_t index = dis(gen);
|
||||
std::vector<float> search;
|
||||
milvus::engine::VectorsData search;
|
||||
search.vector_count_ = 1;
|
||||
for (int64_t j = 0; j < TABLE_DIM; j++) {
|
||||
search.push_back(xb[index * TABLE_DIM + j]);
|
||||
search.float_data_.push_back(xb.float_data_[index * TABLE_DIM + j]);
|
||||
}
|
||||
search_vectors.insert(std::make_pair(vector_ids[index], search));
|
||||
search_vectors.insert(std::make_pair(xb.id_array_[index], search));
|
||||
}
|
||||
|
||||
int topk = 10, nprobe = 10;
|
||||
|
@ -257,7 +246,7 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, 1, nprobe, search.data(), result_ids,
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, search, result_ids,
|
||||
result_distances);
|
||||
ASSERT_EQ(result_ids[0], pair.first);
|
||||
ASSERT_LT(result_distances[0], 1e-4);
|
||||
|
@ -279,10 +268,10 @@ TEST_F(MemManagerTest2, INSERT_TEST) {
|
|||
int insert_loop = 20;
|
||||
for (int i = 0; i < insert_loop; ++i) {
|
||||
int64_t nb = 40960;
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
auto end_time = METRICS_NOW_TIME;
|
||||
|
@ -300,15 +289,12 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
milvus::engine::IDNumbers target_ids;
|
||||
|
||||
int64_t nb = 40960;
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int64_t qb = 5;
|
||||
std::vector<float> qxb;
|
||||
milvus::engine::VectorsData qxb;
|
||||
BuildVectors(qb, qxb);
|
||||
|
||||
std::thread search([&]() {
|
||||
|
@ -331,13 +317,12 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
|
|||
|
||||
std::vector<std::string> tags;
|
||||
stat =
|
||||
db_->Query(dummy_context_, GetTableName(), tags, k, qb, 10, qxb.data(), result_ids, result_distances);
|
||||
db_->Query(dummy_context_, GetTableName(), tags, k, 10, qxb, result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
ASSERT_TRUE(stat.ok());
|
||||
for (auto i = 0; i < qb; ++i) {
|
||||
ASSERT_EQ(result_ids[i * k], target_ids[i]);
|
||||
ss.str("");
|
||||
ss << "Result [" << i << "]:";
|
||||
for (auto t = 0; t < k; t++) {
|
||||
|
@ -354,10 +339,13 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
|
|||
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
if (i == 0) {
|
||||
db_->InsertVectors(GetTableName(), "", qb, qxb.data(), target_ids);
|
||||
ASSERT_EQ(target_ids.size(), qb);
|
||||
qxb.id_array_.clear();
|
||||
db_->InsertVectors(GetTableName(), "", qxb);
|
||||
ASSERT_EQ(qxb.id_array_.size(), qb);
|
||||
} else {
|
||||
db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
xb.id_array_.clear();
|
||||
db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_EQ(xb.id_array_.size(), nb);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
|
@ -375,62 +363,56 @@ TEST_F(MemManagerTest2, VECTOR_IDS_TEST) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
|
||||
int64_t nb = 100000;
|
||||
std::vector<float> xb;
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
vector_ids.resize(nb);
|
||||
xb.id_array_.resize(nb);
|
||||
for (auto i = 0; i < nb; i++) {
|
||||
vector_ids[i] = i;
|
||||
xb.id_array_[i] = i;
|
||||
}
|
||||
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
ASSERT_EQ(vector_ids[0], 0);
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_EQ(xb.id_array_[0], 0);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
nb = 25000;
|
||||
xb.clear();
|
||||
BuildVectors(nb, xb);
|
||||
vector_ids.clear();
|
||||
vector_ids.resize(nb);
|
||||
|
||||
xb.id_array_.resize(nb);
|
||||
for (auto i = 0; i < nb; i++) {
|
||||
vector_ids[i] = i + nb;
|
||||
xb.id_array_[i] = i + nb;
|
||||
}
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
ASSERT_EQ(vector_ids[0], nb);
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_EQ(xb.id_array_[0], nb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
nb = 262144; // 512M
|
||||
xb.clear();
|
||||
BuildVectors(nb, xb);
|
||||
vector_ids.clear();
|
||||
vector_ids.resize(nb);
|
||||
|
||||
xb.id_array_.resize(nb);
|
||||
for (auto i = 0; i < nb; i++) {
|
||||
vector_ids[i] = i + nb / 2;
|
||||
xb.id_array_[i] = i + nb / 2;
|
||||
}
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
ASSERT_EQ(vector_ids[0], nb / 2);
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_EQ(xb.id_array_[0], nb / 2);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
nb = 65536; // 128M
|
||||
xb.clear();
|
||||
BuildVectors(nb, xb);
|
||||
vector_ids.clear();
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
xb.id_array_.clear();
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
nb = 100;
|
||||
xb.clear();
|
||||
BuildVectors(nb, xb);
|
||||
vector_ids.clear();
|
||||
vector_ids.resize(nb);
|
||||
|
||||
xb.id_array_.resize(nb);
|
||||
for (auto i = 0; i < nb; i++) {
|
||||
vector_ids[i] = i + nb;
|
||||
xb.id_array_[i] = i + nb;
|
||||
}
|
||||
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
|
||||
stat = db_->InsertVectors(GetTableName(), "", xb);
|
||||
for (auto i = 0; i < nb; i++) {
|
||||
ASSERT_EQ(vector_ids[i], i + nb);
|
||||
ASSERT_EQ(xb.id_array_[i], i + nb);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,22 @@
|
|||
#include "db/DB.h"
|
||||
#include "db/meta/SqliteMetaImpl.h"
|
||||
|
||||
namespace {
|
||||
static constexpr int64_t TABLE_DIM = 256;
|
||||
|
||||
void
|
||||
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
|
||||
vectors.vector_count_ = n;
|
||||
vectors.float_data_.clear();
|
||||
vectors.float_data_.resize(n * TABLE_DIM);
|
||||
float* data = vectors.float_data_.data();
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
|
||||
data[TABLE_DIM * i] += i / 2000.;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(MetricTest, METRIC_TEST) {
|
||||
milvus::server::SystemInfo::GetInstance().Init();
|
||||
milvus::server::Metrics::GetInstance().Init();
|
||||
|
@ -48,29 +64,18 @@ TEST_F(MetricTest, METRIC_TEST) {
|
|||
group_info_get.table_id_ = group_name;
|
||||
stat = db_->DescribeTable(group_info_get);
|
||||
|
||||
milvus::engine::IDNumbers vector_ids;
|
||||
milvus::engine::IDNumbers target_ids;
|
||||
|
||||
int d = 256;
|
||||
int nb = 50;
|
||||
float *xb = new float[d * nb];
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int j = 0; j < d; j++) xb[d * i + j] = drand48();
|
||||
xb[d * i] += i / 2000.;
|
||||
}
|
||||
milvus::engine::VectorsData xb;
|
||||
BuildVectors(nb, xb);
|
||||
|
||||
int qb = 5;
|
||||
float *qxb = new float[d * qb];
|
||||
for (int i = 0; i < qb; i++) {
|
||||
for (int j = 0; j < d; j++) qxb[d * i + j] = drand48();
|
||||
qxb[d * i] += i / 2000.;
|
||||
}
|
||||
milvus::engine::VectorsData xq;
|
||||
BuildVectors(qb, xq);
|
||||
|
||||
std::thread search([&]() {
|
||||
// std::vector<std::string> tags;
|
||||
// milvus::engine::ResultIds result_ids;
|
||||
// milvus::engine::ResultDistances result_distances;
|
||||
int k = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
|
||||
INIT_TIMER;
|
||||
|
@ -105,18 +110,18 @@ TEST_F(MetricTest, METRIC_TEST) {
|
|||
|
||||
for (auto i = 0; i < loop; ++i) {
|
||||
if (i == 40) {
|
||||
db_->InsertVectors(group_name, "", qb, qxb, target_ids);
|
||||
ASSERT_EQ(target_ids.size(), qb);
|
||||
xq.id_array_.clear();
|
||||
db_->InsertVectors(group_name, "", xq);
|
||||
ASSERT_EQ(xq.id_array_.size(), qb);
|
||||
} else {
|
||||
db_->InsertVectors(group_name, "", nb, xb, vector_ids);
|
||||
xb.id_array_.clear();
|
||||
db_->InsertVectors(group_name, "", xb);
|
||||
ASSERT_EQ(xb.id_array_.size(), nb);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(2000));
|
||||
}
|
||||
|
||||
search.join();
|
||||
|
||||
delete[] xb;
|
||||
delete[] qxb;
|
||||
}
|
||||
|
||||
TEST_F(MetricTest, COLLECTOR_METRICS_TEST) {
|
||||
|
|
|
@ -48,7 +48,7 @@ class MockVecIndex : public engine::VecIndex {
|
|||
}
|
||||
|
||||
engine::IndexType
|
||||
GetType() override {
|
||||
GetType() const override {
|
||||
return engine::IndexType::INVALID;
|
||||
}
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ class MockVecIndex : public milvus::engine::VecIndex {
|
|||
}
|
||||
|
||||
milvus::engine::IndexType
|
||||
GetType() override {
|
||||
GetType() const override {
|
||||
return milvus::engine::IndexType::INVALID;
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ constexpr int64_t SECONDS_EACH_HOUR = 3600;
|
|||
|
||||
void
|
||||
CopyRowRecord(::milvus::grpc::RowRecord* target, const std::vector<float>& src) {
|
||||
auto vector_data = target->mutable_vector_data();
|
||||
auto vector_data = target->mutable_float_data();
|
||||
vector_data->Resize(static_cast<int>(src.size()), 0.0);
|
||||
memcpy(vector_data->mutable_data(), src.data(), src.size() * sizeof(float));
|
||||
}
|
||||
|
|
|
@ -283,9 +283,9 @@ TEST(ValidationUtilTest, VALIDATE_DIMENSION_TEST) {
|
|||
milvus::SERVER_INVALID_VECTOR_DIMENSION);
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(0).code(),
|
||||
milvus::SERVER_INVALID_VECTOR_DIMENSION);
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(16385).code(),
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32769).code(),
|
||||
milvus::SERVER_INVALID_VECTOR_DIMENSION);
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(16384).code(), milvus::SERVER_SUCCESS);
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32768).code(), milvus::SERVER_SUCCESS);
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(1).code(), milvus::SERVER_SUCCESS);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
set(test_files
|
||||
test_knowhere.cpp
|
||||
test_binindex.cpp
|
||||
test_wrapper.cpp)
|
||||
if (MILVUS_GPU_VERSION)
|
||||
set(test_files ${test_files}
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <tuple>
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#include "wrapper/VecIndex.h"
|
||||
#include "wrapper/utils.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class BinIndexTest : public BinDataGen,
|
||||
public TestWithParam<::std::tuple<milvus::engine::IndexType, int, int, int, int>> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
std::tie(index_type, dim, nb, nq, k) = GetParam();
|
||||
Generate(dim, nb, nq, k);
|
||||
|
||||
milvus::engine::TempMetaConf tempconf;
|
||||
tempconf.metric_type = knowhere::METRICTYPE::TANIMOTO;
|
||||
tempconf.size = nb;
|
||||
tempconf.dim = dim;
|
||||
tempconf.k = k;
|
||||
tempconf.nprobe = 16;
|
||||
|
||||
index_ = GetVecIndexFactory(index_type);
|
||||
conf = ParamGenerator::GetInstance().GenBuild(index_type, tempconf);
|
||||
searchconf = ParamGenerator::GetInstance().GenSearchConf(index_type, tempconf);
|
||||
}
|
||||
|
||||
void
|
||||
TearDown() override {
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::engine::IndexType index_type;
|
||||
milvus::engine::VecIndexPtr index_ = nullptr;
|
||||
knowhere::Config conf;
|
||||
knowhere::Config searchconf;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(WrapperParam, BinIndexTest,
|
||||
Values(
|
||||
//["Index type", "dim", "nb", "nq", "k", "build config", "search config"]
|
||||
std::make_tuple(milvus::engine::IndexType::FAISS_BIN_IDMAP, 64, 1000, 10, 10),
|
||||
std::make_tuple(milvus::engine::IndexType::FAISS_BIN_IVFLAT_CPU, DIM, NB, 10, 10)));
|
||||
|
||||
TEST_P(BinIndexTest, BASE_TEST) {
|
||||
EXPECT_EQ(index_->GetType(), index_type);
|
||||
conf->Dump();
|
||||
searchconf->Dump();
|
||||
|
||||
auto elems = nq * k;
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
|
||||
index_->BuildAll(nb, xb.data(), ids.data(), conf);
|
||||
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
|
||||
AssertResult(res_ids, res_dis);
|
||||
}
|
||||
|
||||
TEST_P(BinIndexTest, SERIALIZE_TEST) {
|
||||
EXPECT_EQ(index_->GetType(), index_type);
|
||||
|
||||
auto elems = nq * k;
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
index_->BuildAll(nb, xb.data(), ids.data(), conf);
|
||||
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
|
||||
AssertResult(res_ids, res_dis);
|
||||
|
||||
{
|
||||
auto binary = index_->Serialize();
|
||||
auto type = index_->GetType();
|
||||
auto new_index = GetVecIndexFactory(type);
|
||||
new_index->Load(binary);
|
||||
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
|
||||
EXPECT_EQ(new_index->Count(), index_->Count());
|
||||
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
|
||||
AssertResult(res_ids, res_dis);
|
||||
}
|
||||
|
||||
{
|
||||
std::string file_location = "/tmp/knowhere";
|
||||
write_index(index_, file_location);
|
||||
auto new_index = milvus::engine::read_index(file_location);
|
||||
EXPECT_EQ(new_index->GetType(), ConvertToCpuIndexType(index_type));
|
||||
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
|
||||
EXPECT_EQ(new_index->Count(), index_->Count());
|
||||
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
|
||||
AssertResult(res_ids, res_dis);
|
||||
}
|
||||
}
|
|
@ -153,3 +153,73 @@ DataGenBase::AssertResult(const std::vector<int64_t>& ids, const std::vector<flo
|
|||
EXPECT_GT(precision, 0.5);
|
||||
std::cout << std::endl << "Precision: " << precision << ", match: " << match << ", total: " << nq * k << std::endl;
|
||||
}
|
||||
|
||||
void
|
||||
BinDataGen::GenData(const int& dim,
|
||||
const int& nb,
|
||||
const int& nq,
|
||||
uint8_t* xb,
|
||||
uint8_t* xq,
|
||||
int64_t* ids,
|
||||
const int& k,
|
||||
int64_t* gt_ids,
|
||||
float* gt_dis) {
|
||||
for (auto i = 0; i < nb; ++i) {
|
||||
for (auto j = 0; j < dim; ++j) {
|
||||
// p_data[i * d + j] = float(base + i);
|
||||
xb[i * dim + j] = (uint8_t)lrand48();
|
||||
}
|
||||
ids[i] = i;
|
||||
}
|
||||
for (int64_t i = 0; i < nq * dim; ++i) {
|
||||
xq[i] = xb[i];
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinDataGen::GenData(const int& dim,
|
||||
const int& nb,
|
||||
const int& nq,
|
||||
std::vector<uint8_t>& xb,
|
||||
std::vector<uint8_t>& xq,
|
||||
std::vector<int64_t>& ids,
|
||||
const int& k,
|
||||
std::vector<int64_t>& gt_ids,
|
||||
std::vector<float>& gt_dis) {
|
||||
xb.clear();
|
||||
xq.clear();
|
||||
ids.clear();
|
||||
gt_ids.clear();
|
||||
gt_dis.clear();
|
||||
xb.resize(nb * dim);
|
||||
xq.resize(nq * dim);
|
||||
ids.resize(nb);
|
||||
gt_ids.resize(nq * k);
|
||||
gt_dis.resize(nq * k);
|
||||
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data());
|
||||
assert(xb.size() == (size_t)dim * nb);
|
||||
assert(xq.size() == (size_t)dim * nq);
|
||||
}
|
||||
|
||||
void
|
||||
BinDataGen::AssertResult(const std::vector<int64_t>& ids, const std::vector<float>& dis) {
|
||||
EXPECT_EQ(ids.size(), nq * k);
|
||||
EXPECT_EQ(dis.size(), nq * k);
|
||||
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
EXPECT_EQ(ids[i * k], i);
|
||||
// EXPECT_EQ(dis[i * k], gt_dis[i * k]);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinDataGen::Generate(const int& dim, const int& nb, const int& nq, const int& k) {
|
||||
this->nb = nb;
|
||||
this->nq = nq;
|
||||
this->dim = dim;
|
||||
this->k = k;
|
||||
|
||||
int64_t dim_x = dim / 8;
|
||||
|
||||
GenData(dim_x, nb, nq, xb, xq, ids, k, gt_ids, gt_dis);
|
||||
}
|
||||
|
|
|
@ -82,6 +82,38 @@ class DataGenBase {
|
|||
std::vector<float> gt_dis;
|
||||
};
|
||||
|
||||
class BinDataGen {
|
||||
public:
|
||||
virtual void GenData(const int& dim, const int& nb, const int& nq, uint8_t* xb, uint8_t* xq, int64_t* ids,
|
||||
const int& k, int64_t* gt_ids, float* gt_dis);
|
||||
|
||||
virtual void GenData(const int& dim,
|
||||
const int& nb,
|
||||
const int& nq,
|
||||
std::vector<uint8_t>& xb,
|
||||
std::vector<uint8_t>& xq,
|
||||
std::vector<int64_t>& ids,
|
||||
const int& k,
|
||||
std::vector<int64_t>& gt_ids,
|
||||
std::vector<float>& gt_dis);
|
||||
|
||||
void AssertResult(const std::vector<int64_t>& ids, const std::vector<float>& dis);
|
||||
|
||||
void Generate(const int& dim, const int& nb, const int& nq, const int& k);
|
||||
|
||||
int dim = DIM;
|
||||
int nb = NB;
|
||||
int nq = NQ;
|
||||
int k = 10;
|
||||
std::vector<uint8_t> xb;
|
||||
std::vector<uint8_t> xq;
|
||||
std::vector<int64_t> ids;
|
||||
|
||||
// Ground Truth
|
||||
std::vector<int64_t> gt_ids;
|
||||
std::vector<float> gt_dis;
|
||||
};
|
||||
|
||||
class ParamGenerator {
|
||||
public:
|
||||
static ParamGenerator& GetInstance() {
|
||||
|
|
|
@ -21,3 +21,4 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/utils util_files)
|
|||
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(partition)
|
||||
add_subdirectory(binary_vector)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#-------------------------------------------------------------------------------
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http:#www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
aux_source_directory(src src_files)
|
||||
|
||||
add_executable(sdk_binary
|
||||
main.cpp
|
||||
${src_files}
|
||||
${util_files}
|
||||
)
|
||||
|
||||
target_link_libraries(sdk_binary
|
||||
milvus_sdk
|
||||
pthread
|
||||
)
|
||||
|
||||
install(TARGETS sdk_binary DESTINATION bin)
|
|
@ -0,0 +1,79 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <getopt.h>
|
||||
#include <libgen.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
#include "src/ClientTest.h"
|
||||
|
||||
void
|
||||
print_help(const std::string& app_name);
|
||||
|
||||
int
|
||||
main(int argc, char* argv[]) {
|
||||
printf("Client start...\n");
|
||||
|
||||
std::string app_name = basename(argv[0]);
|
||||
static struct option long_options[] = {{"server", optional_argument, nullptr, 's'},
|
||||
{"port", optional_argument, nullptr, 'p'},
|
||||
{"help", no_argument, nullptr, 'h'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
int option_index = 0;
|
||||
std::string address = "127.0.0.1", port = "19530";
|
||||
app_name = argv[0];
|
||||
|
||||
int value;
|
||||
while ((value = getopt_long(argc, argv, "s:p:h", long_options, &option_index)) != -1) {
|
||||
switch (value) {
|
||||
case 's': {
|
||||
char* address_ptr = strdup(optarg);
|
||||
address = address_ptr;
|
||||
free(address_ptr);
|
||||
break;
|
||||
}
|
||||
case 'p': {
|
||||
char* port_ptr = strdup(optarg);
|
||||
port = port_ptr;
|
||||
free(port_ptr);
|
||||
break;
|
||||
}
|
||||
case 'h':
|
||||
default:
|
||||
print_help(app_name);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
ClientTest test;
|
||||
test.Test(address, port);
|
||||
|
||||
printf("Client stop...\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
void
|
||||
print_help(const std::string& app_name) {
|
||||
printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
|
||||
printf(" Options:\n");
|
||||
printf(" -s --server Server address, default 127.0.0.1\n");
|
||||
printf(" -p --port Server port, default 19530\n");
|
||||
printf(" -h --help Print help information\n");
|
||||
printf("\n");
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "examples/simple/src/ClientTest.h"
|
||||
#include "include/MilvusApi.h"
|
||||
#include "examples/utils/TimeRecorder.h"
|
||||
#include "examples/utils/Utils.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
|
||||
namespace {
|
||||
|
||||
const char* TABLE_NAME = milvus_sdk::Utils::GenTableName().c_str();
|
||||
|
||||
constexpr int64_t TABLE_DIMENSION = 512;
|
||||
constexpr int64_t TABLE_INDEX_FILE_SIZE = 128;
|
||||
constexpr milvus::MetricType TABLE_METRIC_TYPE = milvus::MetricType::TANIMOTO;
|
||||
constexpr int64_t BATCH_ROW_COUNT = 100000;
|
||||
constexpr int64_t NQ = 5;
|
||||
constexpr int64_t TOP_K = 10;
|
||||
constexpr int64_t NPROBE = 32;
|
||||
constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different, ensure less than BATCH_ROW_COUNT
|
||||
constexpr int64_t ADD_VECTOR_LOOP = 20;
|
||||
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFFLAT;
|
||||
constexpr int32_t N_LIST = 1024;
|
||||
|
||||
milvus::TableSchema
|
||||
BuildTableSchema() {
|
||||
milvus::TableSchema tb_schema = {TABLE_NAME, TABLE_DIMENSION, TABLE_INDEX_FILE_SIZE, TABLE_METRIC_TYPE};
|
||||
return tb_schema;
|
||||
}
|
||||
|
||||
milvus::IndexParam
|
||||
BuildIndexParam() {
|
||||
milvus::IndexParam index_param = {TABLE_NAME, INDEX_TYPE, N_LIST};
|
||||
return index_param;
|
||||
}
|
||||
|
||||
void
|
||||
BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::RowRecord>& vector_record_array,
|
||||
std::vector<int64_t>& record_ids, int64_t dimension) {
|
||||
if (to <= from) {
|
||||
return;
|
||||
}
|
||||
|
||||
vector_record_array.clear();
|
||||
record_ids.clear();
|
||||
|
||||
int64_t dim_byte = dimension/8;
|
||||
for (int64_t k = from; k < to; k++) {
|
||||
milvus::RowRecord record;
|
||||
record.binary_data.resize(dim_byte);
|
||||
for (int64_t i = 0; i < dim_byte; i++) {
|
||||
record.binary_data[i] = (uint8_t)lrand48();
|
||||
}
|
||||
|
||||
vector_record_array.emplace_back(record);
|
||||
record_ids.push_back(k);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void
|
||||
ClientTest::Test(const std::string& address, const std::string& port) {
|
||||
std::shared_ptr<milvus::Connection> conn = milvus::Connection::Create();
|
||||
|
||||
milvus::Status stat;
|
||||
{ // connect server
|
||||
milvus::ConnectParam param = {address, port};
|
||||
stat = conn->Connect(param);
|
||||
std::cout << "Connect function call status: " << stat.message() << std::endl;
|
||||
}
|
||||
|
||||
{ // create table
|
||||
milvus::TableSchema tb_schema = BuildTableSchema();
|
||||
stat = conn->CreateTable(tb_schema);
|
||||
std::cout << "CreateTable function call status: " << stat.message() << std::endl;
|
||||
milvus_sdk::Utils::PrintTableSchema(tb_schema);
|
||||
|
||||
bool has_table = conn->HasTable(tb_schema.table_name);
|
||||
if (has_table) {
|
||||
std::cout << "Table is created" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, milvus::RowRecord>> search_record_array;
|
||||
{ // insert vectors
|
||||
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {
|
||||
std::vector<milvus::RowRecord> record_array;
|
||||
std::vector<int64_t> record_ids;
|
||||
int64_t begin_index = i * BATCH_ROW_COUNT;
|
||||
{ // generate vectors
|
||||
milvus_sdk::TimeRecorder rc("Build vectors No." + std::to_string(i));
|
||||
BuildBinaryVectors(begin_index,
|
||||
begin_index + BATCH_ROW_COUNT,
|
||||
record_array,
|
||||
record_ids,
|
||||
TABLE_DIMENSION);
|
||||
}
|
||||
|
||||
if (search_record_array.size() < NQ) {
|
||||
search_record_array.push_back(std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET]));
|
||||
}
|
||||
|
||||
std::string title = "Insert " + std::to_string(record_array.size()) + " vectors No." + std::to_string(i);
|
||||
milvus_sdk::TimeRecorder rc(title);
|
||||
stat = conn->Insert(TABLE_NAME, "", record_array, record_ids);
|
||||
std::cout << "InsertVector function call status: " << stat.message() << std::endl;
|
||||
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
milvus_sdk::Utils::Sleep(3);
|
||||
{ // search vectors
|
||||
std::vector<std::string> partition_tags;
|
||||
milvus::TopKQueryResult topk_query_result;
|
||||
milvus_sdk::Utils::DoSearch(conn, TABLE_NAME, partition_tags, TOP_K, NPROBE, search_record_array,
|
||||
topk_query_result);
|
||||
}
|
||||
|
||||
{ // wait unit build index finish
|
||||
milvus_sdk::TimeRecorder rc("Create index");
|
||||
std::cout << "Wait until create all index done" << std::endl;
|
||||
milvus::IndexParam index1 = BuildIndexParam();
|
||||
milvus_sdk::Utils::PrintIndexParam(index1);
|
||||
stat = conn->CreateIndex(index1);
|
||||
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
|
||||
|
||||
milvus::IndexParam index2;
|
||||
stat = conn->DescribeIndex(TABLE_NAME, index2);
|
||||
std::cout << "DescribeIndex function call status: " << stat.message() << std::endl;
|
||||
milvus_sdk::Utils::PrintIndexParam(index2);
|
||||
}
|
||||
|
||||
{ // search vectors
|
||||
std::vector<std::string> partition_tags;
|
||||
milvus::TopKQueryResult topk_query_result;
|
||||
milvus_sdk::Utils::DoSearch(conn, TABLE_NAME, partition_tags, TOP_K, NPROBE, search_record_array,
|
||||
topk_query_result);
|
||||
}
|
||||
|
||||
{ // drop table
|
||||
stat = conn->DropTable(TABLE_NAME);
|
||||
std::cout << "DropTable function call status: " << stat.message() << std::endl;
|
||||
}
|
||||
|
||||
milvus::Connection::Destroy(conn);
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
class ClientTest {
|
||||
public:
|
||||
void
|
||||
Test(const std::string& address, const std::string& port);
|
||||
};
|
|
@ -76,36 +76,27 @@ Utils::GenTableName() {
|
|||
std::string
|
||||
Utils::MetricTypeName(const milvus::MetricType& metric_type) {
|
||||
switch (metric_type) {
|
||||
case milvus::MetricType::L2:
|
||||
return "L2 distance";
|
||||
case milvus::MetricType::IP:
|
||||
return "Inner product";
|
||||
default:
|
||||
return "Unknown metric type";
|
||||
case milvus::MetricType::L2:return "L2 distance";
|
||||
case milvus::MetricType::IP:return "Inner product";
|
||||
case milvus::MetricType::HAMMING:return "Hamming distance";
|
||||
case milvus::MetricType::JACCARD:return "Jaccard distance";
|
||||
case milvus::MetricType::TANIMOTO:return "Tanimoto distance";
|
||||
default:return "Unknown metric type";
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
Utils::IndexTypeName(const milvus::IndexType& index_type) {
|
||||
switch (index_type) {
|
||||
case milvus::IndexType::FLAT:
|
||||
return "FLAT";
|
||||
case milvus::IndexType::IVFFLAT:
|
||||
return "IVFFLAT";
|
||||
case milvus::IndexType::IVFSQ8:
|
||||
return "IVFSQ8";
|
||||
case milvus::IndexType::RNSG:
|
||||
return "NSG";
|
||||
case milvus::IndexType::IVFSQ8H:
|
||||
return "IVFSQ8H";
|
||||
case milvus::IndexType::IVFPQ:
|
||||
return "IVFPQ";
|
||||
case milvus::IndexType::SPTAGKDT:
|
||||
return "SPTAGKDT";
|
||||
case milvus::IndexType::SPTAGBKT:
|
||||
return "SPTAGBKT";
|
||||
default:
|
||||
return "Unknown index type";
|
||||
case milvus::IndexType::FLAT:return "FLAT";
|
||||
case milvus::IndexType::IVFFLAT:return "IVFFLAT";
|
||||
case milvus::IndexType::IVFSQ8:return "IVFSQ8";
|
||||
case milvus::IndexType::RNSG:return "NSG";
|
||||
case milvus::IndexType::IVFSQ8H:return "IVFSQ8H";
|
||||
case milvus::IndexType::IVFPQ:return "IVFPQ";
|
||||
case milvus::IndexType::SPTAGKDT:return "SPTAGKDT";
|
||||
case milvus::IndexType::SPTAGBKT:return "SPTAGBKT";
|
||||
default:return "Unknown index type";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,9 +139,9 @@ Utils::BuildVectors(int64_t from, int64_t to, std::vector<milvus::RowRecord>& ve
|
|||
record_ids.clear();
|
||||
for (int64_t k = from; k < to; k++) {
|
||||
milvus::RowRecord record;
|
||||
record.data.resize(dimension);
|
||||
record.float_data.resize(dimension);
|
||||
for (int64_t i = 0; i < dimension; i++) {
|
||||
record.data[i] = (float)(k % (i + 1));
|
||||
record.float_data[i] = (float)(k % (i + 1));
|
||||
}
|
||||
|
||||
vector_record_array.emplace_back(record);
|
||||
|
@ -189,11 +180,20 @@ Utils::CheckSearchResult(const std::vector<std::pair<int64_t, milvus::RowRecord>
|
|||
for (size_t i = 0; i < nq; i++) {
|
||||
const milvus::QueryResult& one_result = topk_query_result[i];
|
||||
auto search_id = search_record_array[i].first;
|
||||
int64_t result_id = one_result.ids[0];
|
||||
if (result_id != search_id) {
|
||||
std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl;
|
||||
|
||||
uint64_t match_index = one_result.ids.size();
|
||||
for (uint64_t index = 0; index < one_result.ids.size(); index++) {
|
||||
if (search_id == one_result.ids[index]) {
|
||||
match_index = index;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (match_index >= one_result.ids.size()) {
|
||||
std::cout << "The topk result is wrong: not return search target in result set" << std::endl;
|
||||
} else {
|
||||
std::cout << "No." << i << " Check result successfully" << std::endl;
|
||||
std::cout << "No." << i << " Check result successfully for target: " << search_id << " at top "
|
||||
<< match_index << std::endl;
|
||||
}
|
||||
}
|
||||
BLOCK_SPLITER
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue