Tanimoto distance (#1016)

* Add log to debug #678

* Rename nsg_mix to RNSG in C++ sdk #735

* [skip ci] change __function__

* clang-format

* #766 If partition tag is similar, wrong partition is searched

* #766 If partition tag is similar, wrong partition is searched

* reorder changelog id

* typo

* define interface

* Define interface (#832)

* If partition tag is similar, wrong partition is searched  (#825)

* #766 If partition tag is similar, wrong partition is searched

* #766 If partition tag is similar, wrong partition is searched

* reorder changelog id

* typo

* define interface Attach files by dragging & dropping, selecting or pasting them. 

Co-authored-by: groot <yihua.mo@zilliz.com>

* faiss & knowhere

* faiss & knowhere (#842)

* Add log to debug #678

* Rename nsg_mix to RNSG in C++ sdk #735

* [skip ci] change __function__

* clang-format

* If partition tag is similar, wrong partition is searched  (#825)

* #766 If partition tag is similar, wrong partition is searched

* #766 If partition tag is similar, wrong partition is searched

* reorder changelog id

* typo

* faiss & knowhere

Co-authored-by: groot <yihua.mo@zilliz.com>

* support binary input

* code lint

* add wrapper interface

* add knowhere unittest

* sdk support binary

* support using metric tanimoto and hamming

* sdk binary insert/query example

* fix bug

* fix bug

* update wrapper

* format

* Improve unittest and fix bugs

* delete printresult

* fix bug

* #823 Support binary vector tanimoto metric

* fix typo

* dimension limit to 32768

* fix

* dimension limit to 32768

* fix describe index bug

* fix #886

* fix #889

* add jaccard cases

* hamming dev-test case

* change test_connect

* Add tanimoto cases

* change the output type of hamming

* add abs

* merge master

* rearrange changelog id

* modify feature description

Co-authored-by: Yukikaze-CZR <48198922+Yukikaze-CZR@users.noreply.github.com>
Co-authored-by: Tinkerrr <linxiaojun.cn@outlook.com>
pull/1049/head^2
groot 2020-01-14 19:22:27 +08:00 committed by Jin Hai
parent 297e7e8831
commit 0f1aa5f8bb
113 changed files with 5191 additions and 1051 deletions

View File

@ -16,12 +16,13 @@ Please mark all change in change log and use the issue from GitHub
- \#216 - Add CLI to get server info
- \#343 - Add Opentracing
- \#665 - Support get/set config via CLI
- \#759 - Put C++ sdk out of milvus/core
- \#766 - If partition tag is similar, wrong partition is searched
- \#771 - Add server build commit info interface
- \#759 - Put C++ sdk out of milvus/core
- \#788 - Add web server into server module
- \#813 - Add push mode for prometheus monitor
- \#815 - Support MinIO storage
- \#823 - Support binary vector tanimoto/jaccard/hamming metric
- \#910 - Change Milvus c++ standard to c++17
## Improvement

View File

@ -29,7 +29,7 @@ constexpr uint64_t T = K * G;
constexpr uint64_t MAX_TABLE_FILE_MEM = 128 * M;
constexpr int VECTOR_TYPE_SIZE = sizeof(float);
constexpr int FLOAT_TYPE_SIZE = sizeof(float);
static constexpr uint64_t ONE_KB = K;
static constexpr uint64_t ONE_MB = ONE_KB * ONE_KB;

View File

@ -84,25 +84,22 @@ class DB {
ShowPartitions(const std::string& table_id, std::vector<meta::TableSchema>& partition_schema_array) = 0;
virtual Status
InsertVectors(const std::string& table_id, const std::string& partition_tag, uint64_t n, const float* vectors,
IDNumbers& vector_ids_) = 0;
InsertVectors(const std::string& table_id, const std::string& partition_tag, VectorsData& vectors) = 0;
virtual Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) = 0;
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) = 0;
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
Size(uint64_t& result) = 0;

View File

@ -317,8 +317,7 @@ DBImpl::ShowPartitions(const std::string& table_id, std::vector<meta::TableSchem
}
Status
DBImpl::InsertVectors(const std::string& table_id, const std::string& partition_tag, uint64_t n, const float* vectors,
IDNumbers& vector_ids) {
DBImpl::InsertVectors(const std::string& table_id, const std::string& partition_tag, VectorsData& vectors) {
// ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache";
if (!initialized_.load(std::memory_order_acquire)) {
return SHUTDOWN_ERROR;
@ -337,8 +336,8 @@ DBImpl::InsertVectors(const std::string& table_id, const std::string& partition_
}
// insert vectors into target table
milvus::server::CollectInsertMetrics metrics(n, status);
status = mem_mgr_->InsertVectors(target_table_name, n, vectors, vector_ids);
milvus::server::CollectInsertMetrics metrics(vectors.vector_count_, status);
status = mem_mgr_->InsertVectors(target_table_name, vectors);
return status;
}
@ -407,23 +406,21 @@ DBImpl::DropIndex(const std::string& table_id) {
Status
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
if (!initialized_.load(std::memory_order_acquire)) {
return SHUTDOWN_ERROR;
}
meta::DatesT dates = {utils::GetDate()};
Status result =
Query(context, table_id, partition_tags, k, nq, nprobe, vectors, dates, result_ids, result_distances);
Status result = Query(context, table_id, partition_tags, k, nprobe, vectors, dates, result_ids, result_distances);
return result;
}
Status
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) {
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_ctx = context->Child("Query");
if (!initialized_.load(std::memory_order_acquire)) {
@ -460,7 +457,7 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(query_ctx, table_id, files_array, k, nq, nprobe, vectors, result_ids, result_distances);
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
query_ctx->GetTraceContext()->GetSpan()->Finish();
@ -470,9 +467,8 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
Status
DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) {
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_ctx = context->Child("Query by file id");
if (!initialized_.load(std::memory_order_acquire)) {
@ -501,7 +497,7 @@ DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(query_ctx, table_id, files_array, k, nq, nprobe, vectors, result_ids, result_distances);
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
query_ctx->GetTraceContext()->GetSpan()->Finish();
@ -523,11 +519,11 @@ DBImpl::Size(uint64_t& result) {
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status
DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
auto query_async_ctx = context->Child("Query Async");
server::CollectQueryMetrics metrics(nq);
server::CollectQueryMetrics metrics(vectors.vector_count_);
TimeRecorder rc("");
@ -535,7 +531,7 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::s
auto status = ongoing_files_checker_.MarkOngoingFiles(files);
ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size();
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nq, nprobe, vectors);
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nprobe, vectors);
for (auto& file : files) {
scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
job->AddIndexFile(file_ptr);

View File

@ -92,8 +92,7 @@ class DBImpl : public DB {
ShowPartitions(const std::string& table_id, std::vector<meta::TableSchema>& partition_schema_array) override;
Status
InsertVectors(const std::string& table_id, const std::string& partition_tag, uint64_t n, const float* vectors,
IDNumbers& vector_ids) override;
InsertVectors(const std::string& table_id, const std::string& partition_tag, VectorsData& vectors) override;
Status
CreateIndex(const std::string& table_id, const TableIndex& index) override;
@ -106,20 +105,18 @@ class DBImpl : public DB {
Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) override;
Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) override;
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) override;
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
Size(uint64_t& result) override;
@ -127,7 +124,7 @@ class DBImpl : public DB {
private:
Status
QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances);
void

View File

@ -43,6 +43,13 @@ struct TableIndex {
int32_t metric_type_ = (int)MetricType::L2;
};
struct VectorsData {
uint64_t vector_count_ = 0;
std::vector<float> float_data_;
std::vector<uint8_t> binary_data_;
IDNumbers id_array_;
};
using File2RefCount = std::map<std::string, int64_t>;
using Table2Files = std::map<std::string, File2RefCount>;

View File

@ -37,12 +37,18 @@ enum class EngineType {
FAISS_PQ,
SPTAG_KDT,
SPTAG_BKT,
MAX_VALUE = SPTAG_BKT,
FAISS_BIN_IDMAP,
FAISS_BIN_IVFFLAT,
MAX_VALUE = FAISS_BIN_IVFFLAT,
};
enum class MetricType {
L2 = 1,
IP = 2,
L2 = 1, // Euclidean Distance
IP = 2, // Cosine Similarity
HAMMING = 3, // Hamming Distance
JACCARD = 4, // Jaccard Distance
TANIMOTO = 5, // Tanimoto Distance
MAX_VALUE = TANIMOTO,
};
class ExecutionEngine {
@ -50,6 +56,9 @@ class ExecutionEngine {
virtual Status
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) = 0;
virtual Status
AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) = 0;
virtual size_t
Count() const = 0;
@ -86,6 +95,10 @@ class ExecutionEngine {
virtual Status
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels, bool hybrid) = 0;
virtual Status
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) = 0;
virtual std::shared_ptr<ExecutionEngine>
BuildIndex(const std::string& location, EngineType engine_type) = 0;

View File

@ -25,7 +25,9 @@
#include "utils/CommonUtil.h"
#include "utils/Exception.h"
#include "utils/Log.h"
#include "utils/ValidationUtil.h"
#include "wrapper/BinVecImpl.h"
#include "wrapper/ConfAdapter.h"
#include "wrapper/ConfAdapterMgr.h"
#include "wrapper/VecImpl.h"
@ -39,6 +41,40 @@
namespace milvus {
namespace engine {
namespace {
Status
MappingMetricType(MetricType metric_type, knowhere::METRICTYPE& kw_type) {
switch (metric_type) {
case MetricType::IP:
kw_type = knowhere::METRICTYPE::IP;
break;
case MetricType::L2:
kw_type = knowhere::METRICTYPE::L2;
break;
case MetricType::HAMMING:
kw_type = knowhere::METRICTYPE::HAMMING;
break;
case MetricType::JACCARD:
kw_type = knowhere::METRICTYPE::JACCARD;
break;
case MetricType::TANIMOTO:
kw_type = knowhere::METRICTYPE::TANIMOTO;
break;
default:
return Status(DB_ERROR, "Unsupported metric type");
}
return Status::OK();
}
bool
IsBinaryIndexType(IndexType type) {
return type == IndexType::FAISS_BIN_IDMAP || type == IndexType::FAISS_BIN_IVFLAT_CPU;
}
} // namespace
class CachedQuantizer : public cache::DataObj {
public:
explicit CachedQuantizer(knowhere::QuantizerPtr data) : data_(std::move(data)) {
@ -61,7 +97,10 @@ class CachedQuantizer : public cache::DataObj {
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
MetricType metric_type, int32_t nlist)
: location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
index_ = CreatetVecIndex(EngineType::FAISS_IDMAP);
EngineType tmp_index_type = server::ValidationUtil::IsBinaryMetricType((int32_t)metric_type)
? EngineType::FAISS_BIN_IDMAP
: EngineType::FAISS_IDMAP;
index_ = CreatetVecIndex(tmp_index_type);
if (!index_) {
throw Exception(DB_ERROR, "Unsupported index type");
}
@ -69,11 +108,20 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = dimension;
temp_conf.metric_type = (metric_type_ == MetricType::IP) ? knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2;
auto status = MappingMetricType(metric_type, temp_conf.metric_type);
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->Match(temp_conf);
auto ec = std::static_pointer_cast<BFIndex>(index_)->Build(conf);
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
if (auto bf_index = std::dynamic_pointer_cast<BFIndex>(index_)) {
ec = bf_index->Build(conf);
} else if (auto bf_bin_index = std::dynamic_pointer_cast<BinBFIndex>(index_)) {
ec = bf_bin_index->Build(conf);
}
if (ec != KNOWHERE_SUCCESS) {
throw Exception(DB_ERROR, "Build index error");
}
@ -148,6 +196,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
index = GetVecIndexFactory(IndexType::SPTAG_BKT_RNT_CPU);
break;
}
case EngineType::FAISS_BIN_IDMAP: {
index = GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
break;
}
case EngineType::FAISS_BIN_IVFFLAT: {
index = GetVecIndexFactory(IndexType::FAISS_BIN_IVFLAT_CPU);
break;
}
default: {
ENGINE_LOG_ERROR << "Unsupported index type";
return nullptr;
@ -242,6 +298,12 @@ ExecutionEngineImpl::AddWithIds(int64_t n, const float* xdata, const int64_t* xi
return status;
}
Status
ExecutionEngineImpl::AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) {
auto status = index_->Add(n, xdata, xids);
return status;
}
size_t
ExecutionEngineImpl::Count() const {
if (index_ == nullptr) {
@ -253,7 +315,11 @@ ExecutionEngineImpl::Count() const {
size_t
ExecutionEngineImpl::Size() const {
return (size_t)(Count() * Dimension()) * sizeof(float);
if (IsBinaryIndexType(index_->GetType())) {
return (size_t)(Count() * Dimension() / 8);
} else {
return (size_t)(Count() * Dimension()) * sizeof(float);
}
}
size_t
@ -483,6 +549,14 @@ ExecutionEngineImpl::Merge(const std::string& location) {
ENGINE_LOG_DEBUG << "Finish merge index file: " << location;
}
return status;
} else if (auto bin_index = std::dynamic_pointer_cast<BinBFIndex>(to_merge)) {
auto status = index_->Add(bin_index->Count(), bin_index->GetRawVectors(), bin_index->GetRawIds());
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to merge: " << location << " to: " << location_;
} else {
ENGINE_LOG_DEBUG << "Finish merge index file: " << location;
}
return status;
} else {
return Status(DB_ERROR, "file index type is not idmap");
}
@ -493,7 +567,8 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
auto from_index = std::dynamic_pointer_cast<BFIndex>(index_);
if (from_index == nullptr) {
auto bin_from_index = std::dynamic_pointer_cast<BinBFIndex>(index_);
if (from_index == nullptr && bin_from_index == nullptr) {
ENGINE_LOG_ERROR << "ExecutionEngineImpl: from_index is null, failed to build index";
return nullptr;
}
@ -507,13 +582,20 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = Dimension();
temp_conf.nlist = nlist_;
temp_conf.metric_type = (metric_type_ == MetricType::IP) ? knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2;
temp_conf.size = Count();
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType());
auto conf = adapter->Match(temp_conf);
auto status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
if (from_index) {
status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
} else if (bin_from_index) {
status = to_index->BuildAll(Count(), bin_from_index->GetRawVectors(), bin_from_index->GetRawIds(), conf);
}
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
@ -611,6 +693,40 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
return status;
}
Status
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances,
int64_t* labels, bool hybrid) {
if (index_ == nullptr) {
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search";
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
if (hybrid) {
HybridLoad();
}
auto status = index_->Search(n, data, distances, labels, conf);
if (hybrid) {
HybridUnset();
}
if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error:" << status.message();
}
return status;
}
Status
ExecutionEngineImpl::Cache() {
cache::DataObjPtr obj = std::static_pointer_cast<cache::DataObj>(index_);

View File

@ -37,6 +37,9 @@ class ExecutionEngineImpl : public ExecutionEngine {
Status
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override;
Status
AddWithIds(int64_t n, const uint8_t* xdata, const int64_t* xids) override;
size_t
Count() const override;
@ -74,6 +77,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid = false) override;
Status
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid = false) override;
ExecutionEnginePtr
BuildIndex(const std::string& location, EngineType engine_type) override;

View File

@ -30,7 +30,7 @@ namespace engine {
class MemManager {
public:
virtual Status
InsertVectors(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids) = 0;
InsertVectors(const std::string& table_id, VectorsData& vectors) = 0;
virtual Status
Serialize(std::set<std::string>& table_ids) = 0;

View File

@ -37,26 +37,25 @@ MemManagerImpl::GetMemByTable(const std::string& table_id) {
}
Status
MemManagerImpl::InsertVectors(const std::string& table_id_, size_t n_, const float* vectors_, IDNumbers& vector_ids_) {
MemManagerImpl::InsertVectors(const std::string& table_id, VectorsData& vectors) {
while (GetCurrentMem() > options_.insert_buffer_size_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
std::unique_lock<std::mutex> lock(mutex_);
return InsertVectorsNoLock(table_id_, n_, vectors_, vector_ids_);
return InsertVectorsNoLock(table_id, vectors);
}
Status
MemManagerImpl::InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors,
IDNumbers& vector_ids) {
MemManagerImpl::InsertVectorsNoLock(const std::string& table_id, VectorsData& vectors) {
MemTablePtr mem = GetMemByTable(table_id);
VectorSourcePtr source = std::make_shared<VectorSource>(n, vectors);
VectorSourcePtr source = std::make_shared<VectorSource>(vectors);
auto status = mem->Add(source, vector_ids);
auto status = mem->Add(source);
if (status.ok()) {
if (vector_ids.empty()) {
vector_ids = source->GetVectorIds();
if (vectors.id_array_.empty()) {
vectors.id_array_ = source->GetVectorIds();
}
}
return status;

View File

@ -41,7 +41,7 @@ class MemManagerImpl : public MemManager {
}
Status
InsertVectors(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids) override;
InsertVectors(const std::string& table_id, VectorsData& vectors) override;
Status
Serialize(std::set<std::string>& table_ids) override;
@ -63,7 +63,7 @@ class MemManagerImpl : public MemManager {
GetMemByTable(const std::string& table_id);
Status
InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids);
InsertVectorsNoLock(const std::string& table_id, VectorsData& vectors);
Status
ToImmutable();

View File

@ -29,7 +29,7 @@ MemTable::MemTable(const std::string& table_id, const meta::MetaPtr& meta, const
}
Status
MemTable::Add(VectorSourcePtr& source, IDNumbers& vector_ids) {
MemTable::Add(VectorSourcePtr& source) {
while (!source->AllAdded()) {
MemTableFilePtr current_mem_table_file;
if (!mem_table_file_list_.empty()) {
@ -39,12 +39,12 @@ MemTable::Add(VectorSourcePtr& source, IDNumbers& vector_ids) {
Status status;
if (mem_table_file_list_.empty() || current_mem_table_file->IsFull()) {
MemTableFilePtr new_mem_table_file = std::make_shared<MemTableFile>(table_id_, meta_, options_);
status = new_mem_table_file->Add(source, vector_ids);
status = new_mem_table_file->Add(source);
if (status.ok()) {
mem_table_file_list_.emplace_back(new_mem_table_file);
}
} else {
status = current_mem_table_file->Add(source, vector_ids);
status = current_mem_table_file->Add(source);
}
if (!status.ok()) {

View File

@ -36,7 +36,7 @@ class MemTable {
MemTable(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options);
Status
Add(VectorSourcePtr& source, IDNumbers& vector_ids);
Add(VectorSourcePtr& source);
void
GetCurrentMemTableFile(MemTableFilePtr& mem_table_file);

View File

@ -53,7 +53,7 @@ MemTableFile::CreateTableFile() {
}
Status
MemTableFile::Add(const VectorSourcePtr& source, IDNumbers& vector_ids) {
MemTableFile::Add(VectorSourcePtr& source) {
if (table_file_schema_.dimension_ <= 0) {
std::string err_msg =
"MemTableFile::Add: table_file_schema dimension = " + std::to_string(table_file_schema_.dimension_) +
@ -62,13 +62,12 @@ MemTableFile::Add(const VectorSourcePtr& source, IDNumbers& vector_ids) {
return Status(DB_ERROR, "Not able to create table file");
}
size_t single_vector_mem_size = table_file_schema_.dimension_ * VECTOR_TYPE_SIZE;
size_t single_vector_mem_size = source->SingleVectorSize(table_file_schema_.dimension_);
size_t mem_left = GetMemLeft();
if (mem_left >= single_vector_mem_size) {
size_t num_vectors_to_add = std::ceil(mem_left / single_vector_mem_size);
size_t num_vectors_added;
auto status =
source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added, vector_ids);
auto status = source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added);
if (status.ok()) {
current_mem_ += (num_vectors_added * single_vector_mem_size);
}
@ -89,7 +88,7 @@ MemTableFile::GetMemLeft() {
bool
MemTableFile::IsFull() {
size_t single_vector_mem_size = table_file_schema_.dimension_ * VECTOR_TYPE_SIZE;
size_t single_vector_mem_size = table_file_schema_.dimension_ * FLOAT_TYPE_SIZE;
return (GetMemLeft() < single_vector_mem_size);
}
@ -104,7 +103,8 @@ MemTableFile::Serialize() {
// if index type isn't IDMAP, set file type to TO_INDEX if file size execeed index_file_size
// else set file type to RAW, no need to build index
if (table_file_schema_.engine_type_ != (int)EngineType::FAISS_IDMAP) {
if (table_file_schema_.engine_type_ != (int)EngineType::FAISS_IDMAP &&
table_file_schema_.engine_type_ != (int)EngineType::FAISS_BIN_IDMAP) {
table_file_schema_.file_type_ = (size >= table_file_schema_.index_file_size_) ? meta::TableFileSchema::TO_INDEX
: meta::TableFileSchema::RAW;
} else {

View File

@ -33,7 +33,7 @@ class MemTableFile {
MemTableFile(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options);
Status
Add(const VectorSourcePtr& source, IDNumbers& vector_ids);
Add(VectorSourcePtr& source);
size_t
GetCurrentMem();

View File

@ -24,30 +24,41 @@
namespace milvus {
namespace engine {
VectorSource::VectorSource(const size_t& n, const float* vectors)
: n_(n), vectors_(vectors), id_generator_(std::make_shared<SimpleIDGenerator>()) {
VectorSource::VectorSource(VectorsData& vectors)
: vectors_(vectors), id_generator_(std::make_shared<SimpleIDGenerator>()) {
current_num_vectors_added = 0;
}
Status
VectorSource::Add(const ExecutionEnginePtr& execution_engine, const meta::TableFileSchema& table_file_schema,
const size_t& num_vectors_to_add, size_t& num_vectors_added, IDNumbers& vector_ids) {
server::CollectAddMetrics metrics(n_, table_file_schema.dimension_);
const size_t& num_vectors_to_add, size_t& num_vectors_added) {
uint64_t n = vectors_.vector_count_;
server::CollectAddMetrics metrics(n, table_file_schema.dimension_);
num_vectors_added =
current_num_vectors_added + num_vectors_to_add <= n_ ? num_vectors_to_add : n_ - current_num_vectors_added;
current_num_vectors_added + num_vectors_to_add <= n ? num_vectors_to_add : n - current_num_vectors_added;
IDNumbers vector_ids_to_add;
if (vector_ids.empty()) {
if (vectors_.id_array_.empty()) {
id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add);
} else {
vector_ids_to_add.resize(num_vectors_added);
for (int pos = current_num_vectors_added; pos < current_num_vectors_added + num_vectors_added; pos++) {
vector_ids_to_add[pos - current_num_vectors_added] = vector_ids[pos];
vector_ids_to_add[pos - current_num_vectors_added] = vectors_.id_array_[pos];
}
}
Status status = execution_engine->AddWithIds(num_vectors_added,
vectors_ + current_num_vectors_added * table_file_schema.dimension_,
vector_ids_to_add.data());
Status status;
if (!vectors_.float_data_.empty()) {
status = execution_engine->AddWithIds(
num_vectors_added, vectors_.float_data_.data() + current_num_vectors_added * table_file_schema.dimension_,
vector_ids_to_add.data());
} else if (!vectors_.binary_data_.empty()) {
status = execution_engine->AddWithIds(
num_vectors_added,
vectors_.binary_data_.data() + current_num_vectors_added * SingleVectorSize(table_file_schema.dimension_),
vector_ids_to_add.data());
}
if (status.ok()) {
current_num_vectors_added += num_vectors_added;
vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()),
@ -64,9 +75,20 @@ VectorSource::GetNumVectorsAdded() {
return current_num_vectors_added;
}
size_t
VectorSource::SingleVectorSize(uint16_t dimension) {
if (!vectors_.float_data_.empty()) {
return dimension * FLOAT_TYPE_SIZE;
} else if (!vectors_.binary_data_.empty()) {
return dimension / 8;
}
return 0;
}
bool
VectorSource::AllAdded() {
return (current_num_vectors_added == n_);
return (current_num_vectors_added == vectors_.vector_count_);
}
IDNumbers

View File

@ -29,15 +29,18 @@ namespace engine {
class VectorSource {
public:
VectorSource(const size_t& n, const float* vectors);
explicit VectorSource(VectorsData& vectors);
Status
Add(const ExecutionEnginePtr& execution_engine, const meta::TableFileSchema& table_file_schema,
const size_t& num_vectors_to_add, size_t& num_vectors_added, IDNumbers& vector_ids);
const size_t& num_vectors_to_add, size_t& num_vectors_added);
size_t
GetNumVectorsAdded();
size_t
SingleVectorSize(uint16_t dimension);
bool
AllAdded();
@ -45,8 +48,7 @@ class VectorSource {
GetVectorIds();
private:
const size_t n_;
const float* vectors_;
VectorsData& vectors_;
IDNumbers vector_ids_;
size_t current_num_vectors_added;

View File

@ -462,7 +462,8 @@ const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_milvus_2eproto::offsets[] PROT
~0u, // no _extensions_
~0u, // no _oneof_case_
~0u, // no _weak_field_map_
PROTOBUF_FIELD_OFFSET(::milvus::grpc::RowRecord, vector_data_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::RowRecord, float_data_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::RowRecord, binary_data_),
~0u, // no _has_bits_
PROTOBUF_FIELD_OFFSET(::milvus::grpc::InsertParam, _internal_metadata_),
~0u, // no _extensions_
@ -565,18 +566,18 @@ static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOB
{ 37, -1, sizeof(::milvus::grpc::PartitionList)},
{ 44, -1, sizeof(::milvus::grpc::Range)},
{ 51, -1, sizeof(::milvus::grpc::RowRecord)},
{ 57, -1, sizeof(::milvus::grpc::InsertParam)},
{ 66, -1, sizeof(::milvus::grpc::VectorIds)},
{ 73, -1, sizeof(::milvus::grpc::SearchParam)},
{ 84, -1, sizeof(::milvus::grpc::SearchInFilesParam)},
{ 91, -1, sizeof(::milvus::grpc::TopKQueryResult)},
{ 100, -1, sizeof(::milvus::grpc::StringReply)},
{ 107, -1, sizeof(::milvus::grpc::BoolReply)},
{ 114, -1, sizeof(::milvus::grpc::TableRowCount)},
{ 121, -1, sizeof(::milvus::grpc::Command)},
{ 127, -1, sizeof(::milvus::grpc::Index)},
{ 134, -1, sizeof(::milvus::grpc::IndexParam)},
{ 142, -1, sizeof(::milvus::grpc::DeleteByDateParam)},
{ 58, -1, sizeof(::milvus::grpc::InsertParam)},
{ 67, -1, sizeof(::milvus::grpc::VectorIds)},
{ 74, -1, sizeof(::milvus::grpc::SearchParam)},
{ 85, -1, sizeof(::milvus::grpc::SearchInFilesParam)},
{ 92, -1, sizeof(::milvus::grpc::TopKQueryResult)},
{ 101, -1, sizeof(::milvus::grpc::StringReply)},
{ 108, -1, sizeof(::milvus::grpc::BoolReply)},
{ 115, -1, sizeof(::milvus::grpc::TableRowCount)},
{ 122, -1, sizeof(::milvus::grpc::Command)},
{ 128, -1, sizeof(::milvus::grpc::Index)},
{ 135, -1, sizeof(::milvus::grpc::IndexParam)},
{ 143, -1, sizeof(::milvus::grpc::DeleteByDateParam)},
};
static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = {
@ -617,65 +618,65 @@ const char descriptor_table_protodef_milvus_2eproto[] PROTOBUF_SECTION_VARIABLE(
"ilvus.grpc.Status\0224\n\017partition_array\030\002 \003"
"(\0132\033.milvus.grpc.PartitionParam\"/\n\005Range"
"\022\023\n\013start_value\030\001 \001(\t\022\021\n\tend_value\030\002 \001(\t"
"\" \n\tRowRecord\022\023\n\013vector_data\030\001 \003(\002\"\200\001\n\013I"
"nsertParam\022\022\n\ntable_name\030\001 \001(\t\0220\n\020row_re"
"cord_array\030\002 \003(\0132\026.milvus.grpc.RowRecord"
"\022\024\n\014row_id_array\030\003 \003(\003\022\025\n\rpartition_tag\030"
"\004 \001(\t\"I\n\tVectorIds\022#\n\006status\030\001 \001(\0132\023.mil"
"vus.grpc.Status\022\027\n\017vector_id_array\030\002 \003(\003"
"\"\277\001\n\013SearchParam\022\022\n\ntable_name\030\001 \001(\t\0222\n\022"
"query_record_array\030\002 \003(\0132\026.milvus.grpc.R"
"owRecord\022-\n\021query_range_array\030\003 \003(\0132\022.mi"
"lvus.grpc.Range\022\014\n\004topk\030\004 \001(\003\022\016\n\006nprobe\030"
"\005 \001(\003\022\033\n\023partition_tag_array\030\006 \003(\t\"[\n\022Se"
"archInFilesParam\022\025\n\rfile_id_array\030\001 \003(\t\022"
".\n\014search_param\030\002 \001(\0132\030.milvus.grpc.Sear"
"chParam\"g\n\017TopKQueryResult\022#\n\006status\030\001 \001"
"(\0132\023.milvus.grpc.Status\022\017\n\007row_num\030\002 \001(\003"
"\022\013\n\003ids\030\003 \003(\003\022\021\n\tdistances\030\004 \003(\002\"H\n\013Stri"
"ngReply\022#\n\006status\030\001 \001(\0132\023.milvus.grpc.St"
"atus\022\024\n\014string_reply\030\002 \001(\t\"D\n\tBoolReply\022"
"#\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\022\n\n"
"bool_reply\030\002 \001(\010\"M\n\rTableRowCount\022#\n\006sta"
"tus\030\001 \001(\0132\023.milvus.grpc.Status\022\027\n\017table_"
"row_count\030\002 \001(\003\"\026\n\007Command\022\013\n\003cmd\030\001 \001(\t\""
"*\n\005Index\022\022\n\nindex_type\030\001 \001(\005\022\r\n\005nlist\030\002 "
"\001(\005\"h\n\nIndexParam\022#\n\006status\030\001 \001(\0132\023.milv"
"us.grpc.Status\022\022\n\ntable_name\030\002 \001(\t\022!\n\005in"
"dex\030\003 \001(\0132\022.milvus.grpc.Index\"J\n\021DeleteB"
"yDateParam\022!\n\005range\030\001 \001(\0132\022.milvus.grpc."
"Range\022\022\n\ntable_name\030\002 \001(\t2\272\t\n\rMilvusServ"
"ice\022>\n\013CreateTable\022\030.milvus.grpc.TableSc"
"hema\032\023.milvus.grpc.Status\"\000\022<\n\010HasTable\022"
"\026.milvus.grpc.TableName\032\026.milvus.grpc.Bo"
"olReply\"\000\022C\n\rDescribeTable\022\026.milvus.grpc"
".TableName\032\030.milvus.grpc.TableSchema\"\000\022B"
"\n\nCountTable\022\026.milvus.grpc.TableName\032\032.m"
"ilvus.grpc.TableRowCount\"\000\022@\n\nShowTables"
"\022\024.milvus.grpc.Command\032\032.milvus.grpc.Tab"
"leNameList\"\000\022:\n\tDropTable\022\026.milvus.grpc."
"TableName\032\023.milvus.grpc.Status\"\000\022=\n\013Crea"
"teIndex\022\027.milvus.grpc.IndexParam\032\023.milvu"
"s.grpc.Status\"\000\022B\n\rDescribeIndex\022\026.milvu"
"s.grpc.TableName\032\027.milvus.grpc.IndexPara"
"m\"\000\022:\n\tDropIndex\022\026.milvus.grpc.TableName"
"\032\023.milvus.grpc.Status\"\000\022E\n\017CreatePartiti"
"on\022\033.milvus.grpc.PartitionParam\032\023.milvus"
".grpc.Status\"\000\022F\n\016ShowPartitions\022\026.milvu"
"s.grpc.TableName\032\032.milvus.grpc.Partition"
"List\"\000\022C\n\rDropPartition\022\033.milvus.grpc.Pa"
"rtitionParam\032\023.milvus.grpc.Status\"\000\022<\n\006I"
"nsert\022\030.milvus.grpc.InsertParam\032\026.milvus"
".grpc.VectorIds\"\000\022B\n\006Search\022\030.milvus.grp"
"c.SearchParam\032\034.milvus.grpc.TopKQueryRes"
"ult\"\000\022P\n\rSearchInFiles\022\037.milvus.grpc.Sea"
"rchInFilesParam\032\034.milvus.grpc.TopKQueryR"
"esult\"\000\0227\n\003Cmd\022\024.milvus.grpc.Command\032\030.m"
"ilvus.grpc.StringReply\"\000\022E\n\014DeleteByDate"
"\022\036.milvus.grpc.DeleteByDateParam\032\023.milvu"
"s.grpc.Status\"\000\022=\n\014PreloadTable\022\026.milvus"
".grpc.TableName\032\023.milvus.grpc.Status\"\000b\006"
"proto3"
"\"4\n\tRowRecord\022\022\n\nfloat_data\030\001 \003(\002\022\023\n\013bin"
"ary_data\030\002 \001(\014\"\200\001\n\013InsertParam\022\022\n\ntable_"
"name\030\001 \001(\t\0220\n\020row_record_array\030\002 \003(\0132\026.m"
"ilvus.grpc.RowRecord\022\024\n\014row_id_array\030\003 \003"
"(\003\022\025\n\rpartition_tag\030\004 \001(\t\"I\n\tVectorIds\022#"
"\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\027\n\017v"
"ector_id_array\030\002 \003(\003\"\277\001\n\013SearchParam\022\022\n\n"
"table_name\030\001 \001(\t\0222\n\022query_record_array\030\002"
" \003(\0132\026.milvus.grpc.RowRecord\022-\n\021query_ra"
"nge_array\030\003 \003(\0132\022.milvus.grpc.Range\022\014\n\004t"
"opk\030\004 \001(\003\022\016\n\006nprobe\030\005 \001(\003\022\033\n\023partition_t"
"ag_array\030\006 \003(\t\"[\n\022SearchInFilesParam\022\025\n\r"
"file_id_array\030\001 \003(\t\022.\n\014search_param\030\002 \001("
"\0132\030.milvus.grpc.SearchParam\"g\n\017TopKQuery"
"Result\022#\n\006status\030\001 \001(\0132\023.milvus.grpc.Sta"
"tus\022\017\n\007row_num\030\002 \001(\003\022\013\n\003ids\030\003 \003(\003\022\021\n\tdis"
"tances\030\004 \003(\002\"H\n\013StringReply\022#\n\006status\030\001 "
"\001(\0132\023.milvus.grpc.Status\022\024\n\014string_reply"
"\030\002 \001(\t\"D\n\tBoolReply\022#\n\006status\030\001 \001(\0132\023.mi"
"lvus.grpc.Status\022\022\n\nbool_reply\030\002 \001(\010\"M\n\r"
"TableRowCount\022#\n\006status\030\001 \001(\0132\023.milvus.g"
"rpc.Status\022\027\n\017table_row_count\030\002 \001(\003\"\026\n\007C"
"ommand\022\013\n\003cmd\030\001 \001(\t\"*\n\005Index\022\022\n\nindex_ty"
"pe\030\001 \001(\005\022\r\n\005nlist\030\002 \001(\005\"h\n\nIndexParam\022#\n"
"\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\022\n\nta"
"ble_name\030\002 \001(\t\022!\n\005index\030\003 \001(\0132\022.milvus.g"
"rpc.Index\"J\n\021DeleteByDateParam\022!\n\005range\030"
"\001 \001(\0132\022.milvus.grpc.Range\022\022\n\ntable_name\030"
"\002 \001(\t2\272\t\n\rMilvusService\022>\n\013CreateTable\022\030"
".milvus.grpc.TableSchema\032\023.milvus.grpc.S"
"tatus\"\000\022<\n\010HasTable\022\026.milvus.grpc.TableN"
"ame\032\026.milvus.grpc.BoolReply\"\000\022C\n\rDescrib"
"eTable\022\026.milvus.grpc.TableName\032\030.milvus."
"grpc.TableSchema\"\000\022B\n\nCountTable\022\026.milvu"
"s.grpc.TableName\032\032.milvus.grpc.TableRowC"
"ount\"\000\022@\n\nShowTables\022\024.milvus.grpc.Comma"
"nd\032\032.milvus.grpc.TableNameList\"\000\022:\n\tDrop"
"Table\022\026.milvus.grpc.TableName\032\023.milvus.g"
"rpc.Status\"\000\022=\n\013CreateIndex\022\027.milvus.grp"
"c.IndexParam\032\023.milvus.grpc.Status\"\000\022B\n\rD"
"escribeIndex\022\026.milvus.grpc.TableName\032\027.m"
"ilvus.grpc.IndexParam\"\000\022:\n\tDropIndex\022\026.m"
"ilvus.grpc.TableName\032\023.milvus.grpc.Statu"
"s\"\000\022E\n\017CreatePartition\022\033.milvus.grpc.Par"
"titionParam\032\023.milvus.grpc.Status\"\000\022F\n\016Sh"
"owPartitions\022\026.milvus.grpc.TableName\032\032.m"
"ilvus.grpc.PartitionList\"\000\022C\n\rDropPartit"
"ion\022\033.milvus.grpc.PartitionParam\032\023.milvu"
"s.grpc.Status\"\000\022<\n\006Insert\022\030.milvus.grpc."
"InsertParam\032\026.milvus.grpc.VectorIds\"\000\022B\n"
"\006Search\022\030.milvus.grpc.SearchParam\032\034.milv"
"us.grpc.TopKQueryResult\"\000\022P\n\rSearchInFil"
"es\022\037.milvus.grpc.SearchInFilesParam\032\034.mi"
"lvus.grpc.TopKQueryResult\"\000\0227\n\003Cmd\022\024.mil"
"vus.grpc.Command\032\030.milvus.grpc.StringRep"
"ly\"\000\022E\n\014DeleteByDate\022\036.milvus.grpc.Delet"
"eByDateParam\032\023.milvus.grpc.Status\"\000\022=\n\014P"
"reloadTable\022\026.milvus.grpc.TableName\032\023.mi"
"lvus.grpc.Status\"\000b\006proto3"
;
static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_milvus_2eproto_deps[1] = {
&::descriptor_table_status_2eproto,
@ -705,7 +706,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_mil
static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_milvus_2eproto_once;
static bool descriptor_table_milvus_2eproto_initialized = false;
const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_milvus_2eproto = {
&descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 2886,
&descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 2906,
&descriptor_table_milvus_2eproto_once, descriptor_table_milvus_2eproto_sccs, descriptor_table_milvus_2eproto_deps, 20, 1,
schemas, file_default_instances, TableStruct_milvus_2eproto::offsets,
file_level_metadata_milvus_2eproto, 20, file_level_enum_descriptors_milvus_2eproto, file_level_service_descriptors_milvus_2eproto,
@ -3122,12 +3123,18 @@ RowRecord::RowRecord()
RowRecord::RowRecord(const RowRecord& from)
: ::PROTOBUF_NAMESPACE_ID::Message(),
_internal_metadata_(nullptr),
vector_data_(from.vector_data_) {
float_data_(from.float_data_) {
_internal_metadata_.MergeFrom(from._internal_metadata_);
binary_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
if (!from.binary_data().empty()) {
binary_data_.AssignWithDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from.binary_data_);
}
// @@protoc_insertion_point(copy_constructor:milvus.grpc.RowRecord)
}
void RowRecord::SharedCtor() {
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_RowRecord_milvus_2eproto.base);
binary_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
RowRecord::~RowRecord() {
@ -3136,6 +3143,7 @@ RowRecord::~RowRecord() {
}
void RowRecord::SharedDtor() {
binary_data_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
void RowRecord::SetCachedSize(int size) const {
@ -3153,7 +3161,8 @@ void RowRecord::Clear() {
// Prevent compiler warnings about cached_has_bits being unused
(void) cached_has_bits;
vector_data_.Clear();
float_data_.Clear();
binary_data_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
_internal_metadata_.Clear();
}
@ -3165,16 +3174,23 @@ const char* RowRecord::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::
ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag);
CHK_(ptr);
switch (tag >> 3) {
// repeated float vector_data = 1;
// repeated float float_data = 1;
case 1:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) {
ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(mutable_vector_data(), ptr, ctx);
ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(mutable_float_data(), ptr, ctx);
CHK_(ptr);
} else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 13) {
add_vector_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<float>(ptr));
add_float_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<float>(ptr));
ptr += sizeof(float);
} else goto handle_unusual;
continue;
// bytes binary_data = 2;
case 2:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) {
ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(mutable_binary_data(), ptr, ctx);
CHK_(ptr);
} else goto handle_unusual;
continue;
default: {
handle_unusual:
if ((tag & 7) == 4 || tag == 0) {
@ -3205,16 +3221,27 @@ bool RowRecord::MergePartialFromCodedStream(
tag = p.first;
if (!p.second) goto handle_unusual;
switch (::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::GetTagFieldNumber(tag)) {
// repeated float vector_data = 1;
// repeated float float_data = 1;
case 1: {
if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (10 & 0xFF)) {
DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadPackedPrimitive<
float, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_FLOAT>(
input, this->mutable_vector_data())));
input, this->mutable_float_data())));
} else if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (13 & 0xFF)) {
DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadRepeatedPrimitiveNoInline<
float, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_FLOAT>(
1, 10u, input, this->mutable_vector_data())));
1, 10u, input, this->mutable_float_data())));
} else {
goto handle_unusual;
}
break;
}
// bytes binary_data = 2;
case 2: {
if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (18 & 0xFF)) {
DO_(::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadBytes(
input, this->mutable_binary_data()));
} else {
goto handle_unusual;
}
@ -3248,13 +3275,19 @@ void RowRecord::SerializeWithCachedSizes(
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
(void) cached_has_bits;
// repeated float vector_data = 1;
if (this->vector_data_size() > 0) {
// repeated float float_data = 1;
if (this->float_data_size() > 0) {
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteTag(1, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
output->WriteVarint32(_vector_data_cached_byte_size_.load(
output->WriteVarint32(_float_data_cached_byte_size_.load(
std::memory_order_relaxed));
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatArray(
this->vector_data().data(), this->vector_data_size(), output);
this->float_data().data(), this->float_data_size(), output);
}
// bytes binary_data = 2;
if (this->binary_data().size() > 0) {
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBytesMaybeAliased(
2, this->binary_data(), output);
}
if (_internal_metadata_.have_unknown_fields()) {
@ -3270,17 +3303,24 @@ void RowRecord::SerializeWithCachedSizes(
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
(void) cached_has_bits;
// repeated float vector_data = 1;
if (this->vector_data_size() > 0) {
// repeated float float_data = 1;
if (this->float_data_size() > 0) {
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteTagToArray(
1,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
target);
target = ::PROTOBUF_NAMESPACE_ID::io::CodedOutputStream::WriteVarint32ToArray(
_vector_data_cached_byte_size_.load(std::memory_order_relaxed),
_float_data_cached_byte_size_.load(std::memory_order_relaxed),
target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::
WriteFloatNoTagToArray(this->vector_data_, target);
WriteFloatNoTagToArray(this->float_data_, target);
}
// bytes binary_data = 2;
if (this->binary_data().size() > 0) {
target =
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBytesToArray(
2, this->binary_data(), target);
}
if (_internal_metadata_.have_unknown_fields()) {
@ -3304,9 +3344,9 @@ size_t RowRecord::ByteSizeLong() const {
// Prevent compiler warnings about cached_has_bits being unused
(void) cached_has_bits;
// repeated float vector_data = 1;
// repeated float float_data = 1;
{
unsigned int count = static_cast<unsigned int>(this->vector_data_size());
unsigned int count = static_cast<unsigned int>(this->float_data_size());
size_t data_size = 4UL * count;
if (data_size > 0) {
total_size += 1 +
@ -3314,11 +3354,18 @@ size_t RowRecord::ByteSizeLong() const {
static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size));
}
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size);
_vector_data_cached_byte_size_.store(cached_size,
_float_data_cached_byte_size_.store(cached_size,
std::memory_order_relaxed);
total_size += data_size;
}
// bytes binary_data = 2;
if (this->binary_data().size() > 0) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize(
this->binary_data());
}
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
SetCachedSize(cached_size);
return total_size;
@ -3346,7 +3393,11 @@ void RowRecord::MergeFrom(const RowRecord& from) {
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
(void) cached_has_bits;
vector_data_.MergeFrom(from.vector_data_);
float_data_.MergeFrom(from.float_data_);
if (from.binary_data().size() > 0) {
binary_data_.AssignWithDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from.binary_data_);
}
}
void RowRecord::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) {
@ -3370,7 +3421,9 @@ bool RowRecord::IsInitialized() const {
void RowRecord::InternalSwap(RowRecord* other) {
using std::swap;
_internal_metadata_.Swap(&other->_internal_metadata_);
vector_data_.InternalSwap(&other->vector_data_);
float_data_.InternalSwap(&other->float_data_);
binary_data_.Swap(&other->binary_data_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
GetArenaNoVirtual());
}
::PROTOBUF_NAMESPACE_ID::Metadata RowRecord::GetMetadata() const {

View File

@ -1314,26 +1314,39 @@ class RowRecord :
// accessors -------------------------------------------------------
enum : int {
kVectorDataFieldNumber = 1,
kFloatDataFieldNumber = 1,
kBinaryDataFieldNumber = 2,
};
// repeated float vector_data = 1;
int vector_data_size() const;
void clear_vector_data();
float vector_data(int index) const;
void set_vector_data(int index, float value);
void add_vector_data(float value);
// repeated float float_data = 1;
int float_data_size() const;
void clear_float_data();
float float_data(int index) const;
void set_float_data(int index, float value);
void add_float_data(float value);
const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >&
vector_data() const;
float_data() const;
::PROTOBUF_NAMESPACE_ID::RepeatedField< float >*
mutable_vector_data();
mutable_float_data();
// bytes binary_data = 2;
void clear_binary_data();
const std::string& binary_data() const;
void set_binary_data(const std::string& value);
void set_binary_data(std::string&& value);
void set_binary_data(const char* value);
void set_binary_data(const void* value, size_t size);
std::string* mutable_binary_data();
std::string* release_binary_data();
void set_allocated_binary_data(std::string* binary_data);
// @@protoc_insertion_point(class_scope:milvus.grpc.RowRecord)
private:
class _Internal;
::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< float > vector_data_;
mutable std::atomic<int> _vector_data_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_data_;
mutable std::atomic<int> _float_data_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr binary_data_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto;
};
@ -3907,34 +3920,85 @@ inline void Range::set_allocated_end_value(std::string* end_value) {
// RowRecord
// repeated float vector_data = 1;
inline int RowRecord::vector_data_size() const {
return vector_data_.size();
// repeated float float_data = 1;
inline int RowRecord::float_data_size() const {
return float_data_.size();
}
inline void RowRecord::clear_vector_data() {
vector_data_.Clear();
inline void RowRecord::clear_float_data() {
float_data_.Clear();
}
inline float RowRecord::vector_data(int index) const {
// @@protoc_insertion_point(field_get:milvus.grpc.RowRecord.vector_data)
return vector_data_.Get(index);
inline float RowRecord::float_data(int index) const {
// @@protoc_insertion_point(field_get:milvus.grpc.RowRecord.float_data)
return float_data_.Get(index);
}
inline void RowRecord::set_vector_data(int index, float value) {
vector_data_.Set(index, value);
// @@protoc_insertion_point(field_set:milvus.grpc.RowRecord.vector_data)
inline void RowRecord::set_float_data(int index, float value) {
float_data_.Set(index, value);
// @@protoc_insertion_point(field_set:milvus.grpc.RowRecord.float_data)
}
inline void RowRecord::add_vector_data(float value) {
vector_data_.Add(value);
// @@protoc_insertion_point(field_add:milvus.grpc.RowRecord.vector_data)
inline void RowRecord::add_float_data(float value) {
float_data_.Add(value);
// @@protoc_insertion_point(field_add:milvus.grpc.RowRecord.float_data)
}
inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >&
RowRecord::vector_data() const {
// @@protoc_insertion_point(field_list:milvus.grpc.RowRecord.vector_data)
return vector_data_;
RowRecord::float_data() const {
// @@protoc_insertion_point(field_list:milvus.grpc.RowRecord.float_data)
return float_data_;
}
inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >*
RowRecord::mutable_vector_data() {
// @@protoc_insertion_point(field_mutable_list:milvus.grpc.RowRecord.vector_data)
return &vector_data_;
RowRecord::mutable_float_data() {
// @@protoc_insertion_point(field_mutable_list:milvus.grpc.RowRecord.float_data)
return &float_data_;
}
// bytes binary_data = 2;
inline void RowRecord::clear_binary_data() {
binary_data_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline const std::string& RowRecord::binary_data() const {
// @@protoc_insertion_point(field_get:milvus.grpc.RowRecord.binary_data)
return binary_data_.GetNoArena();
}
inline void RowRecord::set_binary_data(const std::string& value) {
binary_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value);
// @@protoc_insertion_point(field_set:milvus.grpc.RowRecord.binary_data)
}
inline void RowRecord::set_binary_data(std::string&& value) {
binary_data_.SetNoArena(
&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value));
// @@protoc_insertion_point(field_set_rvalue:milvus.grpc.RowRecord.binary_data)
}
inline void RowRecord::set_binary_data(const char* value) {
GOOGLE_DCHECK(value != nullptr);
binary_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value));
// @@protoc_insertion_point(field_set_char:milvus.grpc.RowRecord.binary_data)
}
inline void RowRecord::set_binary_data(const void* value, size_t size) {
binary_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
::std::string(reinterpret_cast<const char*>(value), size));
// @@protoc_insertion_point(field_set_pointer:milvus.grpc.RowRecord.binary_data)
}
inline std::string* RowRecord::mutable_binary_data() {
// @@protoc_insertion_point(field_mutable:milvus.grpc.RowRecord.binary_data)
return binary_data_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline std::string* RowRecord::release_binary_data() {
// @@protoc_insertion_point(field_release:milvus.grpc.RowRecord.binary_data)
return binary_data_.ReleaseNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline void RowRecord::set_allocated_binary_data(std::string* binary_data) {
if (binary_data != nullptr) {
} else {
}
binary_data_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), binary_data);
// @@protoc_insertion_point(field_set_allocated:milvus.grpc.RowRecord.binary_data)
}
// -------------------------------------------------------------------

View File

@ -66,7 +66,8 @@ message Range {
* @brief Record inserted
*/
message RowRecord {
repeated float vector_data = 1; //binary vector data
repeated float float_data = 1; //float vector data
bytes binary_data = 2; //binary vector data
}
/**

View File

@ -32,6 +32,9 @@ set(index_srcs
knowhere/index/vector_index/IndexSPTAG.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/IndexBinaryIVF.cpp
knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
knowhere/index/vector_index/IndexBinaryIDMAP.cpp
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/nsg/NSG.cpp

View File

@ -35,4 +35,9 @@ extern const char* DISTANCE;
auto rows = dataset->Get<int64_t>(meta::ROWS); \
auto p_data = dataset->Get<const float*>(meta::TENSOR);
#define GETBINARYTENSOR(dataset) \
auto dim = dataset->Get<int64_t>(meta::DIM); \
auto rows = dataset->Get<int64_t>(meta::ROWS); \
auto p_data = dataset->Get<const uint8_t*>(meta::TENSOR);
} // namespace knowhere

View File

@ -20,6 +20,7 @@
#include <memory>
#include <sstream>
#include "Log.h"
#include "knowhere/common/Exception.h"
namespace knowhere {
@ -27,6 +28,9 @@ enum class METRICTYPE {
INVALID = 0,
L2 = 1,
IP = 2,
HAMMING = 20,
JACCARD = 21,
TANIMOTO = 22,
};
// General Config
@ -50,7 +54,13 @@ struct Cfg {
virtual bool
CheckValid() {
return true;
if (metric_type == METRICTYPE::IP || metric_type == METRICTYPE::L2) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
void

View File

@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <faiss/index_io.h>
#include <utility>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace knowhere {
FaissBaseBinaryIndex::FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index) : index_(std::move(index)) {
}
BinarySet
FaissBaseBinaryIndex::SerializeImpl() {
try {
faiss::IndexBinary* index = index_.get();
// SealImpl();
MemoryIOWriter writer;
faiss::write_index_binary(index, &writer);
auto data = std::make_shared<uint8_t>();
data.reset(writer.data_);
BinarySet res_set;
// TODO(linxj): use virtual func Name() instead of raw string.
res_set.Append("BinaryIVF", data, writer.rp);
return res_set;
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary) {
auto binary = index_binary.GetByName("BinaryIVF");
MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
faiss::IndexBinary* index = faiss::read_index_binary(&reader);
index_.reset(index);
}
} // namespace knowhere

View File

@ -0,0 +1,43 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <faiss/IndexBinary.h>
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/Dataset.h"
namespace knowhere {
class FaissBaseBinaryIndex {
protected:
explicit FaissBaseBinaryIndex(std::shared_ptr<faiss::IndexBinary> index);
virtual BinarySet
SerializeImpl();
virtual void
LoadImpl(const BinarySet& index_binary);
public:
std::shared_ptr<faiss::IndexBinary> index_ = nullptr;
};
} // namespace knowhere

View File

@ -0,0 +1,146 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <faiss/IndexBinaryFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/index_factory.h>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
namespace knowhere {
BinarySet
BinaryIDMAP::Serialize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl();
}
void
BinaryIDMAP::Load(const BinarySet& index_binary) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary);
}
DatasetPtr
BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
GETBINARYTENSOR(dataset)
auto elems = rows * config->k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, Config());
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
return ret_ds;
}
void
BinaryIDMAP::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& cfg) {
int32_t* pdistances = (int32_t*)distances;
index_->search(n, (uint8_t*)data, k, pdistances, labels);
}
void
BinaryIDMAP::Add(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
std::lock_guard<std::mutex> lk(mutex_);
GETBINARYTENSOR(dataset)
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
index_->add_with_ids(rows, (uint8_t*)p_data, p_ids);
}
void
BinaryIDMAP::Train(const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
if (build_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
config->CheckValid();
const char* type = "BFlat";
auto index = faiss::index_binary_factory(config->d, type, GetMetricType(config->metric_type));
index_.reset(index);
}
int64_t
BinaryIDMAP::Count() {
return index_->ntotal;
}
int64_t
BinaryIDMAP::Dimension() {
return index_->d;
}
const uint8_t*
BinaryIDMAP::GetRawVectors() {
try {
auto file_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get());
auto flat_index = dynamic_cast<faiss::IndexBinaryFlat*>(file_index->index);
return flat_index->xb.data();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
const int64_t*
BinaryIDMAP::GetRawIds() {
try {
auto file_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get());
return file_index->id_map.data();
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
BinaryIDMAP::Seal() {
// do nothing
}
} // namespace knowhere

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "FaissBaseBinaryIndex.h"
#include "VectorIndex.h"
namespace knowhere {
class BinaryIDMAP : public VectorIndex, public FaissBaseBinaryIndex {
public:
BinaryIDMAP() : FaissBaseBinaryIndex(nullptr) {
}
explicit BinaryIDMAP(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
}
BinarySet
Serialize() override;
void
Load(const BinarySet& index_binary) override;
DatasetPtr
Search(const DatasetPtr& dataset, const Config& config) override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
void
Train(const Config& config);
int64_t
Count() override;
int64_t
Dimension() override;
void
Seal() override;
const uint8_t*
GetRawVectors();
const int64_t*
GetRawIds();
protected:
virtual void
search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
protected:
std::mutex mutex_;
};
using BinaryIDMAPPtr = std::shared_ptr<BinaryIDMAP>;
} // namespace knowhere

View File

@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <faiss/IndexBinaryFlat.h>
#include <faiss/IndexBinaryIVF.h>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
#include <chrono>
namespace knowhere {
using stdclock = std::chrono::high_resolution_clock;
BinarySet
BinaryIVF::Serialize() {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
std::lock_guard<std::mutex> lk(mutex_);
return SerializeImpl();
}
void
BinaryIVF::Load(const BinarySet& index_binary) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary);
}
DatasetPtr
BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
GETBINARYTENSOR(dataset)
try {
auto elems = rows * config->k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void
BinaryIVF::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& cfg) {
auto params = GenParams(cfg);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
int32_t* pdistances = (int32_t*)distances;
stdclock::time_point before = stdclock::now();
ivf_index->search(n, (uint8_t*)data, k, pdistances, labels);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost
<< ", quantization cost: " << faiss::indexIVF_stats.quantization_time
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
}
std::shared_ptr<faiss::IVFSearchParameters>
BinaryIVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->max_codes = config.get_with_default("max_codes", size_t(0));
return params;
}
IndexModelPtr
BinaryIVF::Train(const DatasetPtr& dataset, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
auto build_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETBINARYTENSOR(dataset)
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, GetMetricType(build_cfg->metric_type));
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type));
index->train(rows, (uint8_t*)p_data);
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
index_ = index;
return nullptr;
}
int64_t
BinaryIVF::Count() {
return index_->ntotal;
}
int64_t
BinaryIVF::Dimension() {
return index_->d;
}
void
BinaryIVF::Add(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("not support yet");
}
void
BinaryIVF::Seal() {
// do nothing
}
} // namespace knowhere

View File

@ -0,0 +1,76 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "FaissBaseBinaryIndex.h"
#include "VectorIndex.h"
#include "faiss/IndexIVF.h"
namespace knowhere {
class BinaryIVF : public VectorIndex, public FaissBaseBinaryIndex {
public:
BinaryIVF() : FaissBaseBinaryIndex(nullptr) {
}
explicit BinaryIVF(std::shared_ptr<faiss::IndexBinary> index) : FaissBaseBinaryIndex(std::move(index)) {
}
BinarySet
Serialize() override;
void
Load(const BinarySet& index_binary) override;
DatasetPtr
Search(const DatasetPtr& dataset, const Config& config) override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
void
Seal() override;
IndexModelPtr
Train(const DatasetPtr& dataset, const Config& config) override;
int64_t
Count() override;
int64_t
Dimension() override;
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config& config);
virtual void
search_impl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
protected:
std::mutex mutex_;
};
using BinaryIVFIndexPtr = std::shared_ptr<BinaryIVF>;
} // namespace knowhere

View File

@ -66,7 +66,6 @@ IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
config->CheckValid();
GETTENSOR(dataset)
auto elems = rows * config->k;
@ -149,12 +148,11 @@ IDMAP::GetRawIds() {
}
}
const char* type = "IDMap,Flat";
void
IDMAP::Train(const Config& config) {
config->CheckValid();
const char* type = "IDMap,Flat";
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
index_.reset(index);
}

View File

@ -112,8 +112,8 @@ IVF::Search(const DatasetPtr& dataset, const Config& config) {
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg != nullptr) {
search_cfg->CheckValid(); // throw exception
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
GETTENSOR(dataset)

View File

@ -73,9 +73,9 @@ NSG::Load(const BinarySet& index_binary) {
DatasetPtr
NSG::Search(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
// if (build_cfg != nullptr) {
// build_cfg->CheckValid(); // throw exception
// }
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");

View File

@ -218,9 +218,9 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
DatasetPtr
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
SetParameters(config);
if (config != nullptr) {
config->CheckValid(); // throw exception
}
// if (config != nullptr) {
// config->CheckValid(); // throw exception
// }
auto p_data = dataset->Get<const float*>(meta::TENSOR);
for (auto i = 0; i < 10; ++i) {

View File

@ -30,6 +30,16 @@ GetMetricType(METRICTYPE& type) {
if (type == METRICTYPE::IP) {
return faiss::METRIC_INNER_PRODUCT;
}
// binary only
if (type == METRICTYPE::JACCARD) {
return faiss::METRIC_Jaccard;
}
if (type == METRICTYPE::TANIMOTO) {
return faiss::METRIC_Tanimoto;
}
if (type == METRICTYPE::HAMMING) {
return faiss::METRIC_Hamming;
}
KNOWHERE_THROW_MSG("Metric type is invalid");
}

View File

@ -82,13 +82,27 @@ struct IVFCfg : public Cfg {
std::stringstream
DumpImpl() override;
bool
CheckValid() override {
return true;
};
// bool
// CheckValid() override {
// return true;
// };
};
using IVFConfig = std::shared_ptr<IVFCfg>;
struct IVFBinCfg : public IVFCfg {
bool
CheckValid() override {
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
metric_type == METRICTYPE::JACCARD) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
};
struct IVFSQCfg : public IVFCfg {
// TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16
int64_t nbits = DEFAULT_NBITS;
@ -103,10 +117,10 @@ struct IVFSQCfg : public IVFCfg {
IVFSQCfg() = default;
bool
CheckValid() override {
return true;
};
// bool
// CheckValid() override {
// return true;
// };
};
using IVFSQConfig = std::shared_ptr<IVFSQCfg>;
@ -126,10 +140,10 @@ struct IVFPQCfg : public IVFCfg {
IVFPQCfg() = default;
bool
CheckValid() override {
return true;
};
// bool
// CheckValid() override {
// return true;
// };
};
using IVFPQConfig = std::shared_ptr<IVFPQCfg>;
@ -154,10 +168,10 @@ struct NSGCfg : public IVFCfg {
std::stringstream
DumpImpl() override;
bool
CheckValid() override {
return true;
};
// bool
// CheckValid() override {
// return true;
// };
};
using NSGConfig = std::shared_ptr<NSGCfg>;
@ -193,10 +207,10 @@ struct KDTCfg : public SPTAGCfg {
KDTCfg() = default;
bool
CheckValid() override {
return true;
};
// bool
// CheckValid() override {
// return true;
// };
};
using KDTConfig = std::shared_ptr<KDTCfg>;
@ -207,11 +221,25 @@ struct BKTCfg : public SPTAGCfg {
BKTCfg() = default;
bool
CheckValid() override {
return true;
};
// bool
// CheckValid() override {
// return true;
// };
};
using BKTConfig = std::shared_ptr<BKTCfg>;
struct BinIDMAPCfg : public Cfg {
bool
CheckValid() override {
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
metric_type == METRICTYPE::JACCARD) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
};
} // namespace knowhere

View File

@ -47,6 +47,9 @@ enum MetricType {
METRIC_L1, ///< L1 (aka cityblock)
METRIC_Linf, ///< infinity distance
METRIC_Lp, ///< L_p distance, p is given by metric_arg
METRIC_Jaccard,
METRIC_Tanimoto,
METRIC_Hamming,
/// some additional metrics defined in scipy.spatial.distance
METRIC_Canberra = 20,

View File

@ -7,10 +7,14 @@
// -*- c++ -*-
#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/IndexBinaryFlat.h>
#include <cmath>
#include <cstring>
#include <faiss/utils/hamming.h>
#include <faiss/utils/jaccard.h>
#include <faiss/utils/utils.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
@ -21,6 +25,9 @@ namespace faiss {
IndexBinaryFlat::IndexBinaryFlat(idx_t d)
: IndexBinary(d) {}
IndexBinaryFlat::IndexBinaryFlat(idx_t d, MetricType metric)
: IndexBinary(d, metric) {}
void IndexBinaryFlat::add(idx_t n, const uint8_t *x) {
xb.insert(xb.end(), x, x + n * code_size);
ntotal += n;
@ -34,24 +41,54 @@ void IndexBinaryFlat::reset() {
void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
int32_t *distances, idx_t *labels) const {
const idx_t block_size = query_batch_size;
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
float *D = new float[k * n];
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}
if (use_heap) {
// We see the distances and labels as heaps.
int_maxheap_array_t res = {
size_t(nn), size_t(k), labels + s * k, distances + s * k
};
if (use_heap) {
// We see the distances and labels as heaps.
hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
float_maxheap_array_t res = {
size_t(nn), size_t(k), labels + s * k, D + s * k
};
jaccard_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
/* ordered = */ true);
} else {
FAISS_THROW_MSG("tanimoto_knn_mc not implemented");
}
}
if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
D[i] = -log2(1-D[i]);
}
}
memcpy(distances, D, sizeof(float) * n * k);
delete [] D;
} else {
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}
if (use_heap) {
// We see the distances and labels as heaps.
int_maxheap_array_t res = {
size_t(nn), size_t(k), labels + s * k, distances + s * k
};
hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
/* ordered = */ true);
} else {
hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
distances + s * k, labels + s * k);
}
} else {
hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
distances + s * k, labels + s * k);
}
}
}
}

View File

@ -31,6 +31,8 @@ struct IndexBinaryFlat : IndexBinary {
explicit IndexBinaryFlat(idx_t d);
IndexBinaryFlat(idx_t d, MetricType metric);
void add(idx_t n, const uint8_t *x) override;
void reset() override;

View File

@ -8,17 +8,21 @@
// Copyright 2004-present Facebook. All Rights Reserved
// -*- c++ -*-
#include <faiss/Index.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexBinaryIVF.h>
#include <cstdio>
#include <memory>
#include <cmath>
#include <faiss/utils/hamming.h>
#include <faiss/utils/jaccard.h>
#include <faiss/utils/utils.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/IndexFlat.h>
namespace faiss {
@ -41,6 +45,24 @@ IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist)
cp.niter = 10;
}
IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist, MetricType metric)
: IndexBinary(d, metric),
invlists(new ArrayInvertedLists(nlist, code_size)),
own_invlists(true),
nprobe(1),
max_codes(0),
maintain_direct_map(false),
quantizer(quantizer),
nlist(nlist),
own_fields(false),
clustering_index(nullptr)
{
FAISS_THROW_IF_NOT (d == quantizer->d);
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
cp.niter = 10;
}
IndexBinaryIVF::IndexBinaryIVF()
: invlists(nullptr),
own_invlists(false),
@ -270,7 +292,13 @@ void IndexBinaryIVF::train(idx_t n, const uint8_t *x) {
std::unique_ptr<float[]> x_f(new float[n * d]);
binary_to_real(n * d, x, x_f.get());
IndexFlatL2 index_tmp(d);
IndexFlat index_tmp;
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
index_tmp = IndexFlat(d, METRIC_Jaccard);
} else {
index_tmp = IndexFlat(d, METRIC_L2);
}
if (clustering_index && verbose) {
printf("using clustering_index of dimension %d to do the clustering\n",
@ -369,6 +397,50 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
};
template<class JaccardComputer, bool store_pairs>
struct IVFBinaryScannerJaccard: BinaryInvertedListScanner {
JaccardComputer hc;
size_t code_size;
IVFBinaryScannerJaccard (size_t code_size): code_size (code_size)
{}
void set_query (const uint8_t *query_vector) override {
hc.set (query_vector, code_size);
}
idx_t list_no;
void set_list (idx_t list_no, uint8_t /* coarse_dis */) override {
this->list_no = list_no;
}
uint32_t distance_to_code (const uint8_t *code) const override {
}
size_t scan_codes (size_t n,
const uint8_t *codes,
const idx_t *ids,
int32_t *simi, idx_t *idxi,
size_t k) const override
{
using C = CMax<float, idx_t>;
float* psimi = (float*)simi;
size_t nup = 0;
for (size_t j = 0; j < n; j++) {
float dis = hc.jaccard (codes);
if (dis < psimi[0]) {
heap_pop<C> (k, psimi, idxi);
idx_t id = store_pairs ? (list_no << 32 | j) : ids[j];
heap_push<C> (k, psimi, idxi, dis, id);
nup++;
}
codes += code_size;
}
return nup;
}
};
template <bool store_pairs>
BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
@ -398,6 +470,23 @@ BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
}
}
template <bool store_pairs>
BinaryInvertedListScanner *select_IVFBinaryScannerJaccard (size_t code_size) {
switch (code_size) {
#define HANDLE_CS(cs) \
case cs: \
return new IVFBinaryScannerJaccard<JaccardComputer ## cs, store_pairs> (cs);
HANDLE_CS(16)
HANDLE_CS(32)
HANDLE_CS(64)
HANDLE_CS(128)
#undef HANDLE_CS
default:
return new IVFBinaryScannerJaccard<JaccardComputerDefault,
store_pairs>(code_size);
}
}
void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
size_t n,
@ -491,6 +580,89 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
}
void search_knn_jaccard_heap(const IndexBinaryIVF& ivf,
size_t n,
const uint8_t *x,
idx_t k,
const idx_t *keys,
const float * coarse_dis,
float *distances, idx_t *labels,
bool store_pairs,
const IVFSearchParameters *params)
{
long nprobe = params ? params->nprobe : ivf.nprobe;
long max_codes = params ? params->max_codes : ivf.max_codes;
MetricType metric_type = ivf.metric_type;
// almost verbatim copy from IndexIVF::search_preassigned
size_t nlistv = 0, ndis = 0, nheap = 0;
using HeapForJaccard = CMax<float, idx_t>;
#pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap)
{
std::unique_ptr<BinaryInvertedListScanner> scanner
(ivf.get_InvertedListScannerJaccard (store_pairs));
#pragma omp for
for (size_t i = 0; i < n; i++) {
const uint8_t *xi = x + i * ivf.code_size;
scanner->set_query(xi);
const idx_t * keysi = keys + i * nprobe;
float * simi = distances + k * i;
idx_t * idxi = labels + k * i;
heap_heapify<HeapForJaccard> (k, simi, idxi);
size_t nscan = 0;
for (size_t ik = 0; ik < nprobe; ik++) {
idx_t key = keysi[ik]; /* select the list */
if (key < 0) {
// not enough centroids for multiprobe
continue;
}
FAISS_THROW_IF_NOT_FMT
(key < (idx_t) ivf.nlist,
"Invalid key=%ld at ik=%ld nlist=%ld\n",
key, ik, ivf.nlist);
scanner->set_list (key, (int32_t)coarse_dis[i * nprobe + ik]);
nlistv++;
size_t list_size = ivf.invlists->list_size(key);
InvertedLists::ScopedCodes scodes (ivf.invlists, key);
std::unique_ptr<InvertedLists::ScopedIds> sids;
const Index::idx_t * ids = nullptr;
if (!store_pairs) {
sids.reset (new InvertedLists::ScopedIds (ivf.invlists, key));
ids = sids->get();
}
nheap += scanner->scan_codes (list_size, scodes.get(),
ids, (int32_t*)simi, idxi, k);
nscan += list_size;
if (max_codes && nscan >= max_codes)
break;
}
ndis += nscan;
heap_reorder<HeapForJaccard> (k, simi, idxi);
} // parallel for
} // parallel
indexIVF_stats.nq += n;
indexIVF_stats.nlist += nlistv;
indexIVF_stats.ndis += ndis;
indexIVF_stats.nheap_updates += nheap;
}
template<class HammingComputer, bool store_pairs>
void search_knn_hamming_count(const IndexBinaryIVF& ivf,
size_t nx,
@ -634,6 +806,16 @@ BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
}
}
BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScannerJaccard
(bool store_pairs) const
{
if (store_pairs) {
return select_IVFBinaryScannerJaccard<true> (code_size);
} else {
return select_IVFBinaryScannerJaccard<false> (code_size);
}
}
void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
const idx_t *idx,
const int32_t * coarse_dis,
@ -642,17 +824,38 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
const IVFSearchParameters *params
) const {
if (use_heap) {
search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
distances, labels, store_pairs,
params);
} else {
if (store_pairs) {
search_knn_hamming_count_1<true>
(*this, n, x, idx, k, distances, labels, params);
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
if (use_heap) {
float *D = new float[k * n];
float *c_dis = new float [n * nprobe];
memcpy(c_dis, coarse_dis, sizeof(float) * n * nprobe);
search_knn_jaccard_heap (*this, n, x, k, idx, c_dis ,
D, labels, store_pairs,
params);
if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
D[i] = -log2(1-D[i]);
}
}
memcpy(distances, D, sizeof(float) * n * k);
delete [] D;
delete [] c_dis;
} else {
search_knn_hamming_count_1<false>
(*this, n, x, idx, k, distances, labels, params);
//not implemented
}
} else {
if (use_heap) {
search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
distances, labels, store_pairs,
params);
} else {
if (store_pairs) {
search_knn_hamming_count_1<true>
(*this, n, x, idx, k, distances, labels, params);
} else {
search_knn_hamming_count_1<false>
(*this, n, x, idx, k, distances, labels, params);
}
}
}
}

View File

@ -64,6 +64,8 @@ struct IndexBinaryIVF : IndexBinary {
*/
IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist);
IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist, MetricType metric);
IndexBinaryIVF();
~IndexBinaryIVF() override;
@ -109,6 +111,9 @@ struct IndexBinaryIVF : IndexBinary {
virtual BinaryInvertedListScanner *get_InvertedListScanner (
bool store_pairs=false) const;
virtual BinaryInvertedListScanner *get_InvertedListScannerJaccard (
bool store_pairs=false) const;
/** assign the vectors, then call search_preassign */
virtual void search(idx_t n, const uint8_t *x, idx_t k,
int32_t *distances, idx_t *labels) const override;

View File

@ -52,6 +52,10 @@ void IndexFlat::search (idx_t n, const float *x, idx_t k,
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
} else if (metric_type == METRIC_Jaccard) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
knn_jaccard (x, xb.data(), d, n, ntotal, &res);
} else {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};

View File

@ -371,23 +371,24 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
return index;
}
IndexBinary *index_binary_factory(int d, const char *description)
IndexBinary *index_binary_factory(int d, const char *description, MetricType metric = METRIC_L2)
{
IndexBinary *index = nullptr;
int ncentroids = -1;
int M;
ScopeDeleter1<IndexBinary> del_index;
if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) {
IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
new IndexBinaryHNSW(d, M), d, ncentroids
new IndexBinaryHNSW(d, M), d, ncentroids
);
index_ivf->own_fields = true;
index = index_ivf;
} else if (sscanf(description, "BIVF%d", &ncentroids) == 1) {
IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
new IndexBinaryFlat(d), d, ncentroids
new IndexBinaryFlat(d), d, ncentroids
);
index_ivf->own_fields = true;
index = index_ivf;
@ -397,13 +398,27 @@ IndexBinary *index_binary_factory(int d, const char *description)
index = index_hnsw;
} else if (std::string(description) == "BFlat") {
index = new IndexBinaryFlat(d);
IndexBinary* index_x = new IndexBinaryFlat(d, metric);
{
IndexBinaryIDMap *idmap = new IndexBinaryIDMap(index_x);
del_index.set (idmap);
idmap->own_fields = true;
index_x = idmap;
}
if (index_x) {
index = index_x;
del_index.set(index);
}
} else {
FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
description);
description);
}
del_index.release();
return index;
}

View File

@ -19,7 +19,7 @@ namespace faiss {
Index *index_factory (int d, const char *description,
MetricType metric = METRIC_L2);
IndexBinary *index_binary_factory (int d, const char *description);
IndexBinary *index_binary_factory (int d, const char *description, MetricType metric = METRIC_L2);
}

View File

@ -340,7 +340,78 @@ static void knn_L2sqr_blas (const float * x,
}
template<class DistanceCorrection>
static void knn_jaccard_blas (const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res,
const DistanceCorrection &corr)
{
res->heapify ();
// BLAS does not like empty matrices
if (nx == 0 || ny == 0) return;
size_t k = res->k;
/* block sizes */
const size_t bs_x = 4096, bs_y = 1024;
// const size_t bs_x = 16, bs_y = 16;
float *ip_block = new float[bs_x * bs_y];
float *x_norms = new float[nx];
float *y_norms = new float[ny];
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
fvec_norms_L2sqr (x_norms, x, d, nx);
fvec_norms_L2sqr (y_norms, y, d, ny);
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if(i1 > nx) i1 = nx;
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny) j1 = ny;
/* compute the actual dot products */
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
y + j0 * d, &di,
x + i0 * d, &di, &zero,
ip_block, &nyi);
}
/* collect minima */
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
for (size_t j = j0; j < j1; j++) {
float ip = *ip_line++;
float dis = 1.0 - ip / (x_norms[i] + y_norms[j] - ip);
// negative values can occur for identical vectors
// due to roundoff errors
if (dis < 0) dis = 0;
dis = corr (dis, i, j);
if (dis < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, dis, j);
}
}
}
}
InterruptCallback::check ();
}
res->reorder ();
}
@ -387,6 +458,20 @@ void knn_L2sqr (const float * x,
}
}
void knn_jaccard (const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res)
{
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
// knn_jaccard_sse (x, y, d, nx, ny, res);
printf("sse_not implemented!\n");
} else {
NopDistanceCorrection nop;
knn_jaccard_blas (x, y, d, nx, ny, res, nop);
}
}
struct BaseShiftDistanceCorrection {
const float *base_shift;
float operator()(float dis, size_t /*qno*/, size_t bno) const {

View File

@ -47,6 +47,10 @@ float fvec_Linf (
const float * y,
size_t d);
float fvec_jaccard (
const float * x,
const float * y,
size_t d);
/** Compute pairwise distances between sets of vectors
*
@ -175,7 +179,11 @@ void knn_L2sqr (
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res);
void knn_jaccard (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res);
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
* computed distances.

View File

@ -112,7 +112,39 @@ struct VectorDistanceJensenShannon {
}
};
struct VectorDistanceJaccard {
size_t d;
float operator () (const float *x, const float *y) const {
float accu_num = 0, accu_den = 0;
const float EPSILON = 0.000001;
for (size_t i = 0; i < d; i++) {
float xi = x[i], yi = y[i];
if (fabs (xi - yi) < EPSILON) {
accu_num += xi;
accu_den += xi;
} else {
accu_den += xi;
accu_den += yi;
}
}
return 1 - accu_num / accu_den;
}
};
struct VectorDistanceTanimoto {
size_t d;
float operator () (const float *x, const float *y) const {
float accu_num = 0, accu_den = 0;
for (size_t i = 0; i < d; i++) {
float xi = x[i], yi = y[i];
accu_num += xi * yi;
accu_den += xi * xi + yi * yi - xi * yi;
}
return -log2(accu_num / accu_den) ;
}
};
@ -263,6 +295,18 @@ void pairwise_extra_distances (
dis, ldq, ldb, ldd);
break;
}
case METRIC_Jaccard: {
VectorDistanceJaccard vd({(size_t) d});
pairwise_extra_distances_template(vd, nq, xq, nb, xb,
dis, ldq, ldb, ldd);
break;
}
case METRIC_Tanimoto: {
VectorDistanceTanimoto vd({(size_t) d});
pairwise_extra_distances_template(vd, nq, xq, nb, xb,
dis, ldq, ldb, ldd);
break;
}
default:
FAISS_THROW_MSG ("metric type not implemented");
}
@ -296,6 +340,16 @@ void knn_extra_metrics (
knn_extra_metrics_template (vd, x, y, nx, ny, res);
break;
}
case METRIC_Jaccard: {
VectorDistanceJaccard vd({(size_t) d});
knn_extra_metrics_template(vd, x, y, nx, ny, res);
break;
}
case METRIC_Tanimoto: {
VectorDistanceTanimoto vd({(size_t) d});
knn_extra_metrics_template(vd, x, y, nx, ny, res);
break;
}
default:
FAISS_THROW_MSG ("metric type not implemented");
}

View File

@ -0,0 +1,196 @@
//
// Created by czr on 2019/12/19.
//
namespace faiss {
struct JaccardComputer16 {
uint64_t a0, a1;
JaccardComputer16 () {}
JaccardComputer16 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 16);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1];
}
inline float jaccard (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputer32 {
uint64_t a0, a1, a2, a3;
JaccardComputer32 () {}
JaccardComputer32 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 32);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
}
inline float jaccard (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputer64 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7;
JaccardComputer64 () {}
JaccardComputer64 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 64);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
}
inline float jaccard (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputer128 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7,
a8, a9, a10, a11, a12, a13, a14, a15;
JaccardComputer128 () {}
JaccardComputer128 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a16, int code_size) {
assert (code_size == 128 );
const uint64_t *a = (uint64_t *)a16;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
}
inline float jaccard (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7) +
popcount64 (b[8] | a8) + popcount64 (b[9] | a9) +
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
popcount64 (b[14] | a14) + popcount64 (b[15] | a15);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputerDefault {
const uint8_t *a;
int n;
JaccardComputerDefault () {}
JaccardComputerDefault (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
}
float jaccard (const uint8_t *b8) const {
int accu_num = 0;
int accu_den = 0;
for (int i = 0; i < n; i++) {
accu_num += popcount64(a[i] & b8[i]);
accu_den += popcount64(a[i] | b8[i]);
}
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
// default template
template<int CODE_SIZE>
struct JaccardComputer: JaccardComputerDefault {
JaccardComputer (const uint8_t *a, int code_size):
JaccardComputerDefault(a, code_size) {}
};
#define SPECIALIZED_HC(CODE_SIZE) \
template<> struct JaccardComputer<CODE_SIZE>: \
JaccardComputer ## CODE_SIZE { \
JaccardComputer (const uint8_t *a): \
JaccardComputer ## CODE_SIZE(a, CODE_SIZE) {} \
}
SPECIALIZED_HC(16);
SPECIALIZED_HC(32);
SPECIALIZED_HC(64);
SPECIALIZED_HC(128);
#undef SPECIALIZED_HC
}

View File

@ -0,0 +1,89 @@
#include <faiss/utils/jaccard.h>
#include <vector>
#include <memory>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
#include <limits.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>
namespace faiss {
size_t jaccard_batch_size = 65536;
template <class JaccardComputer>
static
void jaccard_knn_hc(
int bytes_per_code,
float_maxheap_array_t * ha,
const uint8_t * bs1,
const uint8_t * bs2,
size_t n2,
bool order = true,
bool init_heap = true)
{
size_t k = ha->k;
if (init_heap) ha->heapify ();
const size_t block_size = jaccard_batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
JaccardComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
tadis_t dis;
tadis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
dis = hc.jaccard (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
}
if (order) ha->reorder ();
}
void jaccard_knn_hc (
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
size_t nb,
size_t ncodes,
int order)
{
switch (ncodes) {
case 16:
jaccard_knn_hc<faiss::JaccardComputer16>
(16, ha, a, b, nb, order, true);
break;
case 32:
jaccard_knn_hc<faiss::JaccardComputer32>
(32, ha, a, b, nb, order, true);
break;
case 64:
jaccard_knn_hc<faiss::JaccardComputer64>
(64, ha, a, b, nb, order, true);
break;
case 128:
jaccard_knn_hc<faiss::JaccardComputer128>
(128, ha, a, b, nb, order, true);
break;
default:
jaccard_knn_hc<faiss::JaccardComputerDefault>
(ncodes, ha, a, b, nb, order, true);
}
}
}

View File

@ -0,0 +1,37 @@
#ifndef FAISS_JACCARD_H
#define FAISS_JACCARD_H
#include <faiss/utils/hamming.h>
#include <stdint.h>
#include <faiss/utils/Heap.h>
/* The Jaccard distance type */
typedef float tadis_t;
namespace faiss {
extern size_t jaccard_batch_size;
/** Return the k smallest Jaccard distances for a set of binary query vectors,
* using a max heap.
* @param a queries, size ha->nh * ncodes
* @param b database, size nb * ncodes
* @param nb number of database vectors
* @param ncodes size of the binary codes (bytes)
* @param ordered if != 0: order the results by decreasing distance
* (may be bottleneck for k/n > 0.01) */
void jaccard_knn_hc (
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
size_t nb,
size_t ncodes,
int ordered);
} //namespace faiss
#include <faiss/utils/jaccard-inl.h>
#endif //FAISS_JACCARD_H

View File

@ -58,6 +58,9 @@ set(ivf_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp
)
if (KNOWHERE_GPU_VERSION)
set(ivf_srcs ${ivf_srcs}
@ -74,6 +77,11 @@ if (NOT TARGET test_ivf)
endif ()
target_link_libraries(test_ivf ${depend_libs} ${unittest_libs} ${basic_libs})
if (NOT TARGET test_binaryivf)
add_executable(test_binaryivf test_binaryivf.cpp ${ivf_srcs} ${util_srcs})
endif ()
target_link_libraries(test_binaryivf ${depend_libs} ${unittest_libs} ${basic_libs})
#<IDMAP-TEST>
if (NOT TARGET test_idmap)
@ -81,6 +89,12 @@ if (NOT TARGET test_idmap)
endif ()
target_link_libraries(test_idmap ${depend_libs} ${unittest_libs} ${basic_libs})
#<BinaryIDMAP-TEST>
if (NOT TARGET test_binaryidmap)
add_executable(test_binaryidmap test_binaryidmap.cpp ${ivf_srcs} ${util_srcs})
endif ()
target_link_libraries(test_binaryidmap ${depend_libs} ${unittest_libs} ${basic_libs})
#<SPTAG-TEST>
set(sptag_srcs
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/SptagAdapter.cpp
@ -104,7 +118,9 @@ if (KNOWHERE_GPU_VERSION)
endif ()
install(TARGETS test_ivf DESTINATION unittest)
install(TARGETS test_binaryivf DESTINATION unittest)
install(TARGETS test_idmap DESTINATION unittest)
install(TARGETS test_binaryidmap DESTINATION unittest)
install(TARGETS test_sptag DESTINATION unittest)
if (KNOWHERE_GPU_VERSION)
install(TARGETS test_gpuresource DESTINATION unittest)

View File

@ -0,0 +1,119 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
#include "Helper.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
protected:
void
SetUp() override {
Init_with_binary_default();
index_ = std::make_shared<knowhere::BinaryIDMAP>();
}
void
TearDown() override{};
protected:
knowhere::BinaryIDMAPPtr index_ = nullptr;
};
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
knowhere::METRICTYPE::HAMMING));
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
ASSERT_TRUE(!xb.empty());
knowhere::METRICTYPE MetricType = GetParam();
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = MetricType;
index_->Train(conf);
index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
ASSERT_TRUE(index_->GetRawIds() != nullptr);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
auto binaryset = index_->Serialize();
auto new_index = std::make_shared<knowhere::BinaryIDMAP>();
new_index->Load(binaryset);
auto re_result = index_->Search(query_dataset, conf);
AssertAnns(re_result, nq, k);
// PrintResult(re_result, nq, k);
}
TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
FileIOReader reader(filename);
reader(ret, bin->size);
};
knowhere::METRICTYPE MetricType = GetParam();
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = MetricType;
{
// serialize index
index_->Train(conf);
index_->Add(base_dataset, knowhere::Config());
auto re_result = index_->Search(query_dataset, conf);
AssertAnns(re_result, nq, k);
// PrintResult(re_result, nq, k);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto binaryset = index_->Serialize();
auto bin = binaryset.GetByName("BinaryIVF");
std::string filename = "/tmp/bianryidmap_test_serialize.bin";
auto load_data = new uint8_t[bin->size];
serialize(filename, bin, load_data);
binaryset.clear();
auto data = std::make_shared<uint8_t>();
data.reset(load_data);
binaryset.Append("BinaryIVF", data, bin->size);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, k);
// PrintResult(result, nq, k);
}
}

View File

@ -0,0 +1,144 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <thread>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Timer.h"
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
#include "unittest/Helper.h"
#include "unittest/utils.h"
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
protected:
void
SetUp() override {
knowhere::METRICTYPE MetricType = GetParam();
Init_with_binary_default();
// nb = 1000000;
// nq = 1000;
// k = 1000;
// Generate(DIM, NB, NQ);
index_ = std::make_shared<knowhere::BinaryIVF>();
auto x_conf = std::make_shared<knowhere::IVFBinCfg>();
x_conf->d = dim;
x_conf->k = k;
x_conf->metric_type = MetricType;
x_conf->nlist = 100;
x_conf->nprobe = 10;
conf = x_conf;
conf->Dump();
}
void
TearDown() override {
}
protected:
std::string index_type;
knowhere::Config conf;
knowhere::BinaryIVFIndexPtr index_ = nullptr;
};
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
knowhere::METRICTYPE::HAMMING));
TEST_P(BinaryIVFTest, binaryivf_basic) {
assert(!xb.empty());
// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf);
// index_->set_preprocessor(preprocessor);
index_->Train(base_dataset, conf);
// index_->set_index_model(model);
// index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
// PrintResult(result, nq, k);
}
TEST_P(BinaryIVFTest, binaryivf_serialize) {
auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) {
FileIOWriter writer(filename);
writer(static_cast<void*>(bin->data.get()), bin->size);
FileIOReader reader(filename);
reader(ret, bin->size);
};
// {
// // serialize index-model
// auto model = index_->Train(base_dataset, conf);
// auto binaryset = model->Serialize();
// auto bin = binaryset.GetByName("BinaryIVF");
//
// std::string filename = "/tmp/binaryivf_test_model_serialize.bin";
// auto load_data = new uint8_t[bin->size];
// serialize(filename, bin, load_data);
//
// binaryset.clear();
// auto data = std::make_shared<uint8_t>();
// data.reset(load_data);
// binaryset.Append("BinaryIVF", data, bin->size);
//
// model->Load(binaryset);
//
// index_->set_index_model(model);
// index_->Add(base_dataset, conf);
// auto result = index_->Search(query_dataset, conf);
// AssertAnns(result, nq, conf->k);
// }
{
// serialize index
index_->Train(base_dataset, conf);
// index_->set_index_model(model);
// index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize();
auto bin = binaryset.GetByName("BinaryIVF");
std::string filename = "/tmp/binaryivf_test_serialize.bin";
auto load_data = new uint8_t[bin->size];
serialize(filename, bin, load_data);
binaryset.clear();
auto data = std::make_shared<uint8_t>();
data.reset(load_data);
binaryset.Append("BinaryIVF", data, bin->size);
index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
// PrintResult(result, nq, k);
}
}

View File

@ -42,11 +42,13 @@ class SPTAGTest : public DataGen, public TestWithParam<std::string> {
auto tempconf = std::make_shared<knowhere::KDTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
tempconf->metric_type = knowhere::METRICTYPE::L2;
conf = tempconf;
} else {
auto tempconf = std::make_shared<knowhere::BKTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
tempconf->metric_type = knowhere::METRICTYPE::L2;
conf = tempconf;
}

View File

@ -38,6 +38,11 @@ DataGen::Init_with_default() {
Generate(dim, nb, nq);
}
void
BinaryDataGen::Init_with_binary_default() {
Generate(dim, nb, nq);
}
void
DataGen::Generate(const int& dim, const int& nb, const int& nq) {
this->nb = nb;
@ -52,6 +57,21 @@ DataGen::Generate(const int& dim, const int& nb, const int& nq) {
query_dataset = generate_query_dataset(nq, dim, xq.data());
}
void
BinaryDataGen::Generate(const int& dim, const int& nb, const int& nq) {
this->nb = nb;
this->nq = nq;
this->dim = dim;
int64_t dim_x = dim / 8;
GenBinaryAll(dim_x, nb, xb, ids, nq, xq);
assert(xb.size() == (size_t)dim_x * nb);
assert(xq.size() == (size_t)dim_x * nq);
base_dataset = generate_binary_dataset(nb, dim, xb.data(), ids.data());
query_dataset = generate_binary_query_dataset(nq, dim, xq.data());
}
knowhere::DatasetPtr
DataGen::GenQuery(const int& nq) {
xq.resize(nq * dim);
@ -78,6 +98,23 @@ GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int
}
}
void
GenBinaryAll(const int64_t dim, const int64_t& nb, std::vector<uint8_t>& xb, std::vector<int64_t>& ids,
const int64_t& nq, std::vector<uint8_t>& xq) {
xb.resize(nb * dim);
xq.resize(nq * dim);
ids.resize(nb);
GenBinaryAll(dim, nb, xb.data(), ids.data(), nq, xq.data());
}
void
GenBinaryAll(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids, const int64_t& nq, uint8_t* xq) {
GenBinaryBase(dim, nb, xb, ids);
for (int64_t i = 0; i < nq * dim; ++i) {
xq[i] = xb[i];
}
}
void
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
for (auto i = 0; i < nb; ++i) {
@ -90,6 +127,17 @@ GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
}
}
void
GenBinaryBase(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids) {
for (auto i = 0; i < nb; ++i) {
for (auto j = 0; j < dim; ++j) {
// p_data[i * d + j] = float(base + i);
xb[i * dim + j] = (uint8_t)lrand48();
}
ids[i] = i;
}
}
FileIOReader::FileIOReader(const std::string& fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
@ -130,6 +178,16 @@ generate_dataset(int64_t nb, int64_t dim, const float* xb, const int64_t* ids) {
return ret_ds;
}
knowhere::DatasetPtr
generate_binary_dataset(int64_t nb, int64_t dim, const uint8_t* xb, const int64_t* ids) {
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xb);
ret_ds->Set(knowhere::meta::IDS, ids);
return ret_ds;
}
knowhere::DatasetPtr
generate_query_dataset(int64_t nb, int64_t dim, const float* xb) {
auto ret_ds = std::make_shared<knowhere::Dataset>();
@ -139,6 +197,15 @@ generate_query_dataset(int64_t nb, int64_t dim, const float* xb) {
return ret_ds;
}
knowhere::DatasetPtr
generate_binary_query_dataset(int64_t nb, int64_t dim, const uint8_t* xb) {
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xb);
return ret_ds;
}
void
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
auto ids = result->Get<int64_t*>(knowhere::meta::IDS);

View File

@ -49,6 +49,29 @@ class DataGen {
knowhere::DatasetPtr query_dataset = nullptr;
};
class BinaryDataGen {
protected:
void
Init_with_binary_default();
void
Generate(const int& dim, const int& nb, const int& nq);
knowhere::DatasetPtr
GenQuery(const int& nq);
protected:
int nb = 10000;
int nq = 10;
int dim = 512;
int k = 10;
std::vector<uint8_t> xb;
std::vector<uint8_t> xq;
std::vector<int64_t> ids;
knowhere::DatasetPtr base_dataset = nullptr;
knowhere::DatasetPtr query_dataset = nullptr;
};
extern void
GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector<int64_t>& ids, const int64_t& nq,
std::vector<float>& xq);
@ -56,18 +79,34 @@ GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector
extern void
GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq);
extern void
GenBinaryAll(const int64_t dim, const int64_t& nb, std::vector<uint8_t>& xb, std::vector<int64_t>& ids,
const int64_t& nq, std::vector<uint8_t>& xq);
extern void
GenBinaryAll(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids, const int64_t& nq, uint8_t* xq);
extern void
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids);
extern void
GenBinaryBase(const int64_t& dim, const int64_t& nb, uint8_t* xb, int64_t* ids);
extern void
InitLog();
knowhere::DatasetPtr
generate_dataset(int64_t nb, int64_t dim, const float* xb, const int64_t* ids);
knowhere::DatasetPtr
generate_binary_dataset(int64_t nb, int64_t dim, const uint8_t* xb, const int64_t* ids);
knowhere::DatasetPtr
generate_query_dataset(int64_t nb, int64_t dim, const float* xb);
knowhere::DatasetPtr
generate_binary_query_dataset(int64_t nb, int64_t dim, const uint8_t* xb);
void
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k);

View File

@ -22,9 +22,9 @@
namespace milvus {
namespace scheduler {
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nq, uint64_t nprobe,
const float* vectors)
: Job(JobType::SEARCH), context_(context), topk_(topk), nq_(nq), nprobe_(nprobe), vectors_(vectors) {
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
const engine::VectorsData& vectors)
: Job(JobType::SEARCH), context_(context), topk_(topk), nprobe_(nprobe), vectors_(vectors) {
}
bool
@ -77,7 +77,7 @@ json
SearchJob::Dump() const {
json ret{
{"topk", topk_},
{"nq", nq_},
{"nq", vectors_.vector_count_},
{"nprobe", nprobe_},
};
auto base = Job::Dump();

View File

@ -46,8 +46,8 @@ using ResultDistances = engine::ResultDistances;
class SearchJob : public Job {
public:
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nq, uint64_t nprobe,
const float* vectors);
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
const engine::VectorsData& vectors);
public:
bool
@ -82,7 +82,7 @@ class SearchJob : public Job {
uint64_t
nq() const {
return nq_;
return vectors_.vector_count_;
}
uint64_t
@ -90,7 +90,7 @@ class SearchJob : public Job {
return nprobe_;
}
const float*
const engine::VectorsData&
vectors() const {
return vectors_;
}
@ -109,10 +109,9 @@ class SearchJob : public Job {
const std::shared_ptr<server::Context> context_;
uint64_t topk_ = 0;
uint64_t nq_ = 0;
uint64_t nprobe_ = 0;
// TODO: smart pointer
const float* vectors_ = nullptr;
const engine::VectorsData& vectors_;
Id2IndexMap index_files_;
// TODO: column-base better ?

View File

@ -103,8 +103,10 @@ CollectFileMetrics(int file_type, size_t file_size) {
XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, TableFileSchemaPtr file, TaskLabelPtr label)
: Task(TaskType::SearchTask, std::move(label)), context_(context), file_(file) {
if (file_) {
if (file_->metric_type_ != static_cast<int>(MetricType::L2)) {
metric_l2 = false;
// distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ...
// similarity -- infinity value means two vectors equal, descending reduce, IP
if (file_->metric_type_ == static_cast<int>(MetricType::IP)) {
ascending_reduce = false;
}
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, (EngineType)file_->engine_type_,
(MetricType)file_->metric_type_, file_->nlist_);
@ -207,7 +209,7 @@ XSearchTask::Execute() {
uint64_t nq = search_job->nq();
uint64_t topk = search_job->topk();
uint64_t nprobe = search_job->nprobe();
const float* vectors = search_job->vectors();
const engine::VectorsData& vectors = search_job->vectors();
output_ids.resize(topk * nq);
output_distance.resize(topk * nq);
@ -221,8 +223,14 @@ XSearchTask::Execute() {
ResMgrInst::GetInstance()->GetResource(path().Last())->type() == ResourceType::CPU) {
hybrid = true;
}
Status s =
index_engine_->Search(nq, vectors, topk, nprobe, output_distance.data(), output_ids.data(), hybrid);
Status s;
if (!vectors.float_data_.empty()) {
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, nprobe, output_distance.data(),
output_ids.data(), hybrid);
} else if (!vectors.binary_data_.empty()) {
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, nprobe, output_distance.data(),
output_ids.data(), hybrid);
}
if (!s.ok()) {
search_job->GetStatus() = s;
search_job->SearchDone(index_id_);
@ -236,7 +244,7 @@ XSearchTask::Execute() {
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
{
std::unique_lock<std::mutex> lock(search_job->mutex());
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
search_job->GetResultIds(), search_job->GetResultDistances());
}

View File

@ -57,7 +57,10 @@ class XSearchTask : public Task {
size_t index_id_ = 0;
int index_type_ = 0;
ExecutionEnginePtr index_engine_ = nullptr;
bool metric_l2 = true;
// distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ...
// similarity -- infinity value means two vectors equal, descending reduce, IP
bool ascending_reduce = true;
};
} // namespace scheduler

View File

@ -75,11 +75,9 @@ RequestHandler::CreateIndex(const std::shared_ptr<Context>& context, const std::
}
Status
RequestHandler::Insert(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
std::vector<float>& data_list, const std::string& partition_tag,
std::vector<int64_t>& id_array) {
BaseRequestPtr request_ptr =
InsertRequest::Create(context, table_name, record_size, data_list, partition_tag, id_array);
RequestHandler::Insert(const std::shared_ptr<Context>& context, const std::string& table_name,
engine::VectorsData& vectors, const std::string& partition_tag) {
BaseRequestPtr request_ptr = InsertRequest::Create(context, table_name, vectors, partition_tag);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
@ -94,13 +92,13 @@ RequestHandler::ShowTables(const std::shared_ptr<Context>& context, std::vector<
}
Status
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
const std::vector<float>& data_list,
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors,
const std::vector<std::pair<std::string, std::string>>& range_list, int64_t topk, int64_t nprobe,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result) {
BaseRequestPtr request_ptr = SearchRequest::Create(context, table_name, record_size, data_list, range_list, topk,
nprobe, partition_list, file_id_list, result);
BaseRequestPtr request_ptr = SearchRequest::Create(context, table_name, vectors, range_list, topk, nprobe,
partition_list, file_id_list, result);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();

View File

@ -47,15 +47,15 @@ class RequestHandler {
int64_t nlist);
Status
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array);
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
const std::string& partition_tag);
Status
ShowTables(const std::shared_ptr<Context>& context, std::vector<std::string>& tables);
Status
Search(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
Search(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result);

View File

@ -71,23 +71,36 @@ CreateIndexRequest::OnExecute() {
return status;
}
// step 2: binary and float vector support different index/metric type, need to adapt here
engine::meta::TableSchema table_info;
table_info.table_id_ = table_name_;
status = DBWrapper::DB()->DescribeTable(table_info);
int32_t adapter_index_type = index_type_;
if (ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) { // binary vector not allow
if (adapter_index_type == static_cast<int32_t>(engine::EngineType::FAISS_IDMAP)) {
adapter_index_type = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP);
} else if (adapter_index_type == static_cast<int32_t>(engine::EngineType::FAISS_IVFFLAT)) {
adapter_index_type = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT);
} else {
return Status(SERVER_INVALID_INDEX_TYPE, "Invalid index type for table metric type");
}
}
#ifdef MILVUS_GPU_VERSION
Status s;
bool enable_gpu = false;
server::Config& config = server::Config::GetInstance();
s = config.GetGpuResourceConfigEnable(enable_gpu);
engine::meta::TableSchema table_info;
table_info.table_id_ = table_name_;
status = DBWrapper::DB()->DescribeTable(table_info);
if (s.ok() && index_type_ == (int)engine::EngineType::FAISS_PQ &&
if (s.ok() && adapter_index_type == (int)engine::EngineType::FAISS_PQ &&
table_info.metric_type_ == (int)engine::MetricType::IP) {
return Status(SERVER_UNEXPECTED_ERROR, "PQ not support IP in GPU version!");
}
#endif
// step 2: check table existence
// step 3: create index
engine::TableIndex index;
index.engine_type_ = index_type_;
index.engine_type_ = adapter_index_type;
index.nlist_ = nlist_;
status = DBWrapper::DB()->CreateIndex(table_name_, index);
if (!status.ok()) {

View File

@ -78,6 +78,15 @@ CreateTableRequest::OnExecute() {
table_info.index_file_size_ = index_file_size_;
table_info.metric_type_ = metric_type_;
// some metric type only support binary vector, adapt the index type
if (ValidationUtil::IsBinaryMetricType(metric_type_)) {
if (table_info.engine_type_ == static_cast<int32_t>(engine::EngineType::FAISS_IDMAP)) {
table_info.engine_type_ = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP);
} else if (table_info.engine_type_ == static_cast<int32_t>(engine::EngineType::FAISS_IVFFLAT)) {
table_info.engine_type_ = static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT);
}
}
// step 3: create table
status = DBWrapper::DB()->CreateTable(table_info);
if (!status.ok()) {

View File

@ -56,6 +56,14 @@ DescribeIndexRequest::OnExecute() {
return status;
}
// for binary vector, IDMAP and IVFLAT will be treated as BIN_IDMAP and BIN_IVFLAT internally
// return IDMAP and IVFLAT for outside caller
if (index.engine_type_ == (int32_t)engine::EngineType::FAISS_BIN_IDMAP) {
index.engine_type_ = (int32_t)engine::EngineType::FAISS_IDMAP;
} else if (index.engine_type_ == (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT) {
index.engine_type_ = (int32_t)engine::EngineType::FAISS_IVFFLAT;
}
index_param_.table_name_ = table_name_;
index_param_.index_type_ = index.engine_type_;
index_param_.nlist_ = index.nlist_;

View File

@ -29,27 +29,24 @@ namespace milvus {
namespace server {
InsertRequest::InsertRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
int64_t record_size, std::vector<float>& data_list, const std::string& partition_tag,
std::vector<int64_t>& id_array)
engine::VectorsData& vectors, const std::string& partition_tag)
: BaseRequest(context, DDL_DML_REQUEST_GROUP),
table_name_(table_name),
record_size_(record_size),
data_list_(data_list),
partition_tag_(partition_tag),
id_array_(id_array) {
vectors_data_(vectors),
partition_tag_(partition_tag) {
}
BaseRequestPtr
InsertRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array) {
return std::shared_ptr<BaseRequest>(
new InsertRequest(context, table_name, record_size, data_list, partition_tag, id_array));
InsertRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
engine::VectorsData& vectors, const std::string& partition_tag) {
return std::shared_ptr<BaseRequest>(new InsertRequest(context, table_name, vectors, partition_tag));
}
Status
InsertRequest::OnExecute() {
try {
std::string hdr = "InsertRequest(table=" + table_name_ + ", n=" + std::to_string(record_size_) +
int64_t vector_count = vectors_data_.vector_count_;
std::string hdr = "InsertRequest(table=" + table_name_ + ", n=" + std::to_string(vector_count) +
", partition_tag=" + partition_tag_ + ")";
TimeRecorder rc(hdr);
@ -58,13 +55,13 @@ InsertRequest::OnExecute() {
if (!status.ok()) {
return status;
}
if (data_list_.empty()) {
if (vectors_data_.float_data_.empty() && vectors_data_.binary_data_.empty()) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
"The vector array is empty. Make sure you have entered vector records.");
}
if (!id_array_.empty()) {
if (id_array_.size() != record_size_) {
if (!vectors_data_.id_array_.empty()) {
if (vectors_data_.id_array_.size() != vector_count) {
return Status(SERVER_ILLEGAL_VECTOR_ID,
"The size of vector ID array must be equal to the size of the vector.");
}
@ -84,7 +81,7 @@ InsertRequest::OnExecute() {
// step 3: check table flag
// all user provide id, or all internal id
bool user_provide_ids = !id_array_.empty();
bool user_provide_ids = !vectors_data_.id_array_.empty();
// user already provided id before, all insert action require user id
if ((table_info.flag_ & engine::meta::FLAG_MASK_HAS_USERID) != 0 && !user_provide_ids) {
return Status(SERVER_ILLEGAL_VECTOR_ID,
@ -105,27 +102,49 @@ InsertRequest::OnExecute() {
"/tmp/insert_" + std::to_string(this->insert_param_->row_record_array_size()) + ".profiling";
ProfilerStart(fname.c_str());
#endif
// step 4: some metric type doesn't support float vectors
if (!vectors_data_.float_data_.empty()) { // insert float vectors
if (ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Table metric type doesn't support float vectors.");
}
// step 4: check prepared float data
if (data_list_.size() % record_size_ != 0) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "The vector dimension must be equal to the table dimension.");
}
// check prepared float data
if (vectors_data_.float_data_.size() % vector_count != 0) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
"The vector dimension must be equal to the table dimension.");
}
if (data_list_.size() / record_size_ != table_info.dimension_) {
return Status(SERVER_INVALID_VECTOR_DIMENSION,
"The vector dimension must be equal to the table dimension.");
if (vectors_data_.float_data_.size() / vector_count != table_info.dimension_) {
return Status(SERVER_INVALID_VECTOR_DIMENSION,
"The vector dimension must be equal to the table dimension.");
}
} else if (!vectors_data_.binary_data_.empty()) { // insert binary vectors
if (!ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Table metric type doesn't support binary vectors.");
}
// check prepared binary data
if (vectors_data_.binary_data_.size() % vector_count != 0) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
"The vector dimension must be equal to the table dimension.");
}
if (vectors_data_.binary_data_.size() * 8 / vector_count != table_info.dimension_) {
return Status(SERVER_INVALID_VECTOR_DIMENSION,
"The vector dimension must be equal to the table dimension.");
}
}
// step 5: insert vectors
auto vec_count = static_cast<uint64_t>(record_size_);
auto vec_count = static_cast<uint64_t>(vector_count);
rc.RecordSection("prepare vectors data");
status = DBWrapper::DB()->InsertVectors(table_name_, partition_tag_, vec_count, data_list_.data(), id_array_);
status = DBWrapper::DB()->InsertVectors(table_name_, partition_tag_, vectors_data_);
if (!status.ok()) {
return status;
}
auto ids_size = id_array_.size();
auto ids_size = vectors_data_.id_array_.size();
if (ids_size != vec_count) {
std::string msg =
"Add " + std::to_string(vec_count) + " vectors but only return " + std::to_string(ids_size) + " id";

View File

@ -29,23 +29,20 @@ namespace server {
class InsertRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array);
Create(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
const std::string& partition_tag);
protected:
InsertRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
std::vector<float>& data_list, const std::string& partition_tag, std::vector<int64_t>& id_array);
InsertRequest(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
const std::string& partition_tag);
Status
OnExecute() override;
private:
const std::string table_name_;
int64_t record_size_;
const std::vector<float>& data_list_;
engine::VectorsData& vectors_data_;
const std::string partition_tag_;
std::vector<int64_t>& id_array_;
};
} // namespace server

View File

@ -23,20 +23,16 @@
#include <memory>
#
namespace milvus {
namespace server {
SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
int64_t record_size, const std::vector<float>& data_list,
const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
const std::vector<std::string>& partition_list,
const engine::VectorsData& vectors, const std::vector<Range>& range_list, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result)
: BaseRequest(context, DQL_REQUEST_GROUP),
table_name_(table_name),
record_size_(record_size),
data_list_(data_list),
vectors_data_(vectors),
range_list_(range_list),
topk_(topk),
nprobe_(nprobe),
@ -46,20 +42,21 @@ SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std:
}
BaseRequestPtr
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk,
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, const std::vector<Range>& range_list, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result) {
return std::shared_ptr<BaseRequest>(new SearchRequest(context, table_name, record_size, data_list, range_list, topk,
nprobe, partition_list, file_id_list, result));
return std::shared_ptr<BaseRequest>(new SearchRequest(context, table_name, vectors, range_list, topk, nprobe,
partition_list, file_id_list, result));
}
Status
SearchRequest::OnExecute() {
try {
uint64_t vector_count = vectors_data_.vector_count_;
auto pre_query_ctx = context_->Child("Pre query");
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(record_size_) +
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
TimeRecorder rc(hdr);
@ -93,7 +90,7 @@ SearchRequest::OnExecute() {
return status;
}
if (data_list_.empty()) {
if (vectors_data_.float_data_.empty() && vectors_data_.binary_data_.empty()) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
"The vector array is empty. Make sure you have entered vector records.");
}
@ -107,14 +104,28 @@ SearchRequest::OnExecute() {
rc.RecordSection("check validation");
// step 5: check prepared float data
if (data_list_.size() % record_size_ != 0) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "The vector dimension must be equal to the table dimension.");
}
if (ValidationUtil::IsBinaryMetricType(table_info.metric_type_)) {
// check prepared binary data
if (vectors_data_.binary_data_.size() % vector_count != 0) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
"The vector dimension must be equal to the table dimension.");
}
if (data_list_.size() / record_size_ != table_info.dimension_) {
return Status(SERVER_INVALID_VECTOR_DIMENSION,
"The vector dimension must be equal to the table dimension.");
if (vectors_data_.binary_data_.size() * 8 / vector_count != table_info.dimension_) {
return Status(SERVER_INVALID_VECTOR_DIMENSION,
"The vector dimension must be equal to the table dimension.");
}
} else {
// check prepared float data
if (vectors_data_.float_data_.size() % vector_count != 0) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY,
"The vector dimension must be equal to the table dimension.");
}
if (vectors_data_.float_data_.size() / vector_count != table_info.dimension_) {
return Status(SERVER_INVALID_VECTOR_DIMENSION,
"The vector dimension must be equal to the table dimension.");
}
}
rc.RecordSection("prepare vector data");
@ -122,7 +133,6 @@ SearchRequest::OnExecute() {
// step 6: search vectors
engine::ResultIds result_ids;
engine::ResultDistances result_distances;
auto record_count = static_cast<uint64_t>(record_size_);
#ifdef MILVUS_ENABLE_PROFILING
std::string fname =
@ -138,11 +148,11 @@ SearchRequest::OnExecute() {
return status;
}
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, record_count,
nprobe_, data_list_.data(), dates, result_ids, result_distances);
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, nprobe_,
vectors_data_, dates, result_ids, result_distances);
} else {
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, record_count,
nprobe_, data_list_.data(), dates, result_ids, result_distances);
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, nprobe_,
vectors_data_, dates, result_ids, result_distances);
}
#ifdef MILVUS_ENABLE_PROFILING
@ -161,7 +171,7 @@ SearchRequest::OnExecute() {
auto post_query_ctx = context_->Child("Constructing result");
// step 7: construct result array
result_.row_num_ = record_count;
result_.row_num_ = vector_count;
result_.distance_list_ = result_distances;
result_.id_list_ = result_ids;

View File

@ -29,14 +29,14 @@ namespace server {
class SearchRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
Create(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
const std::vector<Range>& range_list, int64_t topk, int64_t nprobe,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result);
protected:
SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t record_size,
const std::vector<float>& data_list, const std::vector<Range>& range_list, int64_t topk,
SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, const std::vector<Range>& range_list, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
@ -45,8 +45,7 @@ class SearchRequest : public BaseRequest {
private:
const std::string table_name_;
int64_t record_size_;
const std::vector<float>& data_list_;
const engine::VectorsData& vectors_data_;
const std::vector<Range> range_list_;
int64_t topk_;
int64_t nprobe_;

View File

@ -72,6 +72,49 @@ ErrorMap(ErrorCode code) {
}
}
namespace {
void
CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::RowRecord>& grpc_records,
const google::protobuf::RepeatedField<google::protobuf::int64>& grpc_id_array,
engine::VectorsData& vectors) {
// step 1: copy vector data
int64_t float_data_size = 0, binary_data_size = 0;
for (auto& record : grpc_records) {
float_data_size += record.float_data_size();
binary_data_size += record.binary_data().size();
}
std::vector<float> float_array(float_data_size, 0.0f);
std::vector<uint8_t> binary_array(binary_data_size, 0);
int64_t float_offset = 0, binary_offset = 0;
if (float_data_size > 0) {
for (auto& record : grpc_records) {
memcpy(&float_array[float_offset], record.float_data().data(), record.float_data_size() * sizeof(float));
float_offset += record.float_data_size();
}
} else if (binary_data_size > 0) {
for (auto& record : grpc_records) {
memcpy(&binary_array[binary_offset], record.binary_data().data(), record.binary_data().size());
binary_offset += record.binary_data().size();
}
}
// step 2: copy id array
std::vector<int64_t> id_array;
if (grpc_id_array.size() > 0) {
id_array.resize(grpc_id_array.size());
memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t));
}
// step 3: contruct vectors
vectors.vector_count_ = grpc_records.size();
vectors.float_data_.swap(float_array);
vectors.binary_data_.swap(binary_array);
vectors.id_array_.swap(id_array);
}
} // namespace
GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer)
: tracer_(tracer), random_num_generator_() {
std::random_device random_device;
@ -206,30 +249,18 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
::milvus::grpc::VectorIds* response) {
CHECK_NULLPTR_RETURN(request);
int64_t record_data_size = 0;
for (auto& record : request->row_record_array()) {
record_data_size += record.vector_data_size();
}
std::vector<float> record_array(record_data_size, 0.0f);
int64_t offset = 0;
for (auto& record : request->row_record_array()) {
memcpy(&record_array[offset], record.vector_data().data(), record.vector_data_size() * sizeof(float));
offset += record.vector_data_size();
}
std::vector<int64_t> id_array;
if (request->row_id_array_size() > 0) {
id_array.resize(request->row_id_array_size());
memcpy(id_array.data(), request->row_id_array().data(), request->row_id_array_size() * sizeof(int64_t));
}
// step 1: copy vector data
engine::VectorsData vectors;
CopyRowRecords(request->row_record_array(), request->row_id_array(), vectors);
// step 2: insert vectors
Status status =
request_handler_.Insert(context_map_[context], request->table_name(), request->row_record_array_size(),
record_array, request->partition_tag(), id_array);
request_handler_.Insert(context_map_[context], request->table_name(), vectors, request->partition_tag());
response->mutable_vector_id_array()->Resize(static_cast<int>(id_array.size()), 0);
memcpy(response->mutable_vector_id_array()->mutable_data(), id_array.data(), id_array.size() * sizeof(int64_t));
// step 3: return id array
response->mutable_vector_id_array()->Resize(static_cast<int>(vectors.id_array_.size()), 0);
memcpy(response->mutable_vector_id_array()->mutable_data(), vectors.id_array_.data(),
vectors.id_array_.size() * sizeof(int64_t));
SET_RESPONSE(response->mutable_status(), status, context);
return ::grpc::Status::OK;
@ -240,36 +271,29 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
::milvus::grpc::TopKQueryResult* response) {
CHECK_NULLPTR_RETURN(request);
int64_t record_data_size = 0;
for (auto& record : request->query_record_array()) {
record_data_size += record.vector_data_size();
}
std::vector<float> record_array(record_data_size);
int64_t offset = 0;
for (auto& record : request->query_record_array()) {
memcpy(&record_array[offset], record.vector_data().data(), record.vector_data_size() * sizeof(float));
offset += record.vector_data_size();
}
// step 1: copy vector data
engine::VectorsData vectors;
CopyRowRecords(request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(), vectors);
// deprecated
std::vector<Range> ranges;
for (auto& range : request->query_range_array()) {
ranges.emplace_back(range.start_value(), range.end_value());
}
// step 2: partition tags
std::vector<std::string> partitions;
for (auto& partition : request->partition_tag_array()) {
partitions.emplace_back(partition);
}
// step 3: search vectors
std::vector<std::string> file_ids;
TopKQueryResult result;
Status status = request_handler_.Search(context_map_[context], request->table_name(), vectors, ranges,
request->topk(), request->nprobe(), partitions, file_ids, result);
Status status =
request_handler_.Search(context_map_[context], request->table_name(), request->query_record_array_size(),
record_array, ranges, request->topk(), request->nprobe(), partitions, file_ids, result);
// construct result
// step 4: construct and return result
response->set_row_num(result.row_num_);
response->mutable_ids()->Resize(static_cast<int>(result.id_list_.size()), 0);
@ -289,42 +313,38 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus
::milvus::grpc::TopKQueryResult* response) {
CHECK_NULLPTR_RETURN(request);
std::vector<std::string> file_ids;
for (auto& file_id : request->file_id_array()) {
file_ids.emplace_back(file_id);
}
auto* search_request = &request->search_param();
int64_t record_data_size = 0;
for (auto& record : search_request->query_record_array()) {
record_data_size += record.vector_data_size();
}
std::vector<float> record_array(record_data_size);
int64_t offset = 0;
for (auto& record : search_request->query_record_array()) {
memcpy(&record_array[offset], record.vector_data().data(), record.vector_data_size() * sizeof(float));
offset += record.vector_data_size();
}
// step 1: copy vector data
engine::VectorsData vectors;
CopyRowRecords(search_request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(),
vectors);
// deprecated
std::vector<Range> ranges;
for (auto& range : search_request->query_range_array()) {
ranges.emplace_back(range.start_value(), range.end_value());
}
// step 2: copy file id array
std::vector<std::string> file_ids;
for (auto& file_id : request->file_id_array()) {
file_ids.emplace_back(file_id);
}
// step 3: partition tags
std::vector<std::string> partitions;
for (auto& partition : search_request->partition_tag_array()) {
partitions.emplace_back(partition);
}
// step 4: search vectors
TopKQueryResult result;
Status status =
request_handler_.Search(context_map_[context], search_request->table_name(), vectors, ranges,
search_request->topk(), search_request->nprobe(), partitions, file_ids, result);
Status status = request_handler_.Search(
context_map_[context], search_request->table_name(), search_request->query_record_array_size(), record_array,
ranges, search_request->topk(), search_request->nprobe(), partitions, file_ids, result);
// construct result
// step 5: construct and return result
response->set_row_num(result.row_num_);
response->mutable_ids()->Resize(static_cast<int>(result.id_list_.size()), 0);

View File

@ -79,6 +79,78 @@ WebErrorMap(ErrorCode code) {
}
}
namespace {
Status
CopyRowRecords(const InsertRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) {
vectors.float_data_.clear();
vectors.binary_data_.clear();
vectors.id_array_.clear();
vectors.vector_count_ = param->records->count();
// step 1: copy vector data
if (nullptr == param->records.get()) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "");
}
size_t tal_size = 0;
for (int64_t i = 0; i < param->records->count(); i++) {
tal_size += param->records->get(i)->count();
}
std::vector<float>& datas = vectors.float_data_;
datas.resize(tal_size);
size_t index_offset = 0;
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
datas[index_offset] = item->getValue();
index_offset++;
});
});
// step 2: copy id array
if (nullptr == param->ids.get()) {
return Status(SERVER_ILLEGAL_VECTOR_ID, "");
}
for (int64_t i = 0; i < param->ids->count(); i++) {
vectors.id_array_.emplace_back(param->ids->get(i)->getValue());
}
return Status::OK();
}
Status
CopyRowRecords(const SearchRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) {
vectors.float_data_.clear();
vectors.binary_data_.clear();
vectors.id_array_.clear();
vectors.vector_count_ = param->records->count();
// step 1: copy vector data
if (nullptr == param->records.get()) {
return Status(SERVER_INVALID_ROWRECORD_ARRAY, "");
}
size_t tal_size = 0;
for (int64_t i = 0; i < param->records->count(); i++) {
tal_size += param->records->get(i)->count();
}
std::vector<float>& datas = vectors.float_data_;
datas.resize(tal_size);
size_t index_offset = 0;
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
datas[index_offset] = item->getValue();
index_offset++;
});
});
return Status::OK();
}
} // namespace
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
Status
@ -567,37 +639,17 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& tag)
StatusDto::ObjectWrapper
WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& param,
VectorIdsDto::ObjectWrapper& ids_dto) {
std::vector<int64_t> ids;
if (nullptr != param->ids.get() && param->ids->count() > 0) {
for (int64_t i = 0; i < param->ids->count(); i++) {
ids.emplace_back(param->ids->get(i)->getValue());
}
}
if (nullptr == param->records.get()) {
engine::VectorsData vectors;
auto status = CopyRowRecords(param, vectors);
if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors")
}
size_t tal_size = 0;
for (int64_t i = 0; i < param->records->count(); i++) {
tal_size += param->records->get(i)->count();
}
std::vector<float> datas(tal_size);
size_t index_offset = 0;
param->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& row_item) {
row_item->forEach([&datas, &index_offset](const OFloat32& item) {
datas[index_offset] = item->getValue();
index_offset++;
});
});
auto status = request_handler_.Insert(context_ptr_, table_name->std_str(), param->records->count(), datas,
param->tag->std_str(), ids);
status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, param->tag->std_str());
if (status.ok()) {
ids_dto->ids = ids_dto->ids->createShared();
for (auto& id : ids) {
for (auto& id : vectors.id_array_) {
ids_dto->ids->pushBack(std::to_string(id).c_str());
}
}
@ -634,25 +686,18 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill query vectors")
}
size_t tal_size = 0;
search_request->records->forEach(
[&tal_size](const OList<OFloat32>::ObjectWrapper& item) { tal_size += item->count(); });
std::vector<float> datas(tal_size);
size_t index_offset = 0;
search_request->records->forEach([&datas, &index_offset](const OList<OFloat32>::ObjectWrapper& elem) {
elem->forEach([&datas, &index_offset](const OFloat32& item) {
datas[index_offset] = item->getValue();
index_offset++;
});
});
engine::VectorsData vectors;
auto status = CopyRowRecords(search_request, vectors);
if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors")
}
std::vector<Range> range_list;
TopKQueryResult result;
auto context_ptr = GenContextPtr("Web Handler");
auto status = request_handler_.Search(context_ptr, table_name->std_str(), search_request->records->count(), datas,
range_list, topk_t, nprobe_t, tag_list, file_id_list, result);
status = request_handler_.Search(context_ptr, table_name->std_str(), vectors, range_list, topk_t, nprobe_t,
tag_list, file_id_list, result);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}

View File

@ -21,9 +21,13 @@
#include "utils/StringHelpFunctions.h"
#include <arpa/inet.h>
#ifdef MILVUS_GPU_VERSION
#include <cuda_runtime.h>
#endif
#include <algorithm>
#include <cmath>
#include <regex>
@ -33,7 +37,7 @@ namespace milvus {
namespace server {
constexpr size_t TABLE_NAME_SIZE_LIMIT = 255;
constexpr int64_t TABLE_DIMENSION_LIMIT = 16384;
constexpr int64_t TABLE_DIMENSION_LIMIT = 32768;
constexpr int32_t INDEX_FILE_SIZE_LIMIT = 4096; // index trigger size max = 4096 MB
Status
@ -78,7 +82,8 @@ Status
ValidationUtil::ValidateTableDimension(int64_t dimension) {
if (dimension <= 0 || dimension > TABLE_DIMENSION_LIMIT) {
std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " +
"The table dimension must be within the range of 1 ~ 16384.";
"The table dimension must be within the range of 1 ~ " +
std::to_string(TABLE_DIMENSION_LIMIT) + ".";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
} else {
@ -108,6 +113,12 @@ ValidationUtil::ValidateTableIndexType(int32_t index_type) {
return Status::OK();
}
bool
ValidationUtil::IsBinaryIndexType(int32_t index_type) {
return (index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP)) ||
(index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT));
}
Status
ValidationUtil::ValidateTableIndexNlist(int32_t nlist) {
if (nlist <= 0) {
@ -135,16 +146,22 @@ ValidationUtil::ValidateTableIndexFileSize(int64_t index_file_size) {
Status
ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) {
if (metric_type != static_cast<int32_t>(engine::MetricType::L2) &&
metric_type != static_cast<int32_t>(engine::MetricType::IP)) {
if (metric_type <= 0 || metric_type > static_cast<int32_t>(engine::MetricType::MAX_VALUE)) {
std::string msg = "Invalid index metric type: " + std::to_string(metric_type) + ". " +
"Make sure the metric type is either MetricType.L2 or MetricType.IP.";
"Make sure the metric type is in MetricType list.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_INDEX_METRIC_TYPE, msg);
}
return Status::OK();
}
bool
ValidationUtil::IsBinaryMetricType(int32_t metric_type) {
return (metric_type == static_cast<int32_t>(engine::MetricType::HAMMING)) ||
(metric_type == static_cast<int32_t>(engine::MetricType::JACCARD)) ||
(metric_type == static_cast<int32_t>(engine::MetricType::TANIMOTO));
}
Status
ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) {
if (top_k <= 0 || top_k > 2048) {
@ -219,6 +236,7 @@ ValidationUtil::ValidateGpuIndex(int32_t gpu_index) {
}
#ifdef MILVUS_GPU_VERSION
Status
ValidationUtil::GetGpuMemory(int32_t gpu_index, size_t& memory) {
cudaDeviceProp deviceProp;
@ -233,6 +251,7 @@ ValidationUtil::GetGpuMemory(int32_t gpu_index, size_t& memory) {
memory = deviceProp.totalGlobalMem;
return Status::OK();
}
#endif
Status

View File

@ -40,6 +40,9 @@ class ValidationUtil {
static Status
ValidateTableIndexType(int32_t index_type);
static bool
IsBinaryIndexType(int32_t index_type);
static Status
ValidateTableIndexNlist(int32_t nlist);
@ -49,6 +52,9 @@ class ValidationUtil {
static Status
ValidateTableIndexMetricType(int32_t metric_type);
static bool
IsBinaryMetricType(int32_t metric_type);
static Status
ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema);

View File

@ -0,0 +1,187 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "wrapper/BinVecImpl.h"
#include "WrapperException.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
#include "utils/Log.h"
namespace milvus {
namespace engine {
Status
BinVecImpl::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const uint8_t* xt) {
try {
dim = cfg->d;
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xb);
ret_ds->Set(knowhere::meta::IDS, ids);
index_->Train(ret_ds, cfg);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
}
return Status::OK();
}
Status
BinVecImpl::Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg) {
try {
auto k = cfg->k;
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nq);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xq);
Config search_cfg = cfg;
auto res = index_->Search(ret_ds, search_cfg);
//{
// auto& ids = ids_array;
// auto& dists = dis_array;
// std::stringstream ss_id;
// std::stringstream ss_dist;
// for (auto i = 0; i < 10; i++) {
// for (auto j = 0; j < k; ++j) {
// ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
// ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
// }
// ss_id << std::endl;
// ss_dist << std::endl;
// }
// std::cout << "id\n" << ss_id.str() << std::endl;
// std::cout << "dist\n" << ss_dist.str() << std::endl;
//}
// auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
// auto p_dist = dis_array->data()->GetValues<float>(1, 0);
// TODO(linxj): avoid copy here.
auto res_ids = res->Get<int64_t*>(knowhere::meta::IDS);
auto res_dist = res->Get<float*>(knowhere::meta::DISTANCE);
memcpy(ids, res_ids, sizeof(int64_t) * nq * k);
memcpy(dist, res_dist, sizeof(float) * nq * k);
free(res_ids);
free(res_dist);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
}
return Status::OK();
}
Status
BinVecImpl::Add(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg) {
try {
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xb);
ret_ds->Set(knowhere::meta::IDS, ids);
index_->Add(ret_ds, cfg);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
}
return Status::OK();
}
VecIndexPtr
BinVecImpl::CopyToGpu(const int64_t& device_id, const Config& cfg) {
char* errmsg = "Binary Index not support CopyToGpu";
WRAPPER_LOG_ERROR << errmsg;
throw WrapperException("errmsg");
}
VecIndexPtr
BinVecImpl::CopyToCpu(const Config& cfg) {
char* errmsg = "Binary Index not support CopyToCpu";
WRAPPER_LOG_ERROR << errmsg;
throw WrapperException("errmsg");
}
ErrorCode
BinBFIndex::Build(const Config& cfg) {
try {
dim = cfg->d;
std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->Train(cfg);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
} catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
}
return KNOWHERE_SUCCESS;
}
Status
BinBFIndex::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const uint8_t* xt) {
try {
dim = cfg->d;
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xb);
ret_ds->Set(knowhere::meta::IDS, ids);
std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->Train(cfg);
index_->Add(ret_ds, cfg);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
}
return Status::OK();
}
const uint8_t*
BinBFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<knowhere::BinaryIDMAP>(index_);
if (raw_index) {
return raw_index->GetRawVectors();
}
return nullptr;
}
const int64_t*
BinBFIndex::GetRawIds() {
return std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->GetRawIds();
}
} // namespace engine
} // namespace milvus

View File

@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <utility>
#include "VecImpl.h"
namespace milvus {
namespace engine {
class BinVecImpl : public VecIndexImpl {
public:
explicit BinVecImpl(std::shared_ptr<knowhere::VectorIndex> index, const IndexType& type)
: VecIndexImpl(std::move(index), type) {
}
Status
BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const uint8_t* xt) override;
Status
Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg) override;
Status
Add(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg) override;
VecIndexPtr
CopyToGpu(const int64_t& device_id, const Config& cfg) override;
VecIndexPtr
CopyToCpu(const Config& cfg) override;
};
class BinBFIndex : public BinVecImpl {
public:
explicit BinBFIndex(std::shared_ptr<knowhere::VectorIndex> index)
: BinVecImpl(std::move(index), IndexType::FAISS_BIN_IDMAP) {
}
ErrorCode
Build(const Config& cfg);
Status
BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const uint8_t* xt) override;
const uint8_t*
GetRawVectors();
const int64_t*
GetRawIds();
};
} // namespace engine
} // namespace milvus

View File

@ -37,9 +37,9 @@ namespace engine {
#endif
void
ConfAdapter::MatchBase(knowhere::Config conf) {
ConfAdapter::MatchBase(knowhere::Config conf, knowhere::METRICTYPE default_metric) {
if (conf->metric_type == knowhere::DEFAULT_TYPE)
conf->metric_type = knowhere::METRICTYPE::L2;
conf->metric_type = default_metric;
}
knowhere::Config
@ -63,7 +63,7 @@ ConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
knowhere::Config
IVFConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
@ -74,13 +74,13 @@ IVFConfAdapter::Match(const TempMetaConf& metaconf) {
static constexpr float TYPICAL_COUNT = 1000000.0;
int64_t
IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
if (size <= TYPICAL_COUNT / 16384 + 1) {
IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
if (size <= TYPICAL_COUNT / per_nlist + 1) {
// handle less row count, avoid nlist set to 0
return 1;
} else if (int(size / TYPICAL_COUNT) * nlist <= 0) {
// calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
return int(size / TYPICAL_COUNT * 16384);
return int(size / TYPICAL_COUNT * per_nlist);
}
return nlist;
}
@ -112,7 +112,7 @@ IVFConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
knowhere::Config
IVFSQConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFSQCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
@ -207,7 +207,7 @@ IVFPQConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
knowhere::Config
NSGConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::NSGCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
@ -266,5 +266,26 @@ SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType&
return conf;
}
knowhere::Config
BinIDMAPConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
conf->k = metaconf.k;
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
return conf;
}
knowhere::Config
BinIVFConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFBinCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 2048);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
return conf;
}
} // namespace engine
} // namespace milvus

View File

@ -48,7 +48,7 @@ class ConfAdapter {
protected:
static void
MatchBase(knowhere::Config conf);
MatchBase(knowhere::Config conf, knowhere::METRICTYPE defalut_metric = knowhere::METRICTYPE::L2);
};
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
@ -63,7 +63,7 @@ class IVFConfAdapter : public ConfAdapter {
protected:
static int64_t
MatchNlist(const int64_t& size, const int64_t& nlist);
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist);
};
class IVFSQConfAdapter : public IVFConfAdapter {
@ -112,5 +112,17 @@ class SPTAGBKTConfAdapter : public ConfAdapter {
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
};
class BinIDMAPConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
};
class BinIVFConfAdapter : public IVFConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
};
} // namespace engine
} // namespace milvus

View File

@ -41,10 +41,12 @@ AdapterMgr::RegisterAdapter() {
init_ = true;
REGISTER_CONF_ADAPTER(ConfAdapter, IndexType::FAISS_IDMAP, idmap);
REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexType::FAISS_BIN_IDMAP, idmap_bin);
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexType::FAISS_IVFFLAT_CPU, ivf_cpu);
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexType::FAISS_IVFFLAT_GPU, ivf_gpu);
REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexType::FAISS_IVFFLAT_MIX, ivf_mix);
REGISTER_CONF_ADAPTER(BinIVFConfAdapter, IndexType::FAISS_BIN_IVFLAT_CPU, ivf_bin_cpu);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_CPU, ivfsq8_cpu);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_GPU, ivfsq8_gpu);

View File

@ -148,7 +148,7 @@ VecIndexImpl::Count() {
}
IndexType
VecIndexImpl::GetType() {
VecIndexImpl::GetType() const {
return type;
}

View File

@ -43,7 +43,7 @@ class VecIndexImpl : public VecIndex {
CopyToCpu(const Config& cfg) override;
IndexType
GetType() override;
GetType() const override;
int64_t
Dimension() override;

View File

@ -18,6 +18,8 @@
#include "wrapper/VecIndex.h"
#include "VecImpl.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
@ -32,6 +34,7 @@
#include "utils/Exception.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "wrapper/BinVecImpl.h"
#ifdef MILVUS_GPU_VERSION
#include <cuda.h>
@ -68,10 +71,18 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
index = std::make_shared<knowhere::IDMAP>();
return std::make_shared<BFIndex>(index);
}
case IndexType::FAISS_BIN_IDMAP: {
index = std::make_shared<knowhere::BinaryIDMAP>();
return std::make_shared<BinBFIndex>(index);
}
case IndexType::FAISS_IVFFLAT_CPU: {
index = std::make_shared<knowhere::IVF>();
break;
}
case IndexType::FAISS_BIN_IVFLAT_CPU: {
index = std::make_shared<knowhere::BinaryIVF>();
return std::make_shared<BinVecImpl>(index, type);
}
case IndexType::FAISS_IVFPQ_CPU: {
index = std::make_shared<knowhere::IVFPQ>();
break;

View File

@ -50,6 +50,8 @@ enum class IndexType {
NSG_MIX,
FAISS_IVFPQ_MIX,
SPTAG_BKT_RNT_CPU,
FAISS_BIN_IDMAP = 100,
FAISS_BIN_IVFLAT_CPU = 101,
};
class VecIndex;
@ -62,12 +64,31 @@ class VecIndex : public cache::DataObj {
BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt = 0,
const float* xt = nullptr) = 0;
virtual Status
BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt = 0,
const uint8_t* xt = nullptr) {
ENGINE_LOG_ERROR << "BuildAll with uint8_t not support";
return Status::OK();
}
virtual Status
Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg = Config()) = 0;
virtual Status
Add(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg = Config()) {
ENGINE_LOG_ERROR << "Add with uint8_t not support";
return Status::OK();
}
virtual Status
Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg = Config()) = 0;
virtual Status
Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg = Config()) {
ENGINE_LOG_ERROR << "Search with uint8_t not support";
return Status::OK();
}
virtual VecIndexPtr
CopyToGpu(const int64_t& device_id, const Config& cfg = Config()) = 0;
@ -82,7 +103,7 @@ class VecIndex : public cache::DataObj {
GetDeviceId() = 0;
virtual IndexType
GetType() = 0;
GetType() const = 0;
virtual int64_t
Dimension() = 0;

View File

@ -50,12 +50,13 @@ BuildTableSchema() {
}
void
BuildVectors(int64_t n, std::vector<float>& vectors) {
vectors.clear();
vectors.resize(n * TABLE_DIM);
float* data = vectors.data();
for (int i = 0; i < n; i++) {
for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
vectors.vector_count_ = n;
vectors.float_data_.clear();
vectors.float_data_.resize(n * TABLE_DIM);
float* data = vectors.float_data_.data();
for (uint64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
data[TABLE_DIM * i] += i / 2000.;
}
}
@ -161,15 +162,12 @@ TEST_F(DBTest, DB_TEST) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
milvus::engine::IDNumbers vector_ids;
milvus::engine::IDNumbers target_ids;
int64_t nb = 50;
std::vector<float> xb;
uint64_t nb = 50;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int64_t qb = 5;
std::vector<float> qxb;
uint64_t qb = 5;
milvus::engine::VectorsData qxb;
BuildVectors(qb, qxb);
std::thread search([&]() {
@ -191,13 +189,12 @@ TEST_F(DBTest, DB_TEST) {
START_TIMER;
std::vector<std::string> tags;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, qb, 10, qxb.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str());
ASSERT_TRUE(stat.ok());
for (auto i = 0; i < qb; ++i) {
ASSERT_EQ(result_ids[i * k], target_ids[i]);
ss.str("");
ss << "Result [" << i << "]:";
for (auto t = 0; t < k; t++) {
@ -214,10 +211,13 @@ TEST_F(DBTest, DB_TEST) {
for (auto i = 0; i < loop; ++i) {
if (i == 40) {
db_->InsertVectors(TABLE_NAME, "", qb, qxb.data(), target_ids);
ASSERT_EQ(target_ids.size(), qb);
qxb.id_array_.clear();
db_->InsertVectors(TABLE_NAME, "", qxb);
ASSERT_EQ(qxb.id_array_.size(), qb);
} else {
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
xb.id_array_.clear();
db_->InsertVectors(TABLE_NAME, "", xb);
ASSERT_EQ(xb.id_array_.size(), nb);
}
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
@ -250,21 +250,24 @@ TEST_F(DBTest, SEARCH_TEST) {
size_t nb = VECTOR_COUNT;
size_t nq = 10;
size_t k = 5;
std::vector<float> xb(nb * TABLE_DIM);
std::vector<float> xq(nq * TABLE_DIM);
std::vector<int64_t> ids(nb);
milvus::engine::VectorsData xb, xq;
xb.vector_count_ = nb;
xb.float_data_.resize(nb * TABLE_DIM);
xq.vector_count_ = nq;
xq.float_data_.resize(nq * TABLE_DIM);
xb.id_array_.resize(nb);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
for (size_t i = 0; i < nb * TABLE_DIM; i++) {
xb[i] = dis_xt(gen);
xb.float_data_[i] = dis_xt(gen);
if (i < nb) {
ids[i] = i;
xb.id_array_[i] = i;
}
}
for (size_t i = 0; i < nq * TABLE_DIM; i++) {
xq[i] = dis_xt(gen);
xq.float_data_[i] = dis_xt(gen);
}
// result data
@ -274,14 +277,8 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<float> dis(k * nq);
// insert data
const int batch_size = 100;
for (int j = 0; j < nb / batch_size; ++j) {
stat = db_->InsertVectors(TABLE_NAME, "", batch_size, xb.data() + batch_size * j * TABLE_DIM, ids);
if (j == 200) {
sleep(1);
}
ASSERT_TRUE(stat.ok());
}
stat = db_->InsertVectors(TABLE_NAME, "", xb);
ASSERT_TRUE(stat.ok());
milvus::engine::TableIndex index;
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IDMAP;
@ -291,9 +288,7 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok());
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -304,9 +299,7 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok());
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -317,9 +310,7 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok());
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -331,9 +322,7 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok());
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
#endif
@ -349,7 +338,7 @@ TEST_F(DBTest, SEARCH_TEST) {
}
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids,
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, dates, result_ids,
result_distances);
ASSERT_TRUE(stat.ok());
}
@ -361,9 +350,7 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok());
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 1100, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -378,14 +365,7 @@ TEST_F(DBTest, SEARCH_TEST) {
{
result_ids.clear();
result_dists.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, nq, 10, xq.data(), result_ids, result_dists);
ASSERT_TRUE(stat.ok());
}
{
result_ids.clear();
result_dists.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, 200, 10, xq.data(), result_ids, result_dists);
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, 10, xq, result_ids, result_dists);
ASSERT_TRUE(stat.ok());
}
@ -400,7 +380,7 @@ TEST_F(DBTest, SEARCH_TEST) {
}
result_ids.clear();
result_dists.clear();
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids,
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, dates, result_ids,
result_dists);
ASSERT_TRUE(stat.ok());
}
@ -418,15 +398,15 @@ TEST_F(DBTest, PRELOADTABLE_TEST) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
int64_t nb = VECTOR_COUNT;
std::vector<float> xb;
uint64_t nb = VECTOR_COUNT;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int loop = 5;
for (auto i = 0; i < loop; ++i) {
milvus::engine::IDNumbers vector_ids;
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids.size(), nb);
xb.id_array_.clear();
db_->InsertVectors(TABLE_NAME, "", xb);
ASSERT_EQ(xb.id_array_.size(), nb);
}
milvus::engine::TableIndex index;
@ -454,8 +434,8 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
stat = db_->HasTable(table_info.table_id_, has_table);
ASSERT_FALSE(stat.ok());
milvus::engine::IDNumbers ids;
stat = db_->InsertVectors(table_info.table_id_, "", 0, nullptr, ids);
milvus::engine::VectorsData xb;
stat = db_->InsertVectors(table_info.table_id_, "", xb);
ASSERT_FALSE(stat.ok());
stat = db_->PreloadTable(table_info.table_id_);
@ -477,10 +457,10 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat =
db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, 1, nullptr, dates, result_ids, result_distances);
db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, xb, dates, result_ids, result_distances);
ASSERT_FALSE(stat.ok());
std::vector<std::string> file_ids;
stat = db_->QueryByFileID(dummy_context_, table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, result_ids,
stat = db_->QueryByFileID(dummy_context_, table_info.table_id_, file_ids, 1, 1, xb, dates, result_ids,
result_distances);
ASSERT_FALSE(stat.ok());
@ -492,13 +472,12 @@ TEST_F(DBTest, INDEX_TEST) {
milvus::engine::meta::TableSchema table_info = BuildTableSchema();
auto stat = db_->CreateTable(table_info);
int64_t nb = VECTOR_COUNT;
std::vector<float> xb;
uint64_t nb = VECTOR_COUNT;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
milvus::engine::IDNumbers vector_ids;
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids.size(), nb);
db_->InsertVectors(TABLE_NAME, "", xb);
ASSERT_EQ(xb.id_array_.size(), nb);
milvus::engine::TableIndex index;
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
@ -552,7 +531,7 @@ TEST_F(DBTest, PARTITION_TEST) {
stat = db_->CreatePartition(table_name, partition_name, partition_tag);
ASSERT_FALSE(stat.ok());
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(INSERT_BATCH, xb);
milvus::engine::IDNumbers vector_ids;
@ -561,7 +540,7 @@ TEST_F(DBTest, PARTITION_TEST) {
vector_ids[k] = i * INSERT_BATCH + k;
}
db_->InsertVectors(table_name, partition_tag, INSERT_BATCH, xb.data(), vector_ids);
db_->InsertVectors(table_name, partition_tag, xb);
ASSERT_EQ(vector_ids.size(), INSERT_BATCH);
}
@ -594,14 +573,14 @@ TEST_F(DBTest, PARTITION_TEST) {
const int64_t nq = 5;
const int64_t topk = 10;
const int64_t nprobe = 10;
std::vector<float> xq;
milvus::engine::VectorsData xq;
BuildVectors(nq, xq);
// specify partition tags
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nq, nprobe, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -609,7 +588,7 @@ TEST_F(DBTest, PARTITION_TEST) {
tags.clear();
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nq, nprobe, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -617,7 +596,7 @@ TEST_F(DBTest, PARTITION_TEST) {
tags.push_back("\\d");
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nq, nprobe, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
}
@ -661,14 +640,14 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) {
uint64_t size;
db_->Size(size);
int64_t nb = 10;
std::vector<float> xb;
uint64_t nb = 10;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int loop = INSERT_LOOP;
for (auto i = 0; i < loop; ++i) {
milvus::engine::IDNumbers vector_ids;
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
db_->InsertVectors(TABLE_NAME, "", xb);
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
@ -695,12 +674,12 @@ TEST_F(DBTest2, DELETE_TEST) {
uint64_t size;
db_->Size(size);
int64_t nb = VECTOR_COUNT;
std::vector<float> xb;
uint64_t nb = VECTOR_COUNT;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
milvus::engine::IDNumbers vector_ids;
stat = db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
stat = db_->InsertVectors(TABLE_NAME, "", xb);
milvus::engine::TableIndex index;
stat = db_->CreateIndex(TABLE_NAME, index);
@ -730,12 +709,12 @@ TEST_F(DBTest2, DELETE_BY_RANGE_TEST) {
db_->Size(size);
ASSERT_EQ(size, 0UL);
int64_t nb = VECTOR_COUNT;
std::vector<float> xb;
uint64_t nb = VECTOR_COUNT;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
milvus::engine::IDNumbers vector_ids;
stat = db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
stat = db_->InsertVectors(TABLE_NAME, "", xb);
milvus::engine::TableIndex index;
stat = db_->CreateIndex(TABLE_NAME, index);

View File

@ -44,12 +44,13 @@ BuildTableSchema() {
}
void
BuildVectors(int64_t n, std::vector<float>& vectors) {
vectors.clear();
vectors.resize(n * TABLE_DIM);
float* data = vectors.data();
for (int i = 0; i < n; i++) {
for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
vectors.vector_count_ = n;
vectors.float_data_.clear();
vectors.float_data_.resize(n * TABLE_DIM);
float* data = vectors.float_data_.data();
for (uint64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
data[TABLE_DIM * i] += i / 2000.;
}
}
@ -66,19 +67,16 @@ TEST_F(MySqlDBTest, DB_TEST) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
milvus::engine::IDNumbers vector_ids;
milvus::engine::IDNumbers target_ids;
int64_t nb = 50;
std::vector<float> xb;
uint64_t nb = 50;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int64_t qb = 5;
std::vector<float> qxb;
uint64_t qb = 5;
milvus::engine::VectorsData qxb;
BuildVectors(qb, qxb);
db_->InsertVectors(TABLE_NAME, "", qb, qxb.data(), target_ids);
ASSERT_EQ(target_ids.size(), qb);
db_->InsertVectors(TABLE_NAME, "", qxb);
ASSERT_EQ(qxb.id_array_.size(), qb);
std::thread search([&]() {
milvus::engine::ResultIds result_ids;
@ -98,7 +96,7 @@ TEST_F(MySqlDBTest, DB_TEST) {
START_TIMER;
std::vector<std::string> tags;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, qb, 10, qxb.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str());
@ -106,13 +104,6 @@ TEST_F(MySqlDBTest, DB_TEST) {
for (auto i = 0; i < qb; ++i) {
// std::cout << results[k][0].first << " " << target_ids[k] << std::endl;
// ASSERT_EQ(results[k][0].first, target_ids[k]);
bool exists = false;
for (auto t = 0; t < k; t++) {
if (result_ids[i * k + t] == target_ids[i]) {
exists = true;
}
}
ASSERT_TRUE(exists);
ss.str("");
ss << "Result [" << i << "]:";
for (auto t = 0; t < k; t++) {
@ -136,7 +127,8 @@ TEST_F(MySqlDBTest, DB_TEST) {
// } else {
// db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
// }
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
xb.id_array_.clear();
db_->InsertVectors(TABLE_NAME, "", xb);
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
@ -157,21 +149,24 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
size_t nb = VECTOR_COUNT;
size_t nq = 10;
size_t k = 5;
std::vector<float> xb(nb * TABLE_DIM);
std::vector<float> xq(nq * TABLE_DIM);
std::vector<int64_t> ids(nb);
milvus::engine::VectorsData xb, xq;
xb.vector_count_ = nb;
xb.float_data_.resize(nb * TABLE_DIM);
xq.vector_count_ = nq;
xq.float_data_.resize(nq * TABLE_DIM);
xb.id_array_.resize(nb);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
for (size_t i = 0; i < nb * TABLE_DIM; i++) {
xb[i] = dis_xt(gen);
xb.float_data_[i] = dis_xt(gen);
if (i < nb) {
ids[i] = i;
xb.id_array_[i] = i;
}
}
for (size_t i = 0; i < nq * TABLE_DIM; i++) {
xq[i] = dis_xt(gen);
xq.float_data_[i] = dis_xt(gen);
}
// result data
@ -181,21 +176,15 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
std::vector<float> dis(k * nq);
// insert data
const int batch_size = 100;
for (int j = 0; j < nb / batch_size; ++j) {
stat = db_->InsertVectors(TABLE_NAME, "", batch_size, xb.data() + batch_size * j * TABLE_DIM, ids);
if (j == 200) {
sleep(1);
}
ASSERT_TRUE(stat.ok());
}
stat = db_->InsertVectors(TABLE_NAME, "", xb);
ASSERT_TRUE(stat.ok());
sleep(2); // wait until build index finish
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, nq, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -228,12 +217,12 @@ TEST_F(MySqlDBTest, ARHIVE_DISK_CHECK) {
db_->Size(size);
int64_t nb = 10;
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int loop = INSERT_LOOP;
for (auto i = 0; i < loop; ++i) {
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
db_->InsertVectors(TABLE_NAME, "", xb);
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
@ -264,12 +253,12 @@ TEST_F(MySqlDBTest, DELETE_TEST) {
db_->Size(size);
int64_t nb = INSERT_LOOP;
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int loop = 20;
for (auto i = 0; i < loop; ++i) {
db_->InsertVectors(TABLE_NAME, "", nb, xb.data(), vector_ids);
db_->InsertVectors(TABLE_NAME, "", xb);
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
@ -307,7 +296,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
stat = db_->CreatePartition(table_name, partition_name, partition_tag);
ASSERT_FALSE(stat.ok());
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(INSERT_BATCH, xb);
milvus::engine::IDNumbers vector_ids;
@ -316,7 +305,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
vector_ids[k] = i * INSERT_BATCH + k;
}
db_->InsertVectors(table_name, partition_tag, INSERT_BATCH, xb.data(), vector_ids);
db_->InsertVectors(table_name, partition_tag, xb);
ASSERT_EQ(vector_ids.size(), INSERT_BATCH);
}
@ -349,14 +338,14 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
const int64_t nq = 5;
const int64_t topk = 10;
const int64_t nprobe = 10;
std::vector<float> xq;
milvus::engine::VectorsData xq;
BuildVectors(nq, xq);
// specify partition tags
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, nq, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -364,7 +353,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
tags.clear();
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, nq, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -372,7 +361,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
tags.push_back("\\d");
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, nq, 10, xq.data(), result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
}

View File

@ -55,10 +55,11 @@ BuildTableSchema() {
}
void
BuildVectors(int64_t n, std::vector<float>& vectors) {
vectors.clear();
vectors.resize(n * TABLE_DIM);
float* data = vectors.data();
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
vectors.vector_count_ = n;
vectors.float_data_.clear();
vectors.float_data_.resize(n * TABLE_DIM);
float* data = vectors.float_data_.data();
for (int i = 0; i < n; i++) {
for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
}
@ -76,10 +77,10 @@ TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) {
ASSERT_TRUE(status.ok());
int64_t n = 100;
std::vector<float> vectors;
milvus::engine::VectorsData vectors;
BuildVectors(n, vectors);
milvus::engine::VectorSource source(n, vectors.data());
milvus::engine::VectorSource source(vectors);
size_t num_vectors_added;
milvus::engine::ExecutionEnginePtr execution_engine_ = milvus::engine::EngineFactory::Build(
@ -87,21 +88,16 @@ TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) {
(milvus::engine::EngineType)table_file_schema.engine_type_,
(milvus::engine::MetricType)table_file_schema.metric_type_, table_schema.nlist_);
milvus::engine::IDNumbers vector_ids;
status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added, vector_ids);
status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added);
ASSERT_TRUE(status.ok());
vector_ids = source.GetVectorIds();
ASSERT_EQ(vector_ids.size(), 50);
ASSERT_EQ(num_vectors_added, 50);
ASSERT_EQ(source.GetVectorIds().size(), 50);
vector_ids.clear();
status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added, vector_ids);
vectors.id_array_.clear();
status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added);
ASSERT_TRUE(status.ok());
ASSERT_EQ(num_vectors_added, 50);
vector_ids = source.GetVectorIds();
ASSERT_EQ(vector_ids.size(), 100);
ASSERT_EQ(source.GetVectorIds().size(), 100);
}
TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) {
@ -114,34 +110,30 @@ TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) {
milvus::engine::MemTableFile mem_table_file(GetTableName(), impl_, options);
int64_t n_100 = 100;
std::vector<float> vectors_100;
milvus::engine::VectorsData vectors_100;
BuildVectors(n_100, vectors_100);
milvus::engine::VectorSourcePtr source = std::make_shared<milvus::engine::VectorSource>(n_100, vectors_100.data());
milvus::engine::VectorSourcePtr source = std::make_shared<milvus::engine::VectorSource>(vectors_100);
milvus::engine::IDNumbers vector_ids;
status = mem_table_file.Add(source, vector_ids);
status = mem_table_file.Add(source);
ASSERT_TRUE(status.ok());
// std::cout << mem_table_file.GetCurrentMem() << " " << mem_table_file.GetMemLeft() << std::endl;
vector_ids = source->GetVectorIds();
ASSERT_EQ(vector_ids.size(), 100);
size_t singleVectorMem = sizeof(float) * TABLE_DIM;
ASSERT_EQ(mem_table_file.GetCurrentMem(), n_100 * singleVectorMem);
int64_t n_max = milvus::engine::MAX_TABLE_FILE_MEM / singleVectorMem;
std::vector<float> vectors_128M;
milvus::engine::VectorsData vectors_128M;
BuildVectors(n_max, vectors_128M);
milvus::engine::VectorSourcePtr source_128M =
std::make_shared<milvus::engine::VectorSource>(n_max, vectors_128M.data());
std::make_shared<milvus::engine::VectorSource>(vectors_128M);
vector_ids.clear();
status = mem_table_file.Add(source_128M, vector_ids);
status = mem_table_file.Add(source_128M);
vector_ids = source_128M->GetVectorIds();
ASSERT_EQ(vector_ids.size(), n_max - n_100);
ASSERT_EQ(source_128M->GetVectorIds().size(), n_max - n_100);
ASSERT_TRUE(mem_table_file.IsFull());
}
@ -154,19 +146,18 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
ASSERT_TRUE(status.ok());
int64_t n_100 = 100;
std::vector<float> vectors_100;
milvus::engine::VectorsData vectors_100;
BuildVectors(n_100, vectors_100);
milvus::engine::VectorSourcePtr source_100 =
std::make_shared<milvus::engine::VectorSource>(n_100, vectors_100.data());
std::make_shared<milvus::engine::VectorSource>(vectors_100);
milvus::engine::MemTable mem_table(GetTableName(), impl_, options);
milvus::engine::IDNumbers vector_ids;
status = mem_table.Add(source_100, vector_ids);
status = mem_table.Add(source_100);
ASSERT_TRUE(status.ok());
vector_ids = source_100->GetVectorIds();
ASSERT_EQ(vector_ids.size(), 100);
ASSERT_EQ(source_100->GetVectorIds().size(), 100);
milvus::engine::MemTableFilePtr mem_table_file;
mem_table.GetCurrentMemTableFile(mem_table_file);
@ -174,17 +165,16 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
ASSERT_EQ(mem_table_file->GetCurrentMem(), n_100 * singleVectorMem);
int64_t n_max = milvus::engine::MAX_TABLE_FILE_MEM / singleVectorMem;
std::vector<float> vectors_128M;
milvus::engine::VectorsData vectors_128M;
BuildVectors(n_max, vectors_128M);
vector_ids.clear();
milvus::engine::VectorSourcePtr source_128M =
std::make_shared<milvus::engine::VectorSource>(n_max, vectors_128M.data());
status = mem_table.Add(source_128M, vector_ids);
std::make_shared<milvus::engine::VectorSource>(vectors_128M);
status = mem_table.Add(source_128M);
ASSERT_TRUE(status.ok());
vector_ids = source_128M->GetVectorIds();
ASSERT_EQ(vector_ids.size(), n_max);
ASSERT_EQ(source_128M->GetVectorIds().size(), n_max);
mem_table.GetCurrentMemTableFile(mem_table_file);
ASSERT_EQ(mem_table_file->GetCurrentMem(), n_100 * singleVectorMem);
@ -192,17 +182,16 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
ASSERT_EQ(mem_table.GetTableFileCount(), 2);
int64_t n_1G = 1024000;
std::vector<float> vectors_1G;
milvus::engine::VectorsData vectors_1G;
BuildVectors(n_1G, vectors_1G);
milvus::engine::VectorSourcePtr source_1G = std::make_shared<milvus::engine::VectorSource>(n_1G, vectors_1G.data());
milvus::engine::VectorSourcePtr source_1G = std::make_shared<milvus::engine::VectorSource>(vectors_1G);
vector_ids.clear();
status = mem_table.Add(source_1G, vector_ids);
status = mem_table.Add(source_1G);
ASSERT_TRUE(status.ok());
vector_ids = source_1G->GetVectorIds();
ASSERT_EQ(vector_ids.size(), n_1G);
ASSERT_EQ(source_1G->GetVectorIds().size(), n_1G);
int expectedTableFileCount = 2 + std::ceil((n_1G - n_100) * singleVectorMem / milvus::engine::MAX_TABLE_FILE_MEM);
ASSERT_EQ(mem_table.GetTableFileCount(), expectedTableFileCount);
@ -222,15 +211,14 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
int64_t nb = 100000;
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
milvus::engine::IDNumbers vector_ids;
for (int64_t i = 0; i < nb; i++) {
vector_ids.push_back(i);
xb.id_array_.push_back(i);
}
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
stat = db_->InsertVectors(GetTableName(), "", xb);
ASSERT_TRUE(stat.ok());
std::this_thread::sleep_for(std::chrono::seconds(3)); // ensure raw data write to disk
@ -240,14 +228,15 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
std::uniform_int_distribution<int64_t> dis(0, nb - 1);
int64_t num_query = 10;
std::map<int64_t, std::vector<float>> search_vectors;
std::map<int64_t, milvus::engine::VectorsData> search_vectors;
for (int64_t i = 0; i < num_query; ++i) {
int64_t index = dis(gen);
std::vector<float> search;
milvus::engine::VectorsData search;
search.vector_count_ = 1;
for (int64_t j = 0; j < TABLE_DIM; j++) {
search.push_back(xb[index * TABLE_DIM + j]);
search.float_data_.push_back(xb.float_data_[index * TABLE_DIM + j]);
}
search_vectors.insert(std::make_pair(vector_ids[index], search));
search_vectors.insert(std::make_pair(xb.id_array_[index], search));
}
int topk = 10, nprobe = 10;
@ -257,7 +246,7 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, 1, nprobe, search.data(), result_ids,
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, search, result_ids,
result_distances);
ASSERT_EQ(result_ids[0], pair.first);
ASSERT_LT(result_distances[0], 1e-4);
@ -279,10 +268,10 @@ TEST_F(MemManagerTest2, INSERT_TEST) {
int insert_loop = 20;
for (int i = 0; i < insert_loop; ++i) {
int64_t nb = 40960;
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
milvus::engine::IDNumbers vector_ids;
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
stat = db_->InsertVectors(GetTableName(), "", xb);
ASSERT_TRUE(stat.ok());
}
auto end_time = METRICS_NOW_TIME;
@ -300,15 +289,12 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
milvus::engine::IDNumbers vector_ids;
milvus::engine::IDNumbers target_ids;
int64_t nb = 40960;
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int64_t qb = 5;
std::vector<float> qxb;
milvus::engine::VectorsData qxb;
BuildVectors(qb, qxb);
std::thread search([&]() {
@ -331,13 +317,12 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
std::vector<std::string> tags;
stat =
db_->Query(dummy_context_, GetTableName(), tags, k, qb, 10, qxb.data(), result_ids, result_distances);
db_->Query(dummy_context_, GetTableName(), tags, k, 10, qxb, result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str());
ASSERT_TRUE(stat.ok());
for (auto i = 0; i < qb; ++i) {
ASSERT_EQ(result_ids[i * k], target_ids[i]);
ss.str("");
ss << "Result [" << i << "]:";
for (auto t = 0; t < k; t++) {
@ -354,10 +339,13 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
for (auto i = 0; i < loop; ++i) {
if (i == 0) {
db_->InsertVectors(GetTableName(), "", qb, qxb.data(), target_ids);
ASSERT_EQ(target_ids.size(), qb);
qxb.id_array_.clear();
db_->InsertVectors(GetTableName(), "", qxb);
ASSERT_EQ(qxb.id_array_.size(), qb);
} else {
db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
xb.id_array_.clear();
db_->InsertVectors(GetTableName(), "", xb);
ASSERT_EQ(xb.id_array_.size(), nb);
}
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
@ -375,62 +363,56 @@ TEST_F(MemManagerTest2, VECTOR_IDS_TEST) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
milvus::engine::IDNumbers vector_ids;
int64_t nb = 100000;
std::vector<float> xb;
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
vector_ids.resize(nb);
xb.id_array_.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i;
xb.id_array_[i] = i;
}
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids[0], 0);
stat = db_->InsertVectors(GetTableName(), "", xb);
ASSERT_EQ(xb.id_array_[0], 0);
ASSERT_TRUE(stat.ok());
nb = 25000;
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
vector_ids.resize(nb);
xb.id_array_.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i + nb;
xb.id_array_[i] = i + nb;
}
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids[0], nb);
stat = db_->InsertVectors(GetTableName(), "", xb);
ASSERT_EQ(xb.id_array_[0], nb);
ASSERT_TRUE(stat.ok());
nb = 262144; // 512M
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
vector_ids.resize(nb);
xb.id_array_.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i + nb / 2;
xb.id_array_[i] = i + nb / 2;
}
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids[0], nb / 2);
stat = db_->InsertVectors(GetTableName(), "", xb);
ASSERT_EQ(xb.id_array_[0], nb / 2);
ASSERT_TRUE(stat.ok());
nb = 65536; // 128M
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
xb.id_array_.clear();
stat = db_->InsertVectors(GetTableName(), "", xb);
ASSERT_TRUE(stat.ok());
nb = 100;
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
vector_ids.resize(nb);
xb.id_array_.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i + nb;
xb.id_array_[i] = i + nb;
}
stat = db_->InsertVectors(GetTableName(), "", nb, xb.data(), vector_ids);
stat = db_->InsertVectors(GetTableName(), "", xb);
for (auto i = 0; i < nb; i++) {
ASSERT_EQ(vector_ids[i], i + nb);
ASSERT_EQ(xb.id_array_[i], i + nb);
}
}

View File

@ -29,6 +29,22 @@
#include "db/DB.h"
#include "db/meta/SqliteMetaImpl.h"
namespace {
static constexpr int64_t TABLE_DIM = 256;
void
BuildVectors(uint64_t n, milvus::engine::VectorsData& vectors) {
vectors.vector_count_ = n;
vectors.float_data_.clear();
vectors.float_data_.resize(n * TABLE_DIM);
float* data = vectors.float_data_.data();
for (uint64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48();
data[TABLE_DIM * i] += i / 2000.;
}
}
} // namespace
TEST_F(MetricTest, METRIC_TEST) {
milvus::server::SystemInfo::GetInstance().Init();
milvus::server::Metrics::GetInstance().Init();
@ -48,29 +64,18 @@ TEST_F(MetricTest, METRIC_TEST) {
group_info_get.table_id_ = group_name;
stat = db_->DescribeTable(group_info_get);
milvus::engine::IDNumbers vector_ids;
milvus::engine::IDNumbers target_ids;
int d = 256;
int nb = 50;
float *xb = new float[d * nb];
for (int i = 0; i < nb; i++) {
for (int j = 0; j < d; j++) xb[d * i + j] = drand48();
xb[d * i] += i / 2000.;
}
milvus::engine::VectorsData xb;
BuildVectors(nb, xb);
int qb = 5;
float *qxb = new float[d * qb];
for (int i = 0; i < qb; i++) {
for (int j = 0; j < d; j++) qxb[d * i + j] = drand48();
qxb[d * i] += i / 2000.;
}
milvus::engine::VectorsData xq;
BuildVectors(qb, xq);
std::thread search([&]() {
// std::vector<std::string> tags;
// milvus::engine::ResultIds result_ids;
// milvus::engine::ResultDistances result_distances;
int k = 10;
std::this_thread::sleep_for(std::chrono::seconds(2));
INIT_TIMER;
@ -105,18 +110,18 @@ TEST_F(MetricTest, METRIC_TEST) {
for (auto i = 0; i < loop; ++i) {
if (i == 40) {
db_->InsertVectors(group_name, "", qb, qxb, target_ids);
ASSERT_EQ(target_ids.size(), qb);
xq.id_array_.clear();
db_->InsertVectors(group_name, "", xq);
ASSERT_EQ(xq.id_array_.size(), qb);
} else {
db_->InsertVectors(group_name, "", nb, xb, vector_ids);
xb.id_array_.clear();
db_->InsertVectors(group_name, "", xb);
ASSERT_EQ(xb.id_array_.size(), nb);
}
std::this_thread::sleep_for(std::chrono::microseconds(2000));
}
search.join();
delete[] xb;
delete[] qxb;
}
TEST_F(MetricTest, COLLECTOR_METRICS_TEST) {

View File

@ -48,7 +48,7 @@ class MockVecIndex : public engine::VecIndex {
}
engine::IndexType
GetType() override {
GetType() const override {
return engine::IndexType::INVALID;
}

View File

@ -58,7 +58,7 @@ class MockVecIndex : public milvus::engine::VecIndex {
}
milvus::engine::IndexType
GetType() override {
GetType() const override {
return milvus::engine::IndexType::INVALID;
}

View File

@ -47,7 +47,7 @@ constexpr int64_t SECONDS_EACH_HOUR = 3600;
void
CopyRowRecord(::milvus::grpc::RowRecord* target, const std::vector<float>& src) {
auto vector_data = target->mutable_vector_data();
auto vector_data = target->mutable_float_data();
vector_data->Resize(static_cast<int>(src.size()), 0.0);
memcpy(vector_data->mutable_data(), src.data(), src.size() * sizeof(float));
}

View File

@ -283,9 +283,9 @@ TEST(ValidationUtilTest, VALIDATE_DIMENSION_TEST) {
milvus::SERVER_INVALID_VECTOR_DIMENSION);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(0).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(16385).code(),
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32769).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(16384).code(), milvus::SERVER_SUCCESS);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32768).code(), milvus::SERVER_SUCCESS);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(1).code(), milvus::SERVER_SUCCESS);
}

View File

@ -19,6 +19,7 @@
set(test_files
test_knowhere.cpp
test_binindex.cpp
test_wrapper.cpp)
if (MILVUS_GPU_VERSION)
set(test_files ${test_files}

View File

@ -0,0 +1,117 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <tuple>
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "wrapper/VecIndex.h"
#include "wrapper/utils.h"
#include <gtest/gtest.h>
using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinIndexTest : public BinDataGen,
public TestWithParam<::std::tuple<milvus::engine::IndexType, int, int, int, int>> {
protected:
void
SetUp() override {
std::tie(index_type, dim, nb, nq, k) = GetParam();
Generate(dim, nb, nq, k);
milvus::engine::TempMetaConf tempconf;
tempconf.metric_type = knowhere::METRICTYPE::TANIMOTO;
tempconf.size = nb;
tempconf.dim = dim;
tempconf.k = k;
tempconf.nprobe = 16;
index_ = GetVecIndexFactory(index_type);
conf = ParamGenerator::GetInstance().GenBuild(index_type, tempconf);
searchconf = ParamGenerator::GetInstance().GenSearchConf(index_type, tempconf);
}
void
TearDown() override {
}
protected:
milvus::engine::IndexType index_type;
milvus::engine::VecIndexPtr index_ = nullptr;
knowhere::Config conf;
knowhere::Config searchconf;
};
INSTANTIATE_TEST_CASE_P(WrapperParam, BinIndexTest,
Values(
//["Index type", "dim", "nb", "nq", "k", "build config", "search config"]
std::make_tuple(milvus::engine::IndexType::FAISS_BIN_IDMAP, 64, 1000, 10, 10),
std::make_tuple(milvus::engine::IndexType::FAISS_BIN_IVFLAT_CPU, DIM, NB, 10, 10)));
TEST_P(BinIndexTest, BASE_TEST) {
EXPECT_EQ(index_->GetType(), index_type);
conf->Dump();
searchconf->Dump();
auto elems = nq * k;
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
index_->BuildAll(nb, xb.data(), ids.data(), conf);
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
AssertResult(res_ids, res_dis);
}
TEST_P(BinIndexTest, SERIALIZE_TEST) {
EXPECT_EQ(index_->GetType(), index_type);
auto elems = nq * k;
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
index_->BuildAll(nb, xb.data(), ids.data(), conf);
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
AssertResult(res_ids, res_dis);
{
auto binary = index_->Serialize();
auto type = index_->GetType();
auto new_index = GetVecIndexFactory(type);
new_index->Load(binary);
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
EXPECT_EQ(new_index->Count(), index_->Count());
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
AssertResult(res_ids, res_dis);
}
{
std::string file_location = "/tmp/knowhere";
write_index(index_, file_location);
auto new_index = milvus::engine::read_index(file_location);
EXPECT_EQ(new_index->GetType(), ConvertToCpuIndexType(index_type));
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
EXPECT_EQ(new_index->Count(), index_->Count());
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf);
AssertResult(res_ids, res_dis);
}
}

View File

@ -153,3 +153,73 @@ DataGenBase::AssertResult(const std::vector<int64_t>& ids, const std::vector<flo
EXPECT_GT(precision, 0.5);
std::cout << std::endl << "Precision: " << precision << ", match: " << match << ", total: " << nq * k << std::endl;
}
void
BinDataGen::GenData(const int& dim,
const int& nb,
const int& nq,
uint8_t* xb,
uint8_t* xq,
int64_t* ids,
const int& k,
int64_t* gt_ids,
float* gt_dis) {
for (auto i = 0; i < nb; ++i) {
for (auto j = 0; j < dim; ++j) {
// p_data[i * d + j] = float(base + i);
xb[i * dim + j] = (uint8_t)lrand48();
}
ids[i] = i;
}
for (int64_t i = 0; i < nq * dim; ++i) {
xq[i] = xb[i];
}
}
void
BinDataGen::GenData(const int& dim,
const int& nb,
const int& nq,
std::vector<uint8_t>& xb,
std::vector<uint8_t>& xq,
std::vector<int64_t>& ids,
const int& k,
std::vector<int64_t>& gt_ids,
std::vector<float>& gt_dis) {
xb.clear();
xq.clear();
ids.clear();
gt_ids.clear();
gt_dis.clear();
xb.resize(nb * dim);
xq.resize(nq * dim);
ids.resize(nb);
gt_ids.resize(nq * k);
gt_dis.resize(nq * k);
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data());
assert(xb.size() == (size_t)dim * nb);
assert(xq.size() == (size_t)dim * nq);
}
void
BinDataGen::AssertResult(const std::vector<int64_t>& ids, const std::vector<float>& dis) {
EXPECT_EQ(ids.size(), nq * k);
EXPECT_EQ(dis.size(), nq * k);
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(ids[i * k], i);
// EXPECT_EQ(dis[i * k], gt_dis[i * k]);
}
}
void
BinDataGen::Generate(const int& dim, const int& nb, const int& nq, const int& k) {
this->nb = nb;
this->nq = nq;
this->dim = dim;
this->k = k;
int64_t dim_x = dim / 8;
GenData(dim_x, nb, nq, xb, xq, ids, k, gt_ids, gt_dis);
}

View File

@ -82,6 +82,38 @@ class DataGenBase {
std::vector<float> gt_dis;
};
class BinDataGen {
public:
virtual void GenData(const int& dim, const int& nb, const int& nq, uint8_t* xb, uint8_t* xq, int64_t* ids,
const int& k, int64_t* gt_ids, float* gt_dis);
virtual void GenData(const int& dim,
const int& nb,
const int& nq,
std::vector<uint8_t>& xb,
std::vector<uint8_t>& xq,
std::vector<int64_t>& ids,
const int& k,
std::vector<int64_t>& gt_ids,
std::vector<float>& gt_dis);
void AssertResult(const std::vector<int64_t>& ids, const std::vector<float>& dis);
void Generate(const int& dim, const int& nb, const int& nq, const int& k);
int dim = DIM;
int nb = NB;
int nq = NQ;
int k = 10;
std::vector<uint8_t> xb;
std::vector<uint8_t> xq;
std::vector<int64_t> ids;
// Ground Truth
std::vector<int64_t> gt_ids;
std::vector<float> gt_dis;
};
class ParamGenerator {
public:
static ParamGenerator& GetInstance() {

View File

@ -21,3 +21,4 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/utils util_files)
add_subdirectory(simple)
add_subdirectory(partition)
add_subdirectory(binary_vector)

View File

@ -0,0 +1,33 @@
#-------------------------------------------------------------------------------
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#-------------------------------------------------------------------------------
aux_source_directory(src src_files)
add_executable(sdk_binary
main.cpp
${src_files}
${util_files}
)
target_link_libraries(sdk_binary
milvus_sdk
pthread
)
install(TARGETS sdk_binary DESTINATION bin)

View File

@ -0,0 +1,79 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <getopt.h>
#include <libgen.h>
#include <cstring>
#include <string>
#include "src/ClientTest.h"
void
print_help(const std::string& app_name);
int
main(int argc, char* argv[]) {
printf("Client start...\n");
std::string app_name = basename(argv[0]);
static struct option long_options[] = {{"server", optional_argument, nullptr, 's'},
{"port", optional_argument, nullptr, 'p'},
{"help", no_argument, nullptr, 'h'},
{nullptr, 0, nullptr, 0}};
int option_index = 0;
std::string address = "127.0.0.1", port = "19530";
app_name = argv[0];
int value;
while ((value = getopt_long(argc, argv, "s:p:h", long_options, &option_index)) != -1) {
switch (value) {
case 's': {
char* address_ptr = strdup(optarg);
address = address_ptr;
free(address_ptr);
break;
}
case 'p': {
char* port_ptr = strdup(optarg);
port = port_ptr;
free(port_ptr);
break;
}
case 'h':
default:
print_help(app_name);
return EXIT_SUCCESS;
}
}
ClientTest test;
test.Test(address, port);
printf("Client stop...\n");
return 0;
}
void
print_help(const std::string& app_name) {
printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
printf(" Options:\n");
printf(" -s --server Server address, default 127.0.0.1\n");
printf(" -p --port Server port, default 19530\n");
printf(" -h --help Print help information\n");
printf("\n");
}

View File

@ -0,0 +1,167 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "examples/simple/src/ClientTest.h"
#include "include/MilvusApi.h"
#include "examples/utils/TimeRecorder.h"
#include "examples/utils/Utils.h"
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <random>
namespace {
const char* TABLE_NAME = milvus_sdk::Utils::GenTableName().c_str();
constexpr int64_t TABLE_DIMENSION = 512;
constexpr int64_t TABLE_INDEX_FILE_SIZE = 128;
constexpr milvus::MetricType TABLE_METRIC_TYPE = milvus::MetricType::TANIMOTO;
constexpr int64_t BATCH_ROW_COUNT = 100000;
constexpr int64_t NQ = 5;
constexpr int64_t TOP_K = 10;
constexpr int64_t NPROBE = 32;
constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different, ensure less than BATCH_ROW_COUNT
constexpr int64_t ADD_VECTOR_LOOP = 20;
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFFLAT;
constexpr int32_t N_LIST = 1024;
milvus::TableSchema
BuildTableSchema() {
milvus::TableSchema tb_schema = {TABLE_NAME, TABLE_DIMENSION, TABLE_INDEX_FILE_SIZE, TABLE_METRIC_TYPE};
return tb_schema;
}
milvus::IndexParam
BuildIndexParam() {
milvus::IndexParam index_param = {TABLE_NAME, INDEX_TYPE, N_LIST};
return index_param;
}
void
BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::RowRecord>& vector_record_array,
std::vector<int64_t>& record_ids, int64_t dimension) {
if (to <= from) {
return;
}
vector_record_array.clear();
record_ids.clear();
int64_t dim_byte = dimension/8;
for (int64_t k = from; k < to; k++) {
milvus::RowRecord record;
record.binary_data.resize(dim_byte);
for (int64_t i = 0; i < dim_byte; i++) {
record.binary_data[i] = (uint8_t)lrand48();
}
vector_record_array.emplace_back(record);
record_ids.push_back(k);
}
}
} // namespace
void
ClientTest::Test(const std::string& address, const std::string& port) {
std::shared_ptr<milvus::Connection> conn = milvus::Connection::Create();
milvus::Status stat;
{ // connect server
milvus::ConnectParam param = {address, port};
stat = conn->Connect(param);
std::cout << "Connect function call status: " << stat.message() << std::endl;
}
{ // create table
milvus::TableSchema tb_schema = BuildTableSchema();
stat = conn->CreateTable(tb_schema);
std::cout << "CreateTable function call status: " << stat.message() << std::endl;
milvus_sdk::Utils::PrintTableSchema(tb_schema);
bool has_table = conn->HasTable(tb_schema.table_name);
if (has_table) {
std::cout << "Table is created" << std::endl;
}
}
std::vector<std::pair<int64_t, milvus::RowRecord>> search_record_array;
{ // insert vectors
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {
std::vector<milvus::RowRecord> record_array;
std::vector<int64_t> record_ids;
int64_t begin_index = i * BATCH_ROW_COUNT;
{ // generate vectors
milvus_sdk::TimeRecorder rc("Build vectors No." + std::to_string(i));
BuildBinaryVectors(begin_index,
begin_index + BATCH_ROW_COUNT,
record_array,
record_ids,
TABLE_DIMENSION);
}
if (search_record_array.size() < NQ) {
search_record_array.push_back(std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET]));
}
std::string title = "Insert " + std::to_string(record_array.size()) + " vectors No." + std::to_string(i);
milvus_sdk::TimeRecorder rc(title);
stat = conn->Insert(TABLE_NAME, "", record_array, record_ids);
std::cout << "InsertVector function call status: " << stat.message() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
}
}
milvus_sdk::Utils::Sleep(3);
{ // search vectors
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(conn, TABLE_NAME, partition_tags, TOP_K, NPROBE, search_record_array,
topk_query_result);
}
{ // wait unit build index finish
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
milvus::IndexParam index1 = BuildIndexParam();
milvus_sdk::Utils::PrintIndexParam(index1);
stat = conn->CreateIndex(index1);
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
milvus::IndexParam index2;
stat = conn->DescribeIndex(TABLE_NAME, index2);
std::cout << "DescribeIndex function call status: " << stat.message() << std::endl;
milvus_sdk::Utils::PrintIndexParam(index2);
}
{ // search vectors
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(conn, TABLE_NAME, partition_tags, TOP_K, NPROBE, search_record_array,
topk_query_result);
}
{ // drop table
stat = conn->DropTable(TABLE_NAME);
std::cout << "DropTable function call status: " << stat.message() << std::endl;
}
milvus::Connection::Destroy(conn);
}

View File

@ -0,0 +1,26 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <string>
class ClientTest {
public:
void
Test(const std::string& address, const std::string& port);
};

View File

@ -76,36 +76,27 @@ Utils::GenTableName() {
std::string
Utils::MetricTypeName(const milvus::MetricType& metric_type) {
switch (metric_type) {
case milvus::MetricType::L2:
return "L2 distance";
case milvus::MetricType::IP:
return "Inner product";
default:
return "Unknown metric type";
case milvus::MetricType::L2:return "L2 distance";
case milvus::MetricType::IP:return "Inner product";
case milvus::MetricType::HAMMING:return "Hamming distance";
case milvus::MetricType::JACCARD:return "Jaccard distance";
case milvus::MetricType::TANIMOTO:return "Tanimoto distance";
default:return "Unknown metric type";
}
}
std::string
Utils::IndexTypeName(const milvus::IndexType& index_type) {
switch (index_type) {
case milvus::IndexType::FLAT:
return "FLAT";
case milvus::IndexType::IVFFLAT:
return "IVFFLAT";
case milvus::IndexType::IVFSQ8:
return "IVFSQ8";
case milvus::IndexType::RNSG:
return "NSG";
case milvus::IndexType::IVFSQ8H:
return "IVFSQ8H";
case milvus::IndexType::IVFPQ:
return "IVFPQ";
case milvus::IndexType::SPTAGKDT:
return "SPTAGKDT";
case milvus::IndexType::SPTAGBKT:
return "SPTAGBKT";
default:
return "Unknown index type";
case milvus::IndexType::FLAT:return "FLAT";
case milvus::IndexType::IVFFLAT:return "IVFFLAT";
case milvus::IndexType::IVFSQ8:return "IVFSQ8";
case milvus::IndexType::RNSG:return "NSG";
case milvus::IndexType::IVFSQ8H:return "IVFSQ8H";
case milvus::IndexType::IVFPQ:return "IVFPQ";
case milvus::IndexType::SPTAGKDT:return "SPTAGKDT";
case milvus::IndexType::SPTAGBKT:return "SPTAGBKT";
default:return "Unknown index type";
}
}
@ -148,9 +139,9 @@ Utils::BuildVectors(int64_t from, int64_t to, std::vector<milvus::RowRecord>& ve
record_ids.clear();
for (int64_t k = from; k < to; k++) {
milvus::RowRecord record;
record.data.resize(dimension);
record.float_data.resize(dimension);
for (int64_t i = 0; i < dimension; i++) {
record.data[i] = (float)(k % (i + 1));
record.float_data[i] = (float)(k % (i + 1));
}
vector_record_array.emplace_back(record);
@ -189,11 +180,20 @@ Utils::CheckSearchResult(const std::vector<std::pair<int64_t, milvus::RowRecord>
for (size_t i = 0; i < nq; i++) {
const milvus::QueryResult& one_result = topk_query_result[i];
auto search_id = search_record_array[i].first;
int64_t result_id = one_result.ids[0];
if (result_id != search_id) {
std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl;
uint64_t match_index = one_result.ids.size();
for (uint64_t index = 0; index < one_result.ids.size(); index++) {
if (search_id == one_result.ids[index]) {
match_index = index;
break;
}
}
if (match_index >= one_result.ids.size()) {
std::cout << "The topk result is wrong: not return search target in result set" << std::endl;
} else {
std::cout << "No." << i << " Check result successfully" << std::endl;
std::cout << "No." << i << " Check result successfully for target: " << search_id << " at top "
<< match_index << std::endl;
}
}
BLOCK_SPLITER

Some files were not shown because too many files have changed in this diff Show More