mirror of https://github.com/milvus-io/milvus.git
Merge branch 'master' into caiyd_codec_opt
commit
5afef85466
19
CHANGELOG.md
19
CHANGELOG.md
|
@ -11,16 +11,16 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#805 IVFTest.gpu_seal_test unittest failed
|
||||
- \#831 Judge branch error in CommonUtil.cpp
|
||||
- \#977 Server crash when create tables concurrently
|
||||
- \#990 check gpu resources setting when assign repeated value
|
||||
- \#990 Check gpu resources setting when assign repeated value
|
||||
- \#995 table count set to 0 if no tables found
|
||||
- \#1010 improve error message when offset or page_size is equal 0
|
||||
- \#1010 Improve error message when offset or page_size is equal 0
|
||||
- \#1022 check if partition name is legal
|
||||
- \#1028 check if table exists when show partitions
|
||||
- \#1029 check if table exists when try to delete partition
|
||||
- \#1066 optimize http insert and search speed
|
||||
- \#1067 Add binary vectors support in http server
|
||||
- \#1075 improve error message when page size or offset is illegal
|
||||
- \#1082 check page_size or offset value to avoid float
|
||||
- \#1075 Improve error message when page size or offset is illegal
|
||||
- \#1082 Check page_size or offset value to avoid float
|
||||
- \#1115 http server support load table into memory
|
||||
- \#1152 Error log output continuously after server start
|
||||
- \#1211 Server down caused by searching with index_type: HNSW
|
||||
|
@ -29,18 +29,21 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#1359 Negative distance value returned when searching with HNSW index type
|
||||
- \#1429 Server crashed when searching vectors with GPU
|
||||
- \#1476 Fix vectors results bug when getting vectors from segments
|
||||
- \#1484 Index type changed to IDMAP after compacted
|
||||
- \#1484 Index type changed to IDMAP after compacted
|
||||
- \#1491 Server crashed during adding vectors
|
||||
- \#1499 Fix duplicated ID number issue
|
||||
- \#1491 Server crashed during adding vectors
|
||||
- \#1504 Avoid possible race condition between delete and search
|
||||
- \#1504 Avoid possible race condition between delete and search
|
||||
- \#1507 set_config for insert_buffer_size is wrong
|
||||
- \#1510 Add set interfaces for WAL configurations
|
||||
- \#1511 Fix big integer cannot pass to server correctly
|
||||
- \#1518 Table count did not match after deleting vectors and compact
|
||||
- \#1521 Make cache_insert_data take effect in-service
|
||||
- \#1525 Add setter API for config preload_table
|
||||
- \#1529 Fix server crash when cache_insert_data enabled
|
||||
- \#1530 Set table file with correct engine type in meta
|
||||
- \#1532 Search with ivf_flat failed with open-dataset: sift-256-hamming
|
||||
- \#1535 Degradation searching performance with metric_type: binary_idmap
|
||||
- \#1556 Index file not created after table and index created
|
||||
|
||||
## Feature
|
||||
- \#216 Add CLI to get server info
|
||||
|
@ -86,9 +89,11 @@ Please mark all change in change log and use the issue from GitHub
|
|||
- \#1320 Remove debug logging from faiss
|
||||
- \#1426 Support to configure whether to enabled autoflush and the autoflush interval
|
||||
- \#1444 Improve delete
|
||||
- \#1448 General proto api for NNS libraries
|
||||
- \#1480 Add return code for AVX512 selection
|
||||
- \#1524 Update config "preload_table" description
|
||||
- \#1537 Optimize raw vector and uids read/write
|
||||
- \#1544 Update resources name in HTTP module
|
||||
|
||||
## Task
|
||||
- \#1327 Exclude third-party code from codebeat
|
||||
|
|
|
@ -2,6 +2,7 @@ timeout(time: 60, unit: 'MINUTES') {
|
|||
dir ("tests/milvus_python_test") {
|
||||
// sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com'
|
||||
sh 'python3 -m pip install -r requirements.txt'
|
||||
sh 'python3 -m pip install git+https://github.com/BossZou/pymilvus.git@nns'
|
||||
sh "pytest . --alluredir=\"test_out/dev/single/sqlite\" --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local"
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
#include <string>
|
||||
|
||||
#include "server/Config.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace server {
|
||||
|
|
|
@ -112,18 +112,18 @@ class DB {
|
|||
|
||||
virtual Status
|
||||
QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
|
||||
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
|
||||
|
||||
virtual Status
|
||||
Size(uint64_t& result) = 0;
|
||||
|
|
|
@ -371,8 +371,10 @@ DBImpl::PreloadTable(const std::string& table_id) {
|
|||
} else {
|
||||
engine_type = (EngineType)file.engine_type_;
|
||||
}
|
||||
ExecutionEnginePtr engine = EngineFactory::Build(file.dimension_, file.location_, engine_type,
|
||||
(MetricType)file.metric_type_, file.nlist_);
|
||||
|
||||
auto json = milvus::json::parse(file.index_params_);
|
||||
ExecutionEnginePtr engine =
|
||||
EngineFactory::Build(file.dimension_, file.location_, engine_type, (MetricType)file.metric_type_, json);
|
||||
fiu_do_on("DBImpl.PreloadTable.null_engine", engine = nullptr);
|
||||
if (engine == nullptr) {
|
||||
ENGINE_LOG_ERROR << "Invalid engine type";
|
||||
|
@ -382,7 +384,7 @@ DBImpl::PreloadTable(const std::string& table_id) {
|
|||
size += engine->PhysicalSize();
|
||||
fiu_do_on("DBImpl.PreloadTable.exceed_cache", size = available_size + 1);
|
||||
if (size > available_size) {
|
||||
ENGINE_LOG_DEBUG << "Pre-load canceled since cache almost full";
|
||||
ENGINE_LOG_DEBUG << "Pre-load cancelled since cache is almost full";
|
||||
return Status(SERVER_CACHE_FULL, "Cache is full");
|
||||
} else {
|
||||
try {
|
||||
|
@ -810,7 +812,7 @@ DBImpl::CompactFile(const std::string& table_id, const meta::TableFileSchema& fi
|
|||
// Update table files state
|
||||
// if index type isn't IDMAP, set file type to TO_INDEX if file size exceed index_file_size
|
||||
// else set file type to RAW, no need to build index
|
||||
if (compacted_file.engine_type_ != (int)EngineType::FAISS_IDMAP) {
|
||||
if (!utils::IsRawIndexType(compacted_file.engine_type_)) {
|
||||
compacted_file.file_type_ = (segment_writer_ptr->Size() >= compacted_file.index_file_size_)
|
||||
? meta::TableFileSchema::TO_INDEX
|
||||
: meta::TableFileSchema::RAW;
|
||||
|
@ -1110,8 +1112,8 @@ DBImpl::DropIndex(const std::string& table_id) {
|
|||
|
||||
Status
|
||||
DBImpl::QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
|
||||
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
return SHUTDOWN_ERROR;
|
||||
}
|
||||
|
@ -1119,14 +1121,15 @@ DBImpl::QueryByID(const std::shared_ptr<server::Context>& context, const std::st
|
|||
VectorsData vectors_data = VectorsData();
|
||||
vectors_data.id_array_.emplace_back(vector_id);
|
||||
vectors_data.vector_count_ = 1;
|
||||
Status result = Query(context, table_id, partition_tags, k, nprobe, vectors_data, result_ids, result_distances);
|
||||
Status result =
|
||||
Query(context, table_id, partition_tags, k, extra_params, vectors_data, result_ids, result_distances);
|
||||
return result;
|
||||
}
|
||||
|
||||
Status
|
||||
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
auto query_ctx = context->Child("Query");
|
||||
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
|
@ -1169,7 +1172,7 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
|
|||
}
|
||||
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, extra_params, vectors, result_ids, result_distances);
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
|
||||
|
||||
query_ctx->GetTraceContext()->GetSpan()->Finish();
|
||||
|
@ -1179,8 +1182,8 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
|
|||
|
||||
Status
|
||||
DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
auto query_ctx = context->Child("Query by file id");
|
||||
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
|
@ -1208,7 +1211,7 @@ DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std
|
|||
}
|
||||
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
|
||||
status = QueryAsync(query_ctx, table_id, files_array, k, extra_params, vectors, result_ids, result_distances);
|
||||
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
|
||||
|
||||
query_ctx->GetTraceContext()->GetSpan()->Finish();
|
||||
|
@ -1230,8 +1233,8 @@ DBImpl::Size(uint64_t& result) {
|
|||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Status
|
||||
DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
const meta::TableFilesSchema& files, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
|
||||
auto query_async_ctx = context->Child("Query Async");
|
||||
|
||||
server::CollectQueryMetrics metrics(vectors.vector_count_);
|
||||
|
@ -1242,7 +1245,7 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::s
|
|||
auto status = OngoingFileChecker::GetInstance().MarkOngoingFiles(files);
|
||||
|
||||
ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size();
|
||||
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nprobe, vectors);
|
||||
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, extra_params, vectors);
|
||||
for (auto& file : files) {
|
||||
scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
|
||||
job->AddIndexFile(file_ptr);
|
||||
|
@ -1465,7 +1468,7 @@ DBImpl::MergeFiles(const std::string& table_id, const meta::TableFilesSchema& fi
|
|||
// step 4: update table files state
|
||||
// if index type isn't IDMAP, set file type to TO_INDEX if file size exceed index_file_size
|
||||
// else set file type to RAW, no need to build index
|
||||
if (table_file.engine_type_ != (int)EngineType::FAISS_IDMAP) {
|
||||
if (!utils::IsRawIndexType(table_file.engine_type_)) {
|
||||
table_file.file_type_ = (segment_writer_ptr->Size() >= table_file.index_file_size_)
|
||||
? meta::TableFileSchema::TO_INDEX
|
||||
: meta::TableFileSchema::RAW;
|
||||
|
@ -1767,7 +1770,7 @@ DBImpl::BuildTableIndexRecursively(const std::string& table_id, const TableIndex
|
|||
// for IDMAP type, only wait all NEW file converted to RAW file
|
||||
// for other type, wait NEW/RAW/NEW_MERGE/NEW_INDEX/TO_INDEX files converted to INDEX files
|
||||
std::vector<int> file_types;
|
||||
if (index.engine_type_ == static_cast<int32_t>(EngineType::FAISS_IDMAP)) {
|
||||
if (utils::IsRawIndexType(index.engine_type_)) {
|
||||
file_types = {
|
||||
static_cast<int32_t>(meta::TableFileSchema::NEW),
|
||||
static_cast<int32_t>(meta::TableFileSchema::NEW_MERGE),
|
||||
|
@ -1789,7 +1792,7 @@ DBImpl::BuildTableIndexRecursively(const std::string& table_id, const TableIndex
|
|||
|
||||
while (!table_files.empty()) {
|
||||
ENGINE_LOG_DEBUG << "Non index files detected! Will build index " << times;
|
||||
if (index.engine_type_ != (int)EngineType::FAISS_IDMAP) {
|
||||
if (!utils::IsRawIndexType(index.engine_type_)) {
|
||||
status = meta_ptr_->UpdateTableFilesToIndex(table_id);
|
||||
}
|
||||
|
||||
|
|
|
@ -131,18 +131,18 @@ class DBImpl : public DB, public server::CacheConfigHandler {
|
|||
|
||||
Status
|
||||
QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
|
||||
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
|
||||
|
||||
Status
|
||||
Size(uint64_t& result) override;
|
||||
|
@ -154,8 +154,8 @@ class DBImpl : public DB, public server::CacheConfigHandler {
|
|||
private:
|
||||
Status
|
||||
QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
|
||||
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
|
||||
ResultIds& result_ids, ResultDistances& result_distances);
|
||||
const meta::TableFilesSchema& files, uint64_t k, const milvus::json& extra_params,
|
||||
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances);
|
||||
|
||||
Status
|
||||
GetVectorByIdHelper(const std::string& table_id, IDNumber vector_id, VectorsData& vector,
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "db/engine/ExecutionEngine.h"
|
||||
#include "segment/Types.h"
|
||||
#include "utils/Json.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
@ -35,8 +36,8 @@ typedef std::vector<faiss::Index::distance_t> ResultDistances;
|
|||
|
||||
struct TableIndex {
|
||||
int32_t engine_type_ = (int)EngineType::FAISS_IDMAP;
|
||||
int32_t nlist_ = 16384;
|
||||
int32_t metric_type_ = (int)MetricType::L2;
|
||||
milvus::json extra_params_ = {{"nlist", 16384}};
|
||||
};
|
||||
|
||||
struct VectorsData {
|
||||
|
|
|
@ -211,10 +211,15 @@ GetParentPath(const std::string& path, std::string& parent_path) {
|
|||
|
||||
bool
|
||||
IsSameIndex(const TableIndex& index1, const TableIndex& index2) {
|
||||
return index1.engine_type_ == index2.engine_type_ && index1.nlist_ == index2.nlist_ &&
|
||||
return index1.engine_type_ == index2.engine_type_ && index1.extra_params_ == index2.extra_params_ &&
|
||||
index1.metric_type_ == index2.metric_type_;
|
||||
}
|
||||
|
||||
bool
|
||||
IsRawIndexType(int32_t type) {
|
||||
return (type == (int32_t)EngineType::FAISS_IDMAP) || (type == (int32_t)EngineType::FAISS_BIN_IDMAP);
|
||||
}
|
||||
|
||||
meta::DateT
|
||||
GetDate(const std::time_t& t, int day_delta) {
|
||||
struct tm ltm;
|
||||
|
|
|
@ -45,6 +45,9 @@ GetParentPath(const std::string& path, std::string& parent_path);
|
|||
bool
|
||||
IsSameIndex(const TableIndex& index1, const TableIndex& index2);
|
||||
|
||||
bool
|
||||
IsRawIndexType(int32_t type);
|
||||
|
||||
meta::DateT
|
||||
GetDate(const std::time_t& t, int day_delta = 0);
|
||||
meta::DateT
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace engine {
|
|||
|
||||
ExecutionEnginePtr
|
||||
EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
|
||||
int32_t nlist) {
|
||||
const milvus::json& index_params) {
|
||||
if (index_type == EngineType::INVALID) {
|
||||
ENGINE_LOG_ERROR << "Unsupported engine type";
|
||||
return nullptr;
|
||||
|
@ -28,7 +28,7 @@ EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType
|
|||
|
||||
ENGINE_LOG_DEBUG << "EngineFactory index type: " << (int)index_type;
|
||||
ExecutionEnginePtr execution_engine_ptr =
|
||||
std::make_shared<ExecutionEngineImpl>(dimension, location, index_type, metric_type, nlist);
|
||||
std::make_shared<ExecutionEngineImpl>(dimension, location, index_type, metric_type, index_params);
|
||||
|
||||
execution_engine_ptr->Init();
|
||||
return execution_engine_ptr;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "ExecutionEngine.h"
|
||||
#include "utils/Json.h"
|
||||
#include "utils/Status.h"
|
||||
|
||||
#include <string>
|
||||
|
@ -23,7 +24,7 @@ class EngineFactory {
|
|||
public:
|
||||
static ExecutionEnginePtr
|
||||
Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
|
||||
int32_t nlist);
|
||||
const milvus::json& index_params);
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/Json.h"
|
||||
#include "utils/Status.h"
|
||||
|
||||
namespace milvus {
|
||||
|
@ -94,15 +95,16 @@ class ExecutionEngine {
|
|||
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) = 0;
|
||||
|
||||
virtual Status
|
||||
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels, bool hybrid) = 0;
|
||||
|
||||
virtual Status
|
||||
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
|
||||
bool hybrid) = 0;
|
||||
|
||||
virtual Status
|
||||
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid) = 0;
|
||||
Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params, float* distances,
|
||||
int64_t* labels, bool hybrid) = 0;
|
||||
|
||||
virtual Status
|
||||
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params, float* distances,
|
||||
int64_t* labels, bool hybrid) = 0;
|
||||
|
||||
virtual std::shared_ptr<ExecutionEngine>
|
||||
BuildIndex(const std::string& location, EngineType engine_type) = 0;
|
||||
|
|
|
@ -43,22 +43,22 @@ namespace engine {
|
|||
namespace {
|
||||
|
||||
Status
|
||||
MappingMetricType(MetricType metric_type, knowhere::METRICTYPE& kw_type) {
|
||||
MappingMetricType(MetricType metric_type, milvus::json& conf) {
|
||||
switch (metric_type) {
|
||||
case MetricType::IP:
|
||||
kw_type = knowhere::METRICTYPE::IP;
|
||||
conf[knowhere::Metric::TYPE] = knowhere::Metric::IP;
|
||||
break;
|
||||
case MetricType::L2:
|
||||
kw_type = knowhere::METRICTYPE::L2;
|
||||
conf[knowhere::Metric::TYPE] = knowhere::Metric::L2;
|
||||
break;
|
||||
case MetricType::HAMMING:
|
||||
kw_type = knowhere::METRICTYPE::HAMMING;
|
||||
conf[knowhere::Metric::TYPE] = knowhere::Metric::HAMMING;
|
||||
break;
|
||||
case MetricType::JACCARD:
|
||||
kw_type = knowhere::METRICTYPE::JACCARD;
|
||||
conf[knowhere::Metric::TYPE] = knowhere::Metric::JACCARD;
|
||||
break;
|
||||
case MetricType::TANIMOTO:
|
||||
kw_type = knowhere::METRICTYPE::TANIMOTO;
|
||||
conf[knowhere::Metric::TYPE] = knowhere::Metric::TANIMOTO;
|
||||
break;
|
||||
default:
|
||||
return Status(DB_ERROR, "Unsupported metric type");
|
||||
|
@ -94,8 +94,12 @@ class CachedQuantizer : public cache::DataObj {
|
|||
};
|
||||
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
|
||||
MetricType metric_type, int32_t nlist)
|
||||
: location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
|
||||
MetricType metric_type, const milvus::json& index_params)
|
||||
: location_(location),
|
||||
dim_(dimension),
|
||||
index_type_(index_type),
|
||||
metric_type_(metric_type),
|
||||
index_params_(index_params) {
|
||||
EngineType tmp_index_type = server::ValidationUtil::IsBinaryMetricType((int32_t)metric_type)
|
||||
? EngineType::FAISS_BIN_IDMAP
|
||||
: EngineType::FAISS_IDMAP;
|
||||
|
@ -104,16 +108,15 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
|
|||
throw Exception(DB_ERROR, "Unsupported index type");
|
||||
}
|
||||
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.gpu_id = gpu_num_;
|
||||
temp_conf.dim = dimension;
|
||||
auto status = MappingMetricType(metric_type, temp_conf.metric_type);
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
||||
milvus::json conf = index_params;
|
||||
conf[knowhere::meta::DEVICEID] = gpu_num_;
|
||||
conf[knowhere::meta::DIM] = dimension;
|
||||
MappingMetricType(metric_type, conf);
|
||||
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->Match(temp_conf);
|
||||
if (!adapter->CheckTrain(conf)) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
|
||||
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
|
||||
if (auto bf_index = std::dynamic_pointer_cast<BFIndex>(index_)) {
|
||||
|
@ -127,8 +130,12 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
|
|||
}
|
||||
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type,
|
||||
MetricType metric_type, int32_t nlist)
|
||||
: index_(std::move(index)), location_(location), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
|
||||
MetricType metric_type, const milvus::json& index_params)
|
||||
: index_(std::move(index)),
|
||||
location_(location),
|
||||
index_type_(index_type),
|
||||
metric_type_(metric_type),
|
||||
index_params_(index_params) {
|
||||
}
|
||||
|
||||
VecIndexPtr
|
||||
|
@ -273,10 +280,9 @@ ExecutionEngineImpl::HybridLoad() const {
|
|||
auto best_index = std::distance(all_free_mem.begin(), max_e);
|
||||
auto best_device_id = gpus[best_index];
|
||||
|
||||
auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
|
||||
quantizer_conf->mode = 1;
|
||||
quantizer_conf->gpu_id = best_device_id;
|
||||
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, best_device_id}, {"mode", 1}};
|
||||
auto quantizer = index_->LoadQuantizer(quantizer_conf);
|
||||
ENGINE_LOG_DEBUG << "Quantizer params: " << quantizer_conf.dump();
|
||||
if (quantizer == nullptr) {
|
||||
ENGINE_LOG_ERROR << "quantizer is nullptr";
|
||||
}
|
||||
|
@ -400,22 +406,18 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
utils::GetParentPath(location_, segment_dir);
|
||||
auto segment_reader_ptr = std::make_shared<segment::SegmentReader>(segment_dir);
|
||||
|
||||
if (index_type_ == EngineType::FAISS_IDMAP || index_type_ == EngineType::FAISS_BIN_IDMAP) {
|
||||
if (utils::IsRawIndexType((int32_t)index_type_)) {
|
||||
index_ = index_type_ == EngineType::FAISS_IDMAP ? GetVecIndexFactory(IndexType::FAISS_IDMAP)
|
||||
: GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
|
||||
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.gpu_id = gpu_num_;
|
||||
temp_conf.dim = dim_;
|
||||
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
milvus::json conf{{knowhere::meta::DEVICEID, gpu_num_}, {knowhere::meta::DIM, dim_}};
|
||||
MappingMetricType(metric_type_, conf);
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
|
||||
if (!adapter->CheckTrain(conf)) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->Match(temp_conf);
|
||||
|
||||
status = segment_reader_ptr->Load();
|
||||
auto status = segment_reader_ptr->Load();
|
||||
if (!status.ok()) {
|
||||
std::string msg = "Failed to load segment from " + location_;
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
|
@ -429,6 +431,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
|
||||
auto vectors_uids = vectors->GetUids();
|
||||
index_->SetUids(vectors_uids);
|
||||
ENGINE_LOG_DEBUG << "set uids " << index_->GetUids().size() << " for index " << location_;
|
||||
|
||||
auto vectors_data = vectors->GetData();
|
||||
|
||||
|
@ -453,7 +456,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
float_vectors.data(), Config());
|
||||
status = std::static_pointer_cast<BFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
int64_t index_size = vectors->GetCount() * conf->d * sizeof(float);
|
||||
int64_t index_size = vectors->GetCount() * dim_ * sizeof(float);
|
||||
int64_t bitset_size = vectors->GetCount() / 8;
|
||||
index_->set_size(index_size + bitset_size);
|
||||
} else if (index_type_ == EngineType::FAISS_BIN_IDMAP) {
|
||||
|
@ -465,7 +468,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
vectors_data.data(), Config());
|
||||
status = std::static_pointer_cast<BinBFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
int64_t index_size = vectors->GetCount() * conf->d * sizeof(uint8_t);
|
||||
int64_t index_size = vectors->GetCount() * dim_ * sizeof(uint8_t);
|
||||
int64_t bitset_size = vectors->GetCount() / 8;
|
||||
index_->set_size(index_size + bitset_size);
|
||||
}
|
||||
|
@ -508,6 +511,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
|||
std::vector<segment::doc_id_t> uids;
|
||||
segment_reader_ptr->LoadUids(uids);
|
||||
index_->SetUids(uids);
|
||||
ENGINE_LOG_DEBUG << "set uids " << index_->GetUids().size() << " for index " << location_;
|
||||
|
||||
ENGINE_LOG_DEBUG << "Finished loading index file from segment " << segment_dir;
|
||||
}
|
||||
|
@ -548,9 +552,7 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
|
|||
|
||||
if (device_id != NOT_FOUND) {
|
||||
// cache hit
|
||||
auto config = std::make_shared<knowhere::QuantizerCfg>();
|
||||
config->gpu_id = device_id;
|
||||
config->mode = 2;
|
||||
milvus::json quantizer_conf{{knowhere::meta::DEVICEID : device_id}, {"mode" : 2}};
|
||||
auto new_index = index_->LoadData(quantizer, config);
|
||||
index_ = new_index;
|
||||
}
|
||||
|
@ -723,30 +725,36 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
|||
throw Exception(DB_ERROR, "Unsupported index type");
|
||||
}
|
||||
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.gpu_id = gpu_num_;
|
||||
temp_conf.dim = Dimension();
|
||||
temp_conf.nlist = nlist_;
|
||||
temp_conf.size = Count();
|
||||
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
||||
milvus::json conf = index_params_;
|
||||
conf[knowhere::meta::DIM] = Dimension();
|
||||
conf[knowhere::meta::ROWS] = Count();
|
||||
conf[knowhere::meta::DEVICEID] = gpu_num_;
|
||||
MappingMetricType(metric_type_, conf);
|
||||
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType());
|
||||
auto conf = adapter->Match(temp_conf);
|
||||
if (!adapter->CheckTrain(conf)) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
ENGINE_LOG_DEBUG << "Index config: " << conf.dump();
|
||||
|
||||
auto status = Status::OK();
|
||||
std::vector<segment::doc_id_t> uids;
|
||||
if (from_index) {
|
||||
status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
|
||||
uids = from_index->GetUids();
|
||||
} else if (bin_from_index) {
|
||||
status = to_index->BuildAll(Count(), bin_from_index->GetRawVectors(), bin_from_index->GetRawIds(), conf);
|
||||
uids = bin_from_index->GetUids();
|
||||
}
|
||||
to_index->SetUids(uids);
|
||||
ENGINE_LOG_DEBUG << "set uids " << to_index->GetUids().size() << " for " << location;
|
||||
|
||||
if (!status.ok()) {
|
||||
throw Exception(DB_ERROR, status.message());
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Finish build index file: " << location << " size: " << to_index->Size();
|
||||
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
|
||||
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, index_params_);
|
||||
}
|
||||
|
||||
// map offsets to ids
|
||||
|
@ -761,8 +769,8 @@ MapUids(const std::vector<segment::doc_id_t>& uids, int64_t* labels, size_t num)
|
|||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid) {
|
||||
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances,
|
||||
int64_t* labels, bool hybrid) {
|
||||
#if 0
|
||||
if (index_type_ == EngineType::FAISS_IVFSQ8H) {
|
||||
if (!hybrid) {
|
||||
|
@ -786,9 +794,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
|
|||
|
||||
if (device_id != NOT_FOUND) {
|
||||
// cache hit
|
||||
auto config = std::make_shared<knowhere::QuantizerCfg>();
|
||||
config->gpu_id = device_id;
|
||||
config->mode = 2;
|
||||
milvus::json quantizer_conf{{knowhere::meta::DEVICEID : device_id}, {"mode" : 2}};
|
||||
auto new_index = index_->LoadData(quantizer, config);
|
||||
index_ = new_index;
|
||||
}
|
||||
|
@ -824,15 +830,13 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
|
|||
return Status(DB_ERROR, "index is null");
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
|
||||
|
||||
// TODO(linxj): remove here. Get conf from function
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.k = k;
|
||||
temp_conf.nprobe = nprobe;
|
||||
|
||||
milvus::json conf = extra_params;
|
||||
conf[knowhere::meta::TOPK] = k;
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
|
||||
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
|
||||
if (!adapter->CheckSearch(conf, index_->GetType())) {
|
||||
throw Exception(DB_ERROR, "Illegal search params");
|
||||
}
|
||||
|
||||
if (hybrid) {
|
||||
HybridLoad();
|
||||
|
@ -843,6 +847,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
|
|||
rc.RecordSection("search done");
|
||||
|
||||
// map offsets to ids
|
||||
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
|
||||
MapUids(index_->GetUids(), labels, n * k);
|
||||
|
||||
rc.RecordSection("map uids " + std::to_string(n * k));
|
||||
|
@ -858,8 +863,8 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
|
|||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances,
|
||||
int64_t* labels, bool hybrid) {
|
||||
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params,
|
||||
float* distances, int64_t* labels, bool hybrid) {
|
||||
TimeRecorder rc("ExecutionEngineImpl::Search uint8");
|
||||
|
||||
if (index_ == nullptr) {
|
||||
|
@ -867,15 +872,13 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
|
|||
return Status(DB_ERROR, "index is null");
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
|
||||
|
||||
// TODO(linxj): remove here. Get conf from function
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.k = k;
|
||||
temp_conf.nprobe = nprobe;
|
||||
|
||||
milvus::json conf = extra_params;
|
||||
conf[knowhere::meta::TOPK] = k;
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
|
||||
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
|
||||
if (!adapter->CheckSearch(conf, index_->GetType())) {
|
||||
throw Exception(DB_ERROR, "Illegal search params");
|
||||
}
|
||||
|
||||
if (hybrid) {
|
||||
HybridLoad();
|
||||
|
@ -886,6 +889,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
|
|||
rc.RecordSection("search done");
|
||||
|
||||
// map offsets to ids
|
||||
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
|
||||
MapUids(index_->GetUids(), labels, n * k);
|
||||
|
||||
rc.RecordSection("map uids " + std::to_string(n * k));
|
||||
|
@ -901,8 +905,8 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
|
|||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances,
|
||||
int64_t* labels, bool hybrid) {
|
||||
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params,
|
||||
float* distances, int64_t* labels, bool hybrid) {
|
||||
TimeRecorder rc("ExecutionEngineImpl::Search vector of ids");
|
||||
|
||||
if (index_ == nullptr) {
|
||||
|
@ -910,15 +914,13 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
|
|||
return Status(DB_ERROR, "index is null");
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Search by ids Params: [k] " << k << " [nprobe] " << nprobe;
|
||||
|
||||
// TODO(linxj): remove here. Get conf from function
|
||||
TempMetaConf temp_conf;
|
||||
temp_conf.k = k;
|
||||
temp_conf.nprobe = nprobe;
|
||||
|
||||
milvus::json conf = extra_params;
|
||||
conf[knowhere::meta::TOPK] = k;
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
|
||||
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
|
||||
if (!adapter->CheckSearch(conf, index_->GetType())) {
|
||||
throw Exception(DB_ERROR, "Illegal search params");
|
||||
}
|
||||
|
||||
if (hybrid) {
|
||||
HybridLoad();
|
||||
|
@ -971,6 +973,7 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
|
|||
rc.RecordSection("search done");
|
||||
|
||||
// map offsets to ids
|
||||
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
|
||||
MapUids(uids, labels, offsets.size() * k);
|
||||
|
||||
rc.RecordSection("map uids " + std::to_string(offsets.size() * k));
|
||||
|
@ -993,19 +996,13 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, float* vector, bool hybrid
|
|||
return Status(DB_ERROR, "index is null");
|
||||
}
|
||||
|
||||
// TODO(linxj): remove here. Get conf from function
|
||||
TempMetaConf temp_conf;
|
||||
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
|
||||
|
||||
if (hybrid) {
|
||||
HybridLoad();
|
||||
}
|
||||
|
||||
// Only one id for now
|
||||
std::vector<int64_t> ids{id};
|
||||
auto status = index_->GetVectorById(1, ids.data(), vector, conf);
|
||||
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
|
@ -1026,19 +1023,13 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, uint8_t* vector, bool hybr
|
|||
|
||||
ENGINE_LOG_DEBUG << "Get binary vector by id: " << id;
|
||||
|
||||
// TODO(linxj): remove here. Get conf from function
|
||||
TempMetaConf temp_conf;
|
||||
|
||||
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
|
||||
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
|
||||
|
||||
if (hybrid) {
|
||||
HybridLoad();
|
||||
}
|
||||
|
||||
// Only one id for now
|
||||
std::vector<int64_t> ids{id};
|
||||
auto status = index_->GetVectorById(1, ids.data(), vector, conf);
|
||||
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
|
||||
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
|
@ -1075,7 +1066,7 @@ ExecutionEngineImpl::Init() {
|
|||
std::vector<int64_t> gpu_ids;
|
||||
Status s = config.GetGpuResourceConfigBuildIndexResources(gpu_ids);
|
||||
if (!s.ok()) {
|
||||
gpu_num_ = knowhere::INVALID_VALUE;
|
||||
gpu_num_ = -1;
|
||||
return s;
|
||||
}
|
||||
for (auto id : gpu_ids) {
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <src/segment/SegmentReader.h>
|
||||
#include "segment/SegmentReader.h"
|
||||
#include "utils/Json.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,10 +27,10 @@ namespace engine {
|
|||
class ExecutionEngineImpl : public ExecutionEngine {
|
||||
public:
|
||||
ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
|
||||
int32_t nlist);
|
||||
const milvus::json& index_params);
|
||||
|
||||
ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type, MetricType metric_type,
|
||||
int32_t nlist);
|
||||
const milvus::json& index_params);
|
||||
|
||||
Status
|
||||
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override;
|
||||
|
@ -77,16 +78,16 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) override;
|
||||
|
||||
Status
|
||||
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
|
||||
bool hybrid = false) override;
|
||||
|
||||
Status
|
||||
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid = false) override;
|
||||
Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params, float* distances,
|
||||
int64_t* labels, bool hybrid = false) override;
|
||||
|
||||
Status
|
||||
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
|
||||
bool hybrid) override;
|
||||
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params, float* distances,
|
||||
int64_t* labels, bool hybrid) override;
|
||||
|
||||
ExecutionEnginePtr
|
||||
BuildIndex(const std::string& location, EngineType engine_type) override;
|
||||
|
@ -136,7 +137,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
|||
int64_t dim_;
|
||||
std::string location_;
|
||||
|
||||
int64_t nlist_ = 0;
|
||||
milvus::json index_params_;
|
||||
int64_t gpu_num_ = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ struct TableSchema {
|
|||
int64_t flag_ = 0;
|
||||
int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE;
|
||||
int32_t engine_type_ = DEFAULT_ENGINE_TYPE;
|
||||
int32_t nlist_ = DEFAULT_NLIST;
|
||||
std::string index_params_ = "{ \"nlist\": 16384 }";
|
||||
int32_t metric_type_ = DEFAULT_METRIC_TYPE;
|
||||
std::string owner_table_;
|
||||
std::string partition_tag_;
|
||||
|
@ -89,7 +89,7 @@ struct TableFileSchema {
|
|||
int64_t created_on_ = 0;
|
||||
int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE; // not persist to meta
|
||||
int32_t engine_type_ = DEFAULT_ENGINE_TYPE;
|
||||
int32_t nlist_ = DEFAULT_NLIST; // not persist to meta
|
||||
std::string index_params_; // not persist to meta
|
||||
int32_t metric_type_ = DEFAULT_METRIC_TYPE; // not persist to meta
|
||||
uint64_t flush_lsn_ = 0;
|
||||
}; // TableFileSchema
|
||||
|
|
|
@ -144,7 +144,7 @@ static const MetaSchema TABLES_SCHEMA(META_TABLES, {
|
|||
MetaField("flag", "BIGINT", "DEFAULT 0 NOT NULL"),
|
||||
MetaField("index_file_size", "BIGINT", "DEFAULT 1024 NOT NULL"),
|
||||
MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"),
|
||||
MetaField("nlist", "INT", "DEFAULT 16384 NOT NULL"),
|
||||
MetaField("index_params", "VARCHAR(512)", "NOT NULL"),
|
||||
MetaField("metric_type", "INT", "DEFAULT 1 NOT NULL"),
|
||||
MetaField("owner_table", "VARCHAR(255)", "NOT NULL"),
|
||||
MetaField("partition_tag", "VARCHAR(255)", "NOT NULL"),
|
||||
|
@ -398,7 +398,7 @@ MySQLMetaImpl::CreateTable(TableSchema& table_schema) {
|
|||
std::string flag = std::to_string(table_schema.flag_);
|
||||
std::string index_file_size = std::to_string(table_schema.index_file_size_);
|
||||
std::string engine_type = std::to_string(table_schema.engine_type_);
|
||||
std::string nlist = std::to_string(table_schema.nlist_);
|
||||
std::string& index_params = table_schema.index_params_;
|
||||
std::string metric_type = std::to_string(table_schema.metric_type_);
|
||||
std::string& owner_table = table_schema.owner_table_;
|
||||
std::string& partition_tag = table_schema.partition_tag_;
|
||||
|
@ -407,9 +407,9 @@ MySQLMetaImpl::CreateTable(TableSchema& table_schema) {
|
|||
|
||||
createTableQuery << "INSERT INTO " << META_TABLES << " VALUES(" << id << ", " << mysqlpp::quote << table_id
|
||||
<< ", " << state << ", " << dimension << ", " << created_on << ", " << flag << ", "
|
||||
<< index_file_size << ", " << engine_type << ", " << nlist << ", " << metric_type << ", "
|
||||
<< mysqlpp::quote << owner_table << ", " << mysqlpp::quote << partition_tag << ", "
|
||||
<< mysqlpp::quote << version << ", " << flush_lsn << ");";
|
||||
<< index_file_size << ", " << engine_type << ", " << mysqlpp::quote << index_params << ", "
|
||||
<< metric_type << ", " << mysqlpp::quote << owner_table << ", " << mysqlpp::quote
|
||||
<< partition_tag << ", " << mysqlpp::quote << version << ", " << flush_lsn << ");";
|
||||
|
||||
ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTable: " << createTableQuery.str();
|
||||
|
||||
|
@ -446,8 +446,8 @@ MySQLMetaImpl::DescribeTable(TableSchema& table_schema) {
|
|||
|
||||
mysqlpp::Query describeTableQuery = connectionPtr->query();
|
||||
describeTableQuery
|
||||
<< "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, nlist, metric_type"
|
||||
<< " ,owner_table, partition_tag, version, flush_lsn"
|
||||
<< "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, index_params"
|
||||
<< " , metric_type ,owner_table, partition_tag, version, flush_lsn"
|
||||
<< " FROM " << META_TABLES << " WHERE table_id = " << mysqlpp::quote << table_schema.table_id_
|
||||
<< " AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";";
|
||||
|
||||
|
@ -465,7 +465,7 @@ MySQLMetaImpl::DescribeTable(TableSchema& table_schema) {
|
|||
table_schema.flag_ = resRow["flag"];
|
||||
table_schema.index_file_size_ = resRow["index_file_size"];
|
||||
table_schema.engine_type_ = resRow["engine_type"];
|
||||
table_schema.nlist_ = resRow["nlist"];
|
||||
resRow["index_params"].to_string(table_schema.index_params_);
|
||||
table_schema.metric_type_ = resRow["metric_type"];
|
||||
resRow["owner_table"].to_string(table_schema.owner_table_);
|
||||
resRow["partition_tag"].to_string(table_schema.partition_tag_);
|
||||
|
@ -534,7 +534,7 @@ MySQLMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
|
|||
}
|
||||
|
||||
mysqlpp::Query allTablesQuery = connectionPtr->query();
|
||||
allTablesQuery << "SELECT id, table_id, dimension, engine_type, nlist, index_file_size, metric_type"
|
||||
allTablesQuery << "SELECT id, table_id, dimension, engine_type, index_params, index_file_size, metric_type"
|
||||
<< " ,owner_table, partition_tag, version, flush_lsn"
|
||||
<< " FROM " << META_TABLES << " WHERE state <> " << std::to_string(TableSchema::TO_DELETE)
|
||||
<< " AND owner_table = \"\";";
|
||||
|
@ -551,7 +551,7 @@ MySQLMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
|
|||
table_schema.dimension_ = resRow["dimension"];
|
||||
table_schema.index_file_size_ = resRow["index_file_size"];
|
||||
table_schema.engine_type_ = resRow["engine_type"];
|
||||
table_schema.nlist_ = resRow["nlist"];
|
||||
resRow["index_params"].to_string(table_schema.index_params_);
|
||||
table_schema.metric_type_ = resRow["metric_type"];
|
||||
resRow["owner_table"].to_string(table_schema.owner_table_);
|
||||
resRow["partition_tag"].to_string(table_schema.partition_tag_);
|
||||
|
@ -673,17 +673,8 @@ MySQLMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
|
|||
file_schema.created_on_ = utils::GetMicroSecTimeStamp();
|
||||
file_schema.updated_time_ = file_schema.created_on_;
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
|
||||
if (file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW ||
|
||||
file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW_MERGE) {
|
||||
file_schema.engine_type_ = server::ValidationUtil::IsBinaryMetricType(table_schema.metric_type_)
|
||||
? (int32_t)EngineType::FAISS_BIN_IDMAP
|
||||
: (int32_t)EngineType::FAISS_IDMAP;
|
||||
} else {
|
||||
file_schema.engine_type_ = table_schema.engine_type_;
|
||||
}
|
||||
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.engine_type_ = table_schema.engine_type_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
std::string id = "NULL"; // auto-increment
|
||||
|
@ -785,7 +776,7 @@ MySQLMetaImpl::GetTableFiles(const std::string& table_id, const std::vector<size
|
|||
resRow["segment_id"].to_string(file_schema.segment_id_);
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
file_schema.engine_type_ = resRow["engine_type"];
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
resRow["file_id"].to_string(file_schema.file_id_);
|
||||
file_schema.file_type_ = resRow["file_type"];
|
||||
|
@ -844,7 +835,7 @@ MySQLMetaImpl::GetTableFilesBySegmentId(const std::string& segment_id,
|
|||
resRow["segment_id"].to_string(file_schema.segment_id_);
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
file_schema.engine_type_ = resRow["engine_type"];
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
resRow["file_id"].to_string(file_schema.file_id_);
|
||||
file_schema.file_type_ = resRow["file_type"];
|
||||
|
@ -900,7 +891,8 @@ MySQLMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex& i
|
|||
|
||||
updateTableIndexParamQuery << "UPDATE " << META_TABLES << " SET id = " << id << " ,state = " << state
|
||||
<< " ,dimension = " << dimension << " ,created_on = " << created_on
|
||||
<< " ,engine_type = " << index.engine_type_ << " ,nlist = " << index.nlist_
|
||||
<< " ,engine_type = " << index.engine_type_
|
||||
<< " ,index_params = " << mysqlpp::quote << index.extra_params_.dump()
|
||||
<< " ,metric_type = " << index.metric_type_
|
||||
<< " WHERE table_id = " << mysqlpp::quote << table_id << ";";
|
||||
|
||||
|
@ -1044,7 +1036,7 @@ MySQLMetaImpl::GetTableFilesByFlushLSN(uint64_t flush_lsn, TableFilesSchema& tab
|
|||
}
|
||||
table_file.dimension_ = groups[table_file.table_id_].dimension_;
|
||||
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
|
||||
table_file.nlist_ = groups[table_file.table_id_].nlist_;
|
||||
table_file.index_params_ = groups[table_file.table_id_].index_params_;
|
||||
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
|
||||
|
||||
auto status = utils::GetTableFilePath(options_, table_file);
|
||||
|
@ -1263,7 +1255,7 @@ MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index
|
|||
}
|
||||
|
||||
mysqlpp::Query describeTableIndexQuery = connectionPtr->query();
|
||||
describeTableIndexQuery << "SELECT engine_type, nlist, index_file_size, metric_type"
|
||||
describeTableIndexQuery << "SELECT engine_type, index_params, index_file_size, metric_type"
|
||||
<< " FROM " << META_TABLES << " WHERE table_id = " << mysqlpp::quote << table_id
|
||||
<< " AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";";
|
||||
|
||||
|
@ -1275,7 +1267,9 @@ MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index
|
|||
const mysqlpp::Row& resRow = res[0];
|
||||
|
||||
index.engine_type_ = resRow["engine_type"];
|
||||
index.nlist_ = resRow["nlist"];
|
||||
std::string str_index_params;
|
||||
resRow["index_params"].to_string(str_index_params);
|
||||
index.extra_params_ = milvus::json::parse(str_index_params);
|
||||
index.metric_type_ = resRow["metric_type"];
|
||||
} else {
|
||||
return Status(DB_NOT_FOUND, "Table " + table_id + " not found");
|
||||
|
@ -1334,7 +1328,7 @@ MySQLMetaImpl::DropTableIndex(const std::string& table_id) {
|
|||
// set table index type to raw
|
||||
dropTableIndexQuery << "UPDATE " << META_TABLES
|
||||
<< " SET engine_type = " << std::to_string(DEFAULT_ENGINE_TYPE)
|
||||
<< " ,nlist = " << std::to_string(DEFAULT_NLIST)
|
||||
<< " , index_params = '{}'"
|
||||
<< " WHERE table_id = " << mysqlpp::quote << table_id << ";";
|
||||
|
||||
ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropTableIndex: " << dropTableIndexQuery.str();
|
||||
|
@ -1426,7 +1420,7 @@ MySQLMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Tab
|
|||
|
||||
mysqlpp::Query allPartitionsQuery = connectionPtr->query();
|
||||
allPartitionsQuery << "SELECT table_id, id, state, dimension, created_on, flag, index_file_size,"
|
||||
<< " engine_type, nlist, metric_type, partition_tag, version FROM " << META_TABLES
|
||||
<< " engine_type, index_params, metric_type, partition_tag, version FROM " << META_TABLES
|
||||
<< " WHERE owner_table = " << mysqlpp::quote << table_id << " AND state <> "
|
||||
<< std::to_string(TableSchema::TO_DELETE) << ";";
|
||||
|
||||
|
@ -1445,7 +1439,7 @@ MySQLMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Tab
|
|||
partition_schema.flag_ = resRow["flag"];
|
||||
partition_schema.index_file_size_ = resRow["index_file_size"];
|
||||
partition_schema.engine_type_ = resRow["engine_type"];
|
||||
partition_schema.nlist_ = resRow["nlist"];
|
||||
resRow["index_params"].to_string(partition_schema.index_params_);
|
||||
partition_schema.metric_type_ = resRow["metric_type"];
|
||||
partition_schema.owner_table_ = table_id;
|
||||
resRow["partition_tag"].to_string(partition_schema.partition_tag_);
|
||||
|
@ -1562,7 +1556,7 @@ MySQLMetaImpl::FilesToSearch(const std::string& table_id, const std::vector<size
|
|||
resRow["segment_id"].to_string(table_file.segment_id_);
|
||||
table_file.index_file_size_ = table_schema.index_file_size_;
|
||||
table_file.engine_type_ = resRow["engine_type"];
|
||||
table_file.nlist_ = table_schema.nlist_;
|
||||
table_file.index_params_ = table_schema.index_params_;
|
||||
table_file.metric_type_ = table_schema.metric_type_;
|
||||
resRow["file_id"].to_string(table_file.file_id_);
|
||||
table_file.file_type_ = resRow["file_type"];
|
||||
|
@ -1644,7 +1638,7 @@ MySQLMetaImpl::FilesToMerge(const std::string& table_id, TableFilesSchema& files
|
|||
table_file.date_ = resRow["date"];
|
||||
table_file.index_file_size_ = table_schema.index_file_size_;
|
||||
table_file.engine_type_ = resRow["engine_type"];
|
||||
table_file.nlist_ = table_schema.nlist_;
|
||||
table_file.index_params_ = table_schema.index_params_;
|
||||
table_file.metric_type_ = table_schema.metric_type_;
|
||||
table_file.created_on_ = resRow["created_on"];
|
||||
table_file.dimension_ = table_schema.dimension_;
|
||||
|
@ -1722,7 +1716,7 @@ MySQLMetaImpl::FilesToIndex(TableFilesSchema& files) {
|
|||
}
|
||||
table_file.dimension_ = groups[table_file.table_id_].dimension_;
|
||||
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
|
||||
table_file.nlist_ = groups[table_file.table_id_].nlist_;
|
||||
table_file.index_params_ = groups[table_file.table_id_].index_params_;
|
||||
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
|
||||
|
||||
auto status = utils::GetTableFilePath(options_, table_file);
|
||||
|
@ -1809,7 +1803,7 @@ MySQLMetaImpl::FilesByType(const std::string& table_id, const std::vector<int>&
|
|||
file_schema.created_on_ = resRow["created_on"];
|
||||
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
file_schema.dimension_ = table_schema.dimension_;
|
||||
|
||||
|
@ -2083,8 +2077,7 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint64_t seconds /*, CleanUpFilter* filter*/)
|
|||
// If we are deleting a raw table file, it means it's okay to delete the entire segment directory.
|
||||
// Else, we can only delete the single file
|
||||
// TODO(zhiru): We determine whether a table file is raw by its engine type. This is a bit hacky
|
||||
if (table_file.engine_type_ == (int32_t)EngineType::FAISS_IDMAP ||
|
||||
table_file.engine_type_ == (int32_t)EngineType::FAISS_BIN_IDMAP) {
|
||||
if (utils::IsRawIndexType(table_file.engine_type_)) {
|
||||
utils::DeleteSegment(options_, table_file);
|
||||
std::string segment_dir;
|
||||
utils::GetParentPath(table_file.location_, segment_dir);
|
||||
|
|
|
@ -68,7 +68,8 @@ StoragePrototype(const std::string& path) {
|
|||
make_column("created_on", &TableSchema::created_on_),
|
||||
make_column("flag", &TableSchema::flag_, default_value(0)),
|
||||
make_column("index_file_size", &TableSchema::index_file_size_),
|
||||
make_column("engine_type", &TableSchema::engine_type_), make_column("nlist", &TableSchema::nlist_),
|
||||
make_column("engine_type", &TableSchema::engine_type_),
|
||||
make_column("index_params", &TableSchema::index_params_),
|
||||
make_column("metric_type", &TableSchema::metric_type_),
|
||||
make_column("owner_table", &TableSchema::owner_table_, default_value("")),
|
||||
make_column("partition_tag", &TableSchema::partition_tag_, default_value("")),
|
||||
|
@ -213,7 +214,7 @@ SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
|
|||
auto groups = ConnectorPtr->select(
|
||||
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
|
||||
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
|
||||
&TableSchema::nlist_, &TableSchema::metric_type_, &TableSchema::owner_table_,
|
||||
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::owner_table_,
|
||||
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::flush_lsn_),
|
||||
where(c(&TableSchema::table_id_) == table_schema.table_id_ and
|
||||
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
|
||||
|
@ -226,7 +227,7 @@ SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
|
|||
table_schema.flag_ = std::get<4>(groups[0]);
|
||||
table_schema.index_file_size_ = std::get<5>(groups[0]);
|
||||
table_schema.engine_type_ = std::get<6>(groups[0]);
|
||||
table_schema.nlist_ = std::get<7>(groups[0]);
|
||||
table_schema.index_params_ = std::get<7>(groups[0]);
|
||||
table_schema.metric_type_ = std::get<8>(groups[0]);
|
||||
table_schema.owner_table_ = std::get<9>(groups[0]);
|
||||
table_schema.partition_tag_ = std::get<10>(groups[0]);
|
||||
|
@ -272,7 +273,7 @@ SqliteMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
|
|||
auto selected = ConnectorPtr->select(
|
||||
columns(&TableSchema::id_, &TableSchema::table_id_, &TableSchema::dimension_, &TableSchema::created_on_,
|
||||
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
|
||||
&TableSchema::nlist_, &TableSchema::metric_type_, &TableSchema::owner_table_,
|
||||
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::owner_table_,
|
||||
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::flush_lsn_),
|
||||
where(c(&TableSchema::state_) != (int)TableSchema::TO_DELETE and c(&TableSchema::owner_table_) == ""));
|
||||
for (auto& table : selected) {
|
||||
|
@ -284,7 +285,7 @@ SqliteMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
|
|||
schema.flag_ = std::get<4>(table);
|
||||
schema.index_file_size_ = std::get<5>(table);
|
||||
schema.engine_type_ = std::get<6>(table);
|
||||
schema.nlist_ = std::get<7>(table);
|
||||
schema.index_params_ = std::get<7>(table);
|
||||
schema.metric_type_ = std::get<8>(table);
|
||||
schema.owner_table_ = std::get<9>(table);
|
||||
schema.partition_tag_ = std::get<10>(table);
|
||||
|
@ -373,17 +374,8 @@ SqliteMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
|
|||
file_schema.created_on_ = utils::GetMicroSecTimeStamp();
|
||||
file_schema.updated_time_ = file_schema.created_on_;
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
|
||||
if (file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW ||
|
||||
file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW_MERGE) {
|
||||
file_schema.engine_type_ = server::ValidationUtil::IsBinaryMetricType(table_schema.metric_type_)
|
||||
? (int32_t)EngineType::FAISS_BIN_IDMAP
|
||||
: (int32_t)EngineType::FAISS_IDMAP;
|
||||
} else {
|
||||
file_schema.engine_type_ = table_schema.engine_type_;
|
||||
}
|
||||
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.engine_type_ = table_schema.engine_type_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
// multi-threads call sqlite update may get exception('bad logic', etc), so we add a lock here
|
||||
|
@ -436,7 +428,7 @@ SqliteMetaImpl::GetTableFiles(const std::string& table_id, const std::vector<siz
|
|||
file_schema.created_on_ = std::get<8>(file);
|
||||
file_schema.dimension_ = table_schema.dimension_;
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
utils::GetTableFilePath(options_, file_schema);
|
||||
|
@ -486,7 +478,7 @@ SqliteMetaImpl::GetTableFilesBySegmentId(const std::string& segment_id,
|
|||
file_schema.created_on_ = std::get<9>(file);
|
||||
file_schema.dimension_ = table_schema.dimension_;
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
utils::GetTableFilePath(options_, file_schema);
|
||||
|
@ -601,7 +593,7 @@ SqliteMetaImpl::GetTableFilesByFlushLSN(uint64_t flush_lsn, TableFilesSchema& ta
|
|||
}
|
||||
table_file.dimension_ = groups[table_file.table_id_].dimension_;
|
||||
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
|
||||
table_file.nlist_ = groups[table_file.table_id_].nlist_;
|
||||
table_file.index_params_ = groups[table_file.table_id_].index_params_;
|
||||
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
|
||||
table_files.push_back(table_file);
|
||||
}
|
||||
|
@ -721,7 +713,7 @@ SqliteMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex&
|
|||
table_schema.partition_tag_ = std::get<7>(tables[0]);
|
||||
table_schema.version_ = std::get<8>(tables[0]);
|
||||
table_schema.engine_type_ = index.engine_type_;
|
||||
table_schema.nlist_ = index.nlist_;
|
||||
table_schema.index_params_ = index.extra_params_.dump();
|
||||
table_schema.metric_type_ = index.metric_type_;
|
||||
|
||||
ConnectorPtr->update(table_schema);
|
||||
|
@ -773,12 +765,12 @@ SqliteMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& inde
|
|||
fiu_do_on("SqliteMetaImpl.DescribeTableIndex.throw_exception", throw std::exception());
|
||||
|
||||
auto groups = ConnectorPtr->select(
|
||||
columns(&TableSchema::engine_type_, &TableSchema::nlist_, &TableSchema::metric_type_),
|
||||
columns(&TableSchema::engine_type_, &TableSchema::index_params_, &TableSchema::metric_type_),
|
||||
where(c(&TableSchema::table_id_) == table_id and c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
|
||||
|
||||
if (groups.size() == 1) {
|
||||
index.engine_type_ = std::get<0>(groups[0]);
|
||||
index.nlist_ = std::get<1>(groups[0]);
|
||||
index.extra_params_ = milvus::json::parse(std::get<1>(groups[0]));
|
||||
index.metric_type_ = std::get<2>(groups[0]);
|
||||
} else {
|
||||
return Status(DB_NOT_FOUND, "Table " + table_id + " not found");
|
||||
|
@ -813,7 +805,7 @@ SqliteMetaImpl::DropTableIndex(const std::string& table_id) {
|
|||
|
||||
// set table index type to raw
|
||||
ConnectorPtr->update_all(
|
||||
set(c(&TableSchema::engine_type_) = DEFAULT_ENGINE_TYPE, c(&TableSchema::nlist_) = DEFAULT_NLIST),
|
||||
set(c(&TableSchema::engine_type_) = DEFAULT_ENGINE_TYPE, c(&TableSchema::index_params_) = "{}"),
|
||||
where(c(&TableSchema::table_id_) == table_id));
|
||||
|
||||
ENGINE_LOG_DEBUG << "Successfully drop table index, table id = " << table_id;
|
||||
|
@ -886,13 +878,14 @@ SqliteMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Ta
|
|||
server::MetricCollector metric;
|
||||
fiu_do_on("SqliteMetaImpl.ShowPartitions.throw_exception", throw std::exception());
|
||||
|
||||
auto partitions =
|
||||
ConnectorPtr->select(columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_,
|
||||
&TableSchema::created_on_, &TableSchema::flag_, &TableSchema::index_file_size_,
|
||||
&TableSchema::engine_type_, &TableSchema::nlist_, &TableSchema::metric_type_,
|
||||
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::table_id_),
|
||||
where(c(&TableSchema::owner_table_) == table_id and
|
||||
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
|
||||
auto partitions = ConnectorPtr->select(
|
||||
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
|
||||
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
|
||||
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::partition_tag_,
|
||||
&TableSchema::version_, &TableSchema::table_id_),
|
||||
where(c(&TableSchema::owner_table_) == table_id and
|
||||
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
|
||||
|
||||
for (size_t i = 0; i < partitions.size(); i++) {
|
||||
meta::TableSchema partition_schema;
|
||||
partition_schema.id_ = std::get<0>(partitions[i]);
|
||||
|
@ -902,7 +895,7 @@ SqliteMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Ta
|
|||
partition_schema.flag_ = std::get<4>(partitions[i]);
|
||||
partition_schema.index_file_size_ = std::get<5>(partitions[i]);
|
||||
partition_schema.engine_type_ = std::get<6>(partitions[i]);
|
||||
partition_schema.nlist_ = std::get<7>(partitions[i]);
|
||||
partition_schema.index_params_ = std::get<7>(partitions[i]);
|
||||
partition_schema.metric_type_ = std::get<8>(partitions[i]);
|
||||
partition_schema.owner_table_ = table_id;
|
||||
partition_schema.partition_tag_ = std::get<9>(partitions[i]);
|
||||
|
@ -995,7 +988,7 @@ SqliteMetaImpl::FilesToSearch(const std::string& table_id, const std::vector<siz
|
|||
table_file.engine_type_ = std::get<8>(file);
|
||||
table_file.dimension_ = table_schema.dimension_;
|
||||
table_file.index_file_size_ = table_schema.index_file_size_;
|
||||
table_file.nlist_ = table_schema.nlist_;
|
||||
table_file.index_params_ = table_schema.index_params_;
|
||||
table_file.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
auto status = utils::GetTableFilePath(options_, table_file);
|
||||
|
@ -1063,7 +1056,7 @@ SqliteMetaImpl::FilesToMerge(const std::string& table_id, TableFilesSchema& file
|
|||
table_file.created_on_ = std::get<8>(file);
|
||||
table_file.dimension_ = table_schema.dimension_;
|
||||
table_file.index_file_size_ = table_schema.index_file_size_;
|
||||
table_file.nlist_ = table_schema.nlist_;
|
||||
table_file.index_params_ = table_schema.index_params_;
|
||||
table_file.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
auto status = utils::GetTableFilePath(options_, table_file);
|
||||
|
@ -1134,7 +1127,7 @@ SqliteMetaImpl::FilesToIndex(TableFilesSchema& files) {
|
|||
}
|
||||
table_file.dimension_ = groups[table_file.table_id_].dimension_;
|
||||
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
|
||||
table_file.nlist_ = groups[table_file.table_id_].nlist_;
|
||||
table_file.index_params_ = groups[table_file.table_id_].index_params_;
|
||||
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
|
||||
files.push_back(table_file);
|
||||
}
|
||||
|
@ -1192,7 +1185,7 @@ SqliteMetaImpl::FilesByType(const std::string& table_id, const std::vector<int>&
|
|||
|
||||
file_schema.dimension_ = table_schema.dimension_;
|
||||
file_schema.index_file_size_ = table_schema.index_file_size_;
|
||||
file_schema.nlist_ = table_schema.nlist_;
|
||||
file_schema.index_params_ = table_schema.index_params_;
|
||||
file_schema.metric_type_ = table_schema.metric_type_;
|
||||
|
||||
switch (file_schema.file_type_) {
|
||||
|
@ -1423,8 +1416,7 @@ SqliteMetaImpl::CleanUpFilesWithTTL(uint64_t seconds /*, CleanUpFilter* filter*/
|
|||
// If we are deleting a raw table file, it means it's okay to delete the entire segment directory.
|
||||
// Else, we can only delete the single file
|
||||
// TODO(zhiru): We determine whether a table file is raw by its engine type. This is a bit hacky
|
||||
if (table_file.engine_type_ == (int32_t)EngineType::FAISS_IDMAP ||
|
||||
table_file.engine_type_ == (int32_t)EngineType::FAISS_BIN_IDMAP) {
|
||||
if (utils::IsRawIndexType(table_file.engine_type_)) {
|
||||
utils::DeleteSegment(options_, table_file);
|
||||
std::string segment_dir;
|
||||
utils::GetParentPath(table_file.location_, segment_dir);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -4,6 +4,14 @@ import "status.proto";
|
|||
|
||||
package milvus.grpc;
|
||||
|
||||
/**
|
||||
* @brief general usage
|
||||
*/
|
||||
message KeyValuePair {
|
||||
string key = 1;
|
||||
string value = 2;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Table name
|
||||
*/
|
||||
|
@ -21,6 +29,7 @@ message TableNameList {
|
|||
|
||||
/**
|
||||
* @brief Table schema
|
||||
* metric_type: 1-L2, 2-IP
|
||||
*/
|
||||
message TableSchema {
|
||||
Status status = 1;
|
||||
|
@ -28,6 +37,7 @@ message TableSchema {
|
|||
int64 dimension = 3;
|
||||
int64 index_file_size = 4;
|
||||
int32 metric_type = 5;
|
||||
repeated KeyValuePair extra_params = 6;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -62,6 +72,7 @@ message InsertParam {
|
|||
repeated RowRecord row_record_array = 2;
|
||||
repeated int64 row_id_array = 3; //optional
|
||||
string partition_tag = 4;
|
||||
repeated KeyValuePair extra_params = 5;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -77,10 +88,10 @@ message VectorIds {
|
|||
*/
|
||||
message SearchParam {
|
||||
string table_name = 1;
|
||||
repeated RowRecord query_record_array = 2;
|
||||
int64 topk = 3;
|
||||
int64 nprobe = 4;
|
||||
repeated string partition_tag_array = 5;
|
||||
repeated string partition_tag_array = 2;
|
||||
repeated RowRecord query_record_array = 3;
|
||||
int64 topk = 4;
|
||||
repeated KeyValuePair extra_params = 5;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -96,10 +107,10 @@ message SearchInFilesParam {
|
|||
*/
|
||||
message SearchByIDParam {
|
||||
string table_name = 1;
|
||||
int64 id = 2;
|
||||
int64 topk = 3;
|
||||
int64 nprobe = 4;
|
||||
repeated string partition_tag_array = 5;
|
||||
repeated string partition_tag_array = 2;
|
||||
int64 id = 3;
|
||||
int64 topk = 4;
|
||||
repeated KeyValuePair extra_params = 5;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -143,23 +154,15 @@ message Command {
|
|||
string cmd = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Index
|
||||
* @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
|
||||
* @metric_type: 1-L2, 2-IP
|
||||
*/
|
||||
message Index {
|
||||
int32 index_type = 1;
|
||||
int32 nlist = 2;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Index params
|
||||
* @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
|
||||
*/
|
||||
message IndexParam {
|
||||
Status status = 1;
|
||||
string table_name = 2;
|
||||
Index index = 3;
|
||||
int32 index_type = 3;
|
||||
repeated KeyValuePair extra_params = 4;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -22,7 +22,6 @@ endif ()
|
|||
|
||||
set(external_srcs
|
||||
knowhere/adapter/SptagAdapter.cpp
|
||||
knowhere/adapter/VectorAdapter.cpp
|
||||
knowhere/common/Exception.cpp
|
||||
knowhere/common/Timer.cpp
|
||||
)
|
||||
|
@ -117,4 +116,4 @@ set(INDEX_INCLUDE_DIRS
|
|||
${LAPACK_INCLUDE_DIR}
|
||||
)
|
||||
|
||||
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
|
|
|
@ -43,7 +43,8 @@ std::vector<SPTAG::QueryResult>
|
|||
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) {
|
||||
GETTENSOR(dataset);
|
||||
|
||||
std::vector<SPTAG::QueryResult> query_results(rows, SPTAG::QueryResult(nullptr, config->k, true));
|
||||
std::vector<SPTAG::QueryResult> query_results(rows,
|
||||
SPTAG::QueryResult(nullptr, config[meta::TOPK].get<int64_t>(), true));
|
||||
for (auto i = 0; i < rows; ++i) {
|
||||
query_results[i].SetTarget(&p_data[i * dim]);
|
||||
}
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
namespace meta {
|
||||
const char* DIM = "dim";
|
||||
const char* TENSOR = "tensor";
|
||||
const char* ROWS = "rows";
|
||||
const char* IDS = "ids";
|
||||
const char* DISTANCE = "distance";
|
||||
}; // namespace meta
|
||||
|
||||
} // namespace knowhere
|
|
@ -13,17 +13,10 @@
|
|||
|
||||
#include <string>
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
namespace meta {
|
||||
extern const char* DIM;
|
||||
extern const char* TENSOR;
|
||||
extern const char* ROWS;
|
||||
extern const char* IDS;
|
||||
extern const char* DISTANCE;
|
||||
}; // namespace meta
|
||||
|
||||
#define GETTENSOR(dataset) \
|
||||
auto dim = dataset->Get<int64_t>(meta::DIM); \
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS); \
|
||||
|
|
|
@ -11,64 +11,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include "Log.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "src/utils/Json.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
enum class METRICTYPE {
|
||||
INVALID = 0,
|
||||
L2 = 1,
|
||||
IP = 2,
|
||||
HAMMING = 20,
|
||||
JACCARD = 21,
|
||||
TANIMOTO = 22,
|
||||
};
|
||||
|
||||
// General Config
|
||||
constexpr int64_t INVALID_VALUE = -1;
|
||||
constexpr int64_t DEFAULT_K = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_DIM = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_GPUID = INVALID_VALUE;
|
||||
constexpr METRICTYPE DEFAULT_TYPE = METRICTYPE::INVALID;
|
||||
|
||||
struct Cfg {
|
||||
METRICTYPE metric_type = DEFAULT_TYPE;
|
||||
int64_t k = DEFAULT_K;
|
||||
int64_t gpu_id = DEFAULT_GPUID;
|
||||
int64_t d = DEFAULT_DIM;
|
||||
|
||||
Cfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, METRICTYPE type)
|
||||
: metric_type(type), k(k), gpu_id(gpu_id), d(dim) {
|
||||
}
|
||||
|
||||
Cfg() = default;
|
||||
|
||||
virtual bool
|
||||
CheckValid() {
|
||||
if (metric_type == METRICTYPE::IP || metric_type == METRICTYPE::L2) {
|
||||
return true;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "MetricType: " << int(metric_type) << " not support!";
|
||||
KNOWHERE_THROW_MSG(ss.str());
|
||||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
Dump() {
|
||||
KNOWHERE_LOG_DEBUG << DumpImpl().str();
|
||||
}
|
||||
|
||||
virtual std::stringstream
|
||||
DumpImpl() {
|
||||
std::stringstream ss;
|
||||
ss << "dim: " << d << ", metric: " << int(metric_type) << ", gpuid: " << gpu_id << ", k: " << k;
|
||||
return ss;
|
||||
}
|
||||
};
|
||||
using Config = std::shared_ptr<Cfg>;
|
||||
using Config = milvus::json;
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "IndexModel.h"
|
||||
#include "IndexType.h"
|
||||
#include "knowhere/common/BinarySet.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/common/Dataset.h"
|
||||
#include "knowhere/index/preprocessor/Preprocessor.h"
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <utility>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/index/vector_index/FaissBaseIndex.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/helpers/FaissIO.h"
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/index_factory.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
||||
|
@ -43,13 +45,13 @@ BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
GETBINARYTENSOR(dataset)
|
||||
|
||||
auto elems = rows * config->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, Config());
|
||||
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
@ -90,14 +92,9 @@ BinaryIDMAP::Add(const DatasetPtr& dataset, const Config& config) {
|
|||
|
||||
void
|
||||
BinaryIDMAP::Train(const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
|
||||
if (build_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
}
|
||||
config->CheckValid();
|
||||
|
||||
const char* type = "BFlat";
|
||||
auto index = faiss::index_binary_factory(config->d, type, GetMetricType(config->metric_type));
|
||||
auto index = faiss::index_binary_factory(config[meta::DIM].get<int64_t>(), type,
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
|
@ -181,26 +178,18 @@ BinaryIDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
|
||||
// auto search_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
|
||||
// if (search_cfg == nullptr) {
|
||||
// KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
// }
|
||||
|
||||
// GETBINARYTENSOR(dataset)
|
||||
auto dim = dataset->Get<int64_t>(meta::DIM);
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
auto elems = rows * config->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
auto* pdistances = (int32_t*)p_dist;
|
||||
// index_->searchById(rows, (uint8_t*)p_data, config->k, pdistances, p_id, bitset_);
|
||||
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
index_->search_by_id(rows, p_data, config->k, pdistances, p_id, bitset_);
|
||||
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
#include <faiss/IndexBinaryIVF.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
|
@ -45,22 +47,17 @@ BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
|
||||
// if (search_cfg == nullptr) {
|
||||
// KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
// }
|
||||
|
||||
GETBINARYTENSOR(dataset)
|
||||
|
||||
try {
|
||||
auto elems = rows * config->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, config);
|
||||
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
@ -108,29 +105,20 @@ BinaryIVF::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distanc
|
|||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
BinaryIVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->max_codes = config.get_with_default("max_codes", size_t(0));
|
||||
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->max_codes = config["max_code"];
|
||||
return params;
|
||||
}
|
||||
|
||||
IndexModelPtr
|
||||
BinaryIVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETBINARYTENSOR(dataset)
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, GetMetricType(build_cfg->metric_type));
|
||||
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, build_cfg->nlist,
|
||||
GetMetricType(build_cfg->metric_type));
|
||||
faiss::IndexBinary* coarse_quantizer =
|
||||
new faiss::IndexBinaryFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, config[IndexParams::nlist],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index->train(rows, (uint8_t*)p_data);
|
||||
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
|
||||
index_ = index;
|
||||
|
@ -190,17 +178,11 @@ BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
|
||||
// if (search_cfg == nullptr) {
|
||||
// KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
// }
|
||||
|
||||
// GETBINARYTENSOR(dataset)
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
try {
|
||||
auto elems = rows * config->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
|
@ -208,9 +190,7 @@ BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
int32_t* pdistances = (int32_t*)p_dist;
|
||||
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
// index_->searchById(rows, (uint8_t*)p_data, config->k, pdistances, p_id, blacklist);
|
||||
index_->search_by_id(rows, p_data, config->k, pdistances, p_id, bitset_);
|
||||
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
|
||||
|
@ -127,7 +128,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t& k, Graph& graph, const Conf
|
|||
int64_t K = k + 1;
|
||||
auto ntotal = Count();
|
||||
|
||||
size_t dim = config->d;
|
||||
size_t dim = config[meta::DIM];
|
||||
auto batch_size = 1000;
|
||||
auto tail_batch_size = ntotal % batch_size;
|
||||
auto batch_search_count = ntotal / batch_size;
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
|
@ -28,21 +29,16 @@ namespace knowhere {
|
|||
|
||||
IndexModelPtr
|
||||
GPUIVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
|
||||
idx_config.device = gpu_id_;
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
|
||||
GetMetricType(build_cfg->metric_type), idx_config);
|
||||
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, config[IndexParams::nlist],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()), idx_config);
|
||||
device_index.train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
|
@ -121,15 +117,13 @@ GPUIVF::LoadImpl(const BinarySet& index_binary) {
|
|||
}
|
||||
|
||||
void
|
||||
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
|
||||
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
std::lock_guard<std::mutex> lk(mutex_);
|
||||
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
|
||||
if (device_index) {
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(cfg);
|
||||
device_index->nprobe = search_cfg->nprobe;
|
||||
// assert(device_index->getNumProbes() == search_cfg->nprobe);
|
||||
device_index->nprobe = config[IndexParams::nprobe];
|
||||
ResScope rs(res_, gpu_id_);
|
||||
device_index->search(n, (float*)data, k, distances, labels);
|
||||
} else {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <faiss/index_factory.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -25,20 +26,16 @@ namespace knowhere {
|
|||
|
||||
IndexModelPtr
|
||||
GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
ResScope rs(temp_resource, gpu_id_, true);
|
||||
auto device_index = new faiss::gpu::GpuIndexIVFPQ(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
|
||||
build_cfg->m, build_cfg->nbits,
|
||||
GetMetricType(build_cfg->metric_type)); // IP not support
|
||||
auto device_index = new faiss::gpu::GpuIndexIVFPQ(
|
||||
temp_resource->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m],
|
||||
config[IndexParams::nbits],
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
|
||||
device_index->train(rows, (float*)p_data);
|
||||
std::shared_ptr<faiss::Index> host_index = nullptr;
|
||||
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
|
||||
|
@ -51,11 +48,10 @@ GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
|||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
GPUIVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->scan_table_threshold = conf->scan_table_threhold;
|
||||
// params->polysemous_ht = conf->polysemous_ht;
|
||||
// params->max_codes = conf->max_codes;
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->scan_table_threshold = config["scan_table_threhold"]
|
||||
// params->polysemous_ht = config["polysemous_ht"]
|
||||
// params->max_codes = config["max_codes"]
|
||||
|
||||
return params;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <faiss/index_factory.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -23,18 +24,14 @@ namespace knowhere {
|
|||
|
||||
IndexModelPtr
|
||||
GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << build_cfg->nlist << ","
|
||||
<< "SQ" << build_cfg->nbits;
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
|
||||
index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
<< "SQ" << config[IndexParams::nbits];
|
||||
auto build_index =
|
||||
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
|
|
|
@ -79,20 +79,15 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
|
||||
if (search_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("search conf is null");
|
||||
}
|
||||
index_->setEf(search_cfg->ef);
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
size_t id_size = sizeof(int64_t) * config->k;
|
||||
size_t dist_size = sizeof(float) * config->k;
|
||||
size_t id_size = sizeof(int64_t) * config[meta::TOPK].get<int64_t>();
|
||||
size_t dist_size = sizeof(float) * config[meta::TOPK].get<int64_t>();
|
||||
auto p_id = (int64_t*)malloc(id_size * rows);
|
||||
auto p_dist = (float*)malloc(dist_size * rows);
|
||||
|
||||
index_->setEf(config[IndexParams::ef]);
|
||||
|
||||
using P = std::pair<float, int64_t>;
|
||||
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
|
||||
#pragma omp parallel for
|
||||
|
@ -103,13 +98,13 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
// if (normalize) {
|
||||
// std::vector<float> norm_vector(Dimension());
|
||||
// normalize_vector((float*)(single_query), norm_vector.data(), Dimension());
|
||||
// ret = index_->searchKnn((float*)(norm_vector.data()), config->k, compare);
|
||||
// ret = index_->searchKnn((float*)(norm_vector.data()), config[meta::TOPK].get<int64_t>(), compare);
|
||||
// } else {
|
||||
// ret = index_->searchKnn((float*)single_query, config->k, compare);
|
||||
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
|
||||
// }
|
||||
ret = index_->searchKnn((float*)single_query, config->k, compare);
|
||||
ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
|
||||
|
||||
while (ret.size() < config->k) {
|
||||
while (ret.size() < config[meta::TOPK]) {
|
||||
ret.push_back(std::make_pair(-1, -1));
|
||||
}
|
||||
std::vector<float> dist;
|
||||
|
@ -125,8 +120,8 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
|
||||
[](const std::pair<float, int64_t>& e) { return e.second; });
|
||||
|
||||
memcpy(p_dist + i * config->k, dist.data(), dist_size);
|
||||
memcpy(p_id + i * config->k, ids.data(), id_size);
|
||||
memcpy(p_dist + i * config[meta::TOPK].get<int64_t>(), dist.data(), dist_size);
|
||||
memcpy(p_id + i * config[meta::TOPK].get<int64_t>(), ids.data(), id_size);
|
||||
}
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
|
@ -137,21 +132,17 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
|
||||
IndexModelPtr
|
||||
IndexHNSW::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
|
||||
if (build_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("build conf is null");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
hnswlib::SpaceInterface<float>* space;
|
||||
if (config->metric_type == METRICTYPE::L2) {
|
||||
if (config[Metric::TYPE] == Metric::L2) {
|
||||
space = new hnswlib::L2Space(dim);
|
||||
} else if (config->metric_type == METRICTYPE::IP) {
|
||||
} else if (config[Metric::TYPE] == Metric::IP) {
|
||||
space = new hnswlib::InnerProductSpace(dim);
|
||||
normalize = true;
|
||||
}
|
||||
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, build_cfg->M, build_cfg->ef);
|
||||
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
|
||||
config[IndexParams::efConstruction].get<int64_t>());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
|
@ -61,13 +62,13 @@ IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
}
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto elems = rows * config->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (float*)p_data, config->k, p_dist, p_id, Config());
|
||||
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -144,10 +145,9 @@ IDMAP::GetRawIds() {
|
|||
|
||||
void
|
||||
IDMAP::Train(const Config& config) {
|
||||
config->CheckValid();
|
||||
|
||||
const char* type = "IDMap,Flat";
|
||||
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
|
||||
auto index = faiss::index_factory(config[meta::DIM].get<int64_t>(), type,
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index_.reset(index);
|
||||
}
|
||||
|
||||
|
@ -214,7 +214,7 @@ IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
auto elems = rows * config->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
|
@ -222,8 +222,8 @@ IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
|
||||
// todo: enable search by id (zhiru)
|
||||
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
// index_->searchById(rows, (float*)p_data, config->k, p_dist, p_id, blacklist);
|
||||
index_->search_by_id(rows, p_data, config->k, p_dist, p_id, bitset_);
|
||||
// index_->searchById(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, blacklist);
|
||||
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <fiu-local.h>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
@ -43,16 +44,11 @@ using stdclock = std::chrono::high_resolution_clock;
|
|||
|
||||
IndexModelPtr
|
||||
IVF::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, build_cfg->nlist,
|
||||
GetMetricType(build_cfg->metric_type));
|
||||
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
// TODO(linxj): override here. train return model or not.
|
||||
|
@ -106,24 +102,19 @@ IVF::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (search_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
try {
|
||||
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
|
||||
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
|
||||
auto elems = rows * search_cfg->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
search_impl(rows, (float*)p_data, search_cfg->k, p_dist, p_id, config);
|
||||
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
|
||||
|
||||
// std::stringstream ss_res_id, ss_res_dist;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
|
@ -163,9 +154,8 @@ std::shared_ptr<faiss::IVFSearchParameters>
|
|||
IVF::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFSearchParameters>();
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->max_codes = config.get_with_default("max_codes", size_t(0));
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->max_codes = config["max_codes"];
|
||||
|
||||
return params;
|
||||
}
|
||||
|
@ -185,7 +175,7 @@ IVF::GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& c
|
|||
int64_t K = k + 1;
|
||||
auto ntotal = Count();
|
||||
|
||||
size_t dim = config->d;
|
||||
size_t dim = config[meta::DIM];
|
||||
auto batch_size = 1000;
|
||||
auto tail_batch_size = ntotal % batch_size;
|
||||
auto batch_search_count = ntotal / batch_size;
|
||||
|
@ -279,12 +269,6 @@ IVF::GetVectorById(const DatasetPtr& dataset, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (search_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
}
|
||||
|
||||
// auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
auto elems = dataset->Get<int64_t>(meta::DIM);
|
||||
|
||||
|
@ -311,16 +295,11 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
|
||||
if (search_cfg == nullptr) {
|
||||
KNOWHERE_THROW_MSG("not support this kind of config");
|
||||
}
|
||||
|
||||
auto rows = dataset->Get<int64_t>(meta::ROWS);
|
||||
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
try {
|
||||
auto elems = rows * search_cfg->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
|
@ -330,7 +309,7 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
|
|||
// todo: enable search by id (zhiru)
|
||||
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
|
||||
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
|
||||
index_ivf->search_by_id(rows, p_data, search_cfg->k, p_dist, p_id, bitset_);
|
||||
index_ivf->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
|
||||
|
||||
// std::stringstream ss_res_id, ss_res_dist;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
|
@ -30,16 +31,12 @@ namespace knowhere {
|
|||
|
||||
IndexModelPtr
|
||||
IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type));
|
||||
auto index =
|
||||
std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits);
|
||||
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
|
||||
config[IndexParams::m].get<int64_t>(),
|
||||
config[IndexParams::nbits].get<int64_t>());
|
||||
index->train(rows, (float*)p_data);
|
||||
|
||||
return std::make_shared<IVFIndexModel>(index);
|
||||
|
@ -48,11 +45,10 @@ IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
|
|||
std::shared_ptr<faiss::IVFSearchParameters>
|
||||
IVFPQ::GenParams(const Config& config) {
|
||||
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
|
||||
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
|
||||
params->nprobe = search_cfg->nprobe;
|
||||
// params->scan_table_threshold = conf->scan_table_threhold;
|
||||
// params->polysemous_ht = conf->polysemous_ht;
|
||||
// params->max_codes = conf->max_codes;
|
||||
params->nprobe = config[IndexParams::nprobe];
|
||||
// params->scan_table_threshold = config["scan_table_threhold"]
|
||||
// params->polysemous_ht = config["polysemous_ht"]
|
||||
// params->max_codes = config["max_codes"]
|
||||
|
||||
return params;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <faiss/index_factory.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
|
@ -30,17 +31,13 @@ namespace knowhere {
|
|||
|
||||
IndexModelPtr
|
||||
IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << build_cfg->nlist << ","
|
||||
<< "SQ" << build_cfg->nbits;
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
|
||||
index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
<< "SQ" << config[IndexParams::nbits];
|
||||
auto build_index =
|
||||
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
build_index->train(rows, (float*)p_data);
|
||||
|
||||
std::shared_ptr<faiss::Index> ret_index;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/index_factory.h>
|
||||
#include <fiu-local.h>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace knowhere {
|
||||
|
@ -30,19 +31,14 @@ namespace knowhere {
|
|||
IndexModelPtr
|
||||
IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
gpu_id_ = build_cfg->gpu_id;
|
||||
|
||||
GETTENSOR(dataset)
|
||||
gpu_id_ = config[knowhere::meta::DEVICEID];
|
||||
|
||||
std::stringstream index_type;
|
||||
index_type << "IVF" << build_cfg->nlist << ","
|
||||
index_type << "IVF" << config[IndexParams::nlist] << ","
|
||||
<< "SQ8Hybrid";
|
||||
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
|
||||
auto build_index =
|
||||
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
|
||||
|
||||
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
|
||||
if (temp_resource != nullptr) {
|
||||
|
@ -133,17 +129,10 @@ IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distanc
|
|||
}
|
||||
|
||||
QuantizerPtr
|
||||
IVFSQHybrid::LoadQuantizer(const Config& conf) {
|
||||
IVFSQHybrid::LoadQuantizer(const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
|
||||
if (quantizer_conf != nullptr) {
|
||||
if (quantizer_conf->mode != 1) {
|
||||
KNOWHERE_THROW_MSG("mode only support 1 in this func");
|
||||
}
|
||||
}
|
||||
auto gpu_id = quantizer_conf->gpu_id;
|
||||
|
||||
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, false);
|
||||
faiss::gpu::GpuClonerOptions option;
|
||||
|
@ -152,7 +141,7 @@ IVFSQHybrid::LoadQuantizer(const Config& conf) {
|
|||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = nullptr;
|
||||
index_composition->mode = quantizer_conf->mode; // only 1
|
||||
index_composition->mode = 1; // only 1
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
delete gpu_index;
|
||||
|
@ -205,19 +194,10 @@ IVFSQHybrid::UnsetQuantizer() {
|
|||
}
|
||||
|
||||
VectorIndexPtr
|
||||
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
|
||||
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& config) {
|
||||
// std::lock_guard<std::mutex> lk(g_mutex);
|
||||
|
||||
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
|
||||
if (quantizer_conf != nullptr) {
|
||||
if (quantizer_conf->mode != 2) {
|
||||
KNOWHERE_THROW_MSG("mode only support 2 in this func");
|
||||
}
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("conf error");
|
||||
}
|
||||
|
||||
auto gpu_id = quantizer_conf->gpu_id;
|
||||
int64_t gpu_id = config[knowhere::meta::DEVICEID];
|
||||
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, false);
|
||||
|
@ -231,7 +211,7 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
|
|||
auto index_composition = new faiss::IndexComposition;
|
||||
index_composition->index = index_.get();
|
||||
index_composition->quantizer = ivf_quantizer->quantizer;
|
||||
index_composition->mode = quantizer_conf->mode; // only 2
|
||||
index_composition->mode = 2; // only 2
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
|
||||
std::shared_ptr<faiss::Index> new_idx;
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "knowhere/index/vector_index/IndexNSG.h"
|
||||
|
||||
#include "knowhere/adapter/VectorAdapter.h"
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
|
@ -23,6 +24,7 @@
|
|||
#endif
|
||||
|
||||
#include <fiu-local.h>
|
||||
|
||||
#include "knowhere/index/vector_index/IndexIDMAP.h"
|
||||
#include "knowhere/index/vector_index/IndexIVF.h"
|
||||
#include "knowhere/index/vector_index/nsg/NSG.h"
|
||||
|
@ -72,23 +74,21 @@ NSG::Load(const BinarySet& index_binary) {
|
|||
|
||||
DatasetPtr
|
||||
NSG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
|
||||
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
|
||||
GETTENSOR(dataset)
|
||||
|
||||
auto elems = rows * build_cfg->k;
|
||||
auto elems = rows * config[meta::TOPK].get<int64_t>();
|
||||
size_t p_id_size = sizeof(int64_t) * elems;
|
||||
size_t p_dist_size = sizeof(float) * elems;
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
algo::SearchParams s_params;
|
||||
s_params.search_length = build_cfg->search_length;
|
||||
index_->Search((float*)p_data, rows, dim, build_cfg->k, p_dist, p_id, s_params);
|
||||
s_params.search_length = config[IndexParams::search_length];
|
||||
index_->Search((float*)p_data, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id, s_params);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
|
@ -98,41 +98,35 @@ NSG::Search(const DatasetPtr& dataset, const Config& config) {
|
|||
|
||||
IndexModelPtr
|
||||
NSG::Train(const DatasetPtr& dataset, const Config& config) {
|
||||
config->Dump();
|
||||
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
|
||||
if (build_cfg != nullptr) {
|
||||
build_cfg->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
auto idmap = std::make_shared<IDMAP>();
|
||||
idmap->Train(config);
|
||||
idmap->AddWithoutId(dataset, config);
|
||||
Graph knng;
|
||||
const float* raw_data = idmap->GetRawVectors();
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
if (build_cfg->gpu_id == knowhere::INVALID_VALUE) {
|
||||
if (config[knowhere::meta::DEVICEID].get<int64_t>() == -1) {
|
||||
auto preprocess_index = std::make_shared<IVF>();
|
||||
auto model = preprocess_index->Train(dataset, config);
|
||||
preprocess_index->set_index_model(model);
|
||||
preprocess_index->Add(dataset, config);
|
||||
preprocess_index->GenGraph(raw_data, build_cfg->knng, knng, config);
|
||||
preprocess_index->AddWithoutIds(dataset, config);
|
||||
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
|
||||
} else {
|
||||
auto gpu_idx = cloner::CopyCpuToGpu(idmap, build_cfg->gpu_id, config);
|
||||
auto gpu_idx = cloner::CopyCpuToGpu(idmap, config[knowhere::meta::DEVICEID].get<int64_t>(), config);
|
||||
auto gpu_idmap = std::dynamic_pointer_cast<GPUIDMAP>(gpu_idx);
|
||||
gpu_idmap->GenGraph(raw_data, build_cfg->knng, knng, config);
|
||||
gpu_idmap->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
|
||||
}
|
||||
#else
|
||||
auto preprocess_index = std::make_shared<IVF>();
|
||||
auto model = preprocess_index->Train(dataset, config);
|
||||
preprocess_index->set_index_model(model);
|
||||
preprocess_index->AddWithoutIds(dataset, config);
|
||||
preprocess_index->GenGraph(raw_data, build_cfg->knng, knng, config);
|
||||
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
|
||||
#endif
|
||||
|
||||
algo::BuildParams b_params;
|
||||
b_params.candidate_pool_size = build_cfg->candidate_pool_size;
|
||||
b_params.out_degree = build_cfg->out_degree;
|
||||
b_params.search_length = build_cfg->search_length;
|
||||
b_params.candidate_pool_size = config[IndexParams::candidate];
|
||||
b_params.out_degree = config[IndexParams::out_degree];
|
||||
b_params.search_length = config[IndexParams::search_length];
|
||||
|
||||
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
|
|
|
@ -123,9 +123,6 @@ CPUSPTAGRNG::Load(const BinarySet& binary_set) {
|
|||
IndexModelPtr
|
||||
CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) {
|
||||
SetParameters(train_config);
|
||||
if (train_config != nullptr) {
|
||||
train_config->CheckValid(); // throw exception
|
||||
}
|
||||
|
||||
DatasetPtr dataset = origin; // TODO(linxj): copy or reference?
|
||||
|
||||
|
@ -159,62 +156,56 @@ CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) {
|
|||
|
||||
void
|
||||
CPUSPTAGRNG::SetParameters(const Config& config) {
|
||||
#define Assign(param_name, str_name) \
|
||||
conf->param_name == INVALID_VALUE ? index_ptr_->SetParameter(str_name, std::to_string(build_cfg->param_name)) \
|
||||
: index_ptr_->SetParameter(str_name, std::to_string(conf->param_name))
|
||||
#define Assign(param_name, str_name) \
|
||||
index_ptr_->SetParameter(str_name, std::to_string(build_cfg[param_name].get<int64_t>()))
|
||||
|
||||
if (index_type_ == SPTAG::IndexAlgoType::KDT) {
|
||||
auto conf = std::dynamic_pointer_cast<KDTCfg>(config);
|
||||
auto build_cfg = SPTAGParameterMgr::GetInstance().GetKDTParameters();
|
||||
|
||||
Assign(kdtnumber, "KDTNumber");
|
||||
Assign(numtopdimensionkdtsplit, "NumTopDimensionKDTSplit");
|
||||
Assign(samples, "Samples");
|
||||
Assign(tptnumber, "TPTNumber");
|
||||
Assign(tptleafsize, "TPTLeafSize");
|
||||
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
|
||||
Assign(neighborhoodsize, "NeighborhoodSize");
|
||||
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
|
||||
Assign(graphcefscale, "GraphCEFScale");
|
||||
Assign(refineiterations, "RefineIterations");
|
||||
Assign(cef, "CEF");
|
||||
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
|
||||
Assign(numofthreads, "NumberOfThreads");
|
||||
Assign(maxcheck, "MaxCheck");
|
||||
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
|
||||
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
|
||||
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
|
||||
Assign("kdtnumber", "KDTNumber");
|
||||
Assign("numtopdimensionkdtsplit", "NumTopDimensionKDTSplit");
|
||||
Assign("samples", "Samples");
|
||||
Assign("tptnumber", "TPTNumber");
|
||||
Assign("tptleafsize", "TPTLeafSize");
|
||||
Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit");
|
||||
Assign("neighborhoodsize", "NeighborhoodSize");
|
||||
Assign("graphneighborhoodscale", "GraphNeighborhoodScale");
|
||||
Assign("graphcefscale", "GraphCEFScale");
|
||||
Assign("refineiterations", "RefineIterations");
|
||||
Assign("cef", "CEF");
|
||||
Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph");
|
||||
Assign("numofthreads", "NumberOfThreads");
|
||||
Assign("maxcheck", "MaxCheck");
|
||||
Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation");
|
||||
Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots");
|
||||
Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots");
|
||||
} else {
|
||||
auto conf = std::dynamic_pointer_cast<BKTCfg>(config);
|
||||
auto build_cfg = SPTAGParameterMgr::GetInstance().GetBKTParameters();
|
||||
|
||||
Assign(bktnumber, "BKTNumber");
|
||||
Assign(bktkmeansk, "BKTKMeansK");
|
||||
Assign(bktleafsize, "BKTLeafSize");
|
||||
Assign(samples, "Samples");
|
||||
Assign(tptnumber, "TPTNumber");
|
||||
Assign(tptleafsize, "TPTLeafSize");
|
||||
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
|
||||
Assign(neighborhoodsize, "NeighborhoodSize");
|
||||
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
|
||||
Assign(graphcefscale, "GraphCEFScale");
|
||||
Assign(refineiterations, "RefineIterations");
|
||||
Assign(cef, "CEF");
|
||||
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
|
||||
Assign(numofthreads, "NumberOfThreads");
|
||||
Assign(maxcheck, "MaxCheck");
|
||||
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
|
||||
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
|
||||
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
|
||||
Assign("bktnumber", "BKTNumber");
|
||||
Assign("bktkmeansk", "BKTKMeansK");
|
||||
Assign("bktleafsize", "BKTLeafSize");
|
||||
Assign("samples", "Samples");
|
||||
Assign("tptnumber", "TPTNumber");
|
||||
Assign("tptleafsize", "TPTLeafSize");
|
||||
Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit");
|
||||
Assign("neighborhoodsize", "NeighborhoodSize");
|
||||
Assign("graphneighborhoodscale", "GraphNeighborhoodScale");
|
||||
Assign("graphcefscale", "GraphCEFScale");
|
||||
Assign("refineiterations", "RefineIterations");
|
||||
Assign("cef", "CEF");
|
||||
Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph");
|
||||
Assign("numofthreads", "NumberOfThreads");
|
||||
Assign("maxcheck", "MaxCheck");
|
||||
Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation");
|
||||
Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots");
|
||||
Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots");
|
||||
}
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
|
||||
SetParameters(config);
|
||||
// if (config != nullptr) {
|
||||
// config->CheckValid(); // throw exception
|
||||
// }
|
||||
|
||||
auto p_data = dataset->Get<const float*>(meta::TENSOR);
|
||||
for (auto i = 0; i < 10; ++i) {
|
||||
|
|
|
@ -23,9 +23,9 @@ struct Quantizer {
|
|||
};
|
||||
using QuantizerPtr = std::shared_ptr<Quantizer>;
|
||||
|
||||
struct QuantizerCfg : Cfg {
|
||||
int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
|
||||
};
|
||||
using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
|
||||
// struct QuantizerCfg : Cfg {
|
||||
// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
|
||||
// };
|
||||
// using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -24,7 +24,10 @@ namespace cloner {
|
|||
VectorIndexPtr
|
||||
CopyGpuToCpu(const VectorIndexPtr& index, const Config& config) {
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIndex>(index)) {
|
||||
return device_index->CopyGpuToCpu(config);
|
||||
VectorIndexPtr result = device_index->CopyGpuToCpu(config);
|
||||
auto uids = index->GetUids();
|
||||
result->SetUids(uids);
|
||||
return result;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("index type is not gpuindex");
|
||||
}
|
||||
|
|
|
@ -17,47 +17,23 @@
|
|||
namespace knowhere {
|
||||
|
||||
faiss::MetricType
|
||||
GetMetricType(METRICTYPE& type) {
|
||||
if (type == METRICTYPE::L2) {
|
||||
GetMetricType(const std::string& type) {
|
||||
if (type == Metric::L2) {
|
||||
return faiss::METRIC_L2;
|
||||
}
|
||||
if (type == METRICTYPE::IP) {
|
||||
if (type == Metric::IP) {
|
||||
return faiss::METRIC_INNER_PRODUCT;
|
||||
}
|
||||
// binary only
|
||||
if (type == METRICTYPE::JACCARD) {
|
||||
if (type == Metric::JACCARD) {
|
||||
return faiss::METRIC_Jaccard;
|
||||
}
|
||||
if (type == METRICTYPE::TANIMOTO) {
|
||||
if (type == Metric::TANIMOTO) {
|
||||
return faiss::METRIC_Tanimoto;
|
||||
}
|
||||
if (type == METRICTYPE::HAMMING) {
|
||||
if (type == Metric::HAMMING) {
|
||||
return faiss::METRIC_Hamming;
|
||||
}
|
||||
|
||||
KNOWHERE_THROW_MSG("Metric type is invalid");
|
||||
}
|
||||
|
||||
std::stringstream
|
||||
IVFCfg::DumpImpl() {
|
||||
auto ss = Cfg::DumpImpl();
|
||||
ss << ", nlist: " << nlist << ", nprobe: " << nprobe;
|
||||
return ss;
|
||||
}
|
||||
|
||||
std::stringstream
|
||||
IVFSQCfg::DumpImpl() {
|
||||
auto ss = IVFCfg::DumpImpl();
|
||||
ss << ", nbits: " << nbits;
|
||||
return ss;
|
||||
}
|
||||
|
||||
std::stringstream
|
||||
NSGCfg::DumpImpl() {
|
||||
auto ss = IVFCfg::DumpImpl();
|
||||
ss << ", knng: " << knng << ", search_length: " << search_length << ", out_degree: " << out_degree
|
||||
<< ", candidate: " << candidate_pool_size;
|
||||
return ss;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -12,240 +12,49 @@
|
|||
#pragma once
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <memory>
|
||||
|
||||
#include "knowhere/common/Config.h"
|
||||
#include <string>
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
namespace meta {
|
||||
constexpr const char* DIM = "dim";
|
||||
constexpr const char* TENSOR = "tensor";
|
||||
constexpr const char* ROWS = "rows";
|
||||
constexpr const char* IDS = "ids";
|
||||
constexpr const char* DISTANCE = "distance";
|
||||
constexpr const char* TOPK = "k";
|
||||
constexpr const char* DEVICEID = "gpu_id";
|
||||
}; // namespace meta
|
||||
|
||||
namespace IndexParams {
|
||||
// IVF Params
|
||||
constexpr const char* nprobe = "nprobe";
|
||||
constexpr const char* nlist = "nlist";
|
||||
constexpr const char* m = "m"; // PQ
|
||||
constexpr const char* nbits = "nbits"; // PQ/SQ
|
||||
|
||||
// NSG Params
|
||||
constexpr const char* knng = "knng";
|
||||
constexpr const char* search_length = "search_length";
|
||||
constexpr const char* out_degree = "out_degree";
|
||||
constexpr const char* candidate = "candidate_pool_size";
|
||||
|
||||
// HNSW Params
|
||||
constexpr const char* efConstruction = "efConstruction";
|
||||
constexpr const char* M = "M";
|
||||
constexpr const char* ef = "ef";
|
||||
} // namespace IndexParams
|
||||
|
||||
namespace Metric {
|
||||
constexpr const char* TYPE = "metric_type";
|
||||
constexpr const char* IP = "IP";
|
||||
constexpr const char* L2 = "L2";
|
||||
constexpr const char* HAMMING = "HAMMING";
|
||||
constexpr const char* JACCARD = "JACCARD";
|
||||
constexpr const char* TANIMOTO = "TANIMOTO";
|
||||
} // namespace Metric
|
||||
|
||||
extern faiss::MetricType
|
||||
GetMetricType(METRICTYPE& type);
|
||||
|
||||
// IVF Config
|
||||
constexpr int64_t DEFAULT_NLIST = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NPROBE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NSUBVECTORS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NBITS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_SCAN_TABLE_THREHOLD = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_POLYSEMOUS_HT = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_MAX_CODES = INVALID_VALUE;
|
||||
|
||||
// NSG Config
|
||||
constexpr int64_t DEFAULT_SEARCH_LENGTH = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_CANDIDATE_SISE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NNG_K = INVALID_VALUE;
|
||||
|
||||
// SPTAG Config
|
||||
constexpr int64_t DEFAULT_SAMPLES = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_TPTNUMBER = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_TPTLEAFSIZE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NUMTOPDIMENSIONTPTSPLIT = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NEIGHBORHOODSIZE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_GRAPHNEIGHBORHOODSCALE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_GRAPHCEFSCALE = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_REFINEITERATIONS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_CEF = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_MAXCHECKFORREFINEGRAPH = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NUMOFTHREADS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_MAXCHECK = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS = INVALID_VALUE;
|
||||
|
||||
// KDT Config
|
||||
constexpr int64_t DEFAULT_KDTNUMBER = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_NUMTOPDIMENSIONKDTSPLIT = INVALID_VALUE;
|
||||
|
||||
// BKT Config
|
||||
constexpr int64_t DEFAULT_BKTNUMBER = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_BKTKMEANSK = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_BKTLEAFSIZE = INVALID_VALUE;
|
||||
|
||||
// HNSW Config
|
||||
constexpr int64_t DEFAULT_M = INVALID_VALUE;
|
||||
constexpr int64_t DEFAULT_EF = INVALID_VALUE;
|
||||
|
||||
struct IVFCfg : public Cfg {
|
||||
int64_t nlist = DEFAULT_NLIST;
|
||||
int64_t nprobe = DEFAULT_NPROBE;
|
||||
|
||||
IVFCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
METRICTYPE type)
|
||||
: Cfg(dim, k, gpu_id, type), nlist(nlist), nprobe(nprobe) {
|
||||
}
|
||||
|
||||
IVFCfg() = default;
|
||||
|
||||
std::stringstream
|
||||
DumpImpl() override;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using IVFConfig = std::shared_ptr<IVFCfg>;
|
||||
|
||||
struct IVFBinCfg : public IVFCfg {
|
||||
bool
|
||||
CheckValid() override {
|
||||
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
|
||||
metric_type == METRICTYPE::JACCARD) {
|
||||
return true;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "MetricType: " << int(metric_type) << " not support!";
|
||||
KNOWHERE_THROW_MSG(ss.str());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct IVFSQCfg : public IVFCfg {
|
||||
// TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16
|
||||
int64_t nbits = DEFAULT_NBITS;
|
||||
|
||||
IVFSQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
const int64_t& nbits, METRICTYPE type)
|
||||
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), nbits(nbits) {
|
||||
}
|
||||
|
||||
std::stringstream
|
||||
DumpImpl() override;
|
||||
|
||||
IVFSQCfg() = default;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using IVFSQConfig = std::shared_ptr<IVFSQCfg>;
|
||||
|
||||
struct IVFPQCfg : public IVFCfg {
|
||||
int64_t m = DEFAULT_NSUBVECTORS; // number of subquantizers(subvector)
|
||||
int64_t nbits = DEFAULT_NBITS; // number of bit per subvector index
|
||||
|
||||
// TODO(linxj): not use yet
|
||||
int64_t scan_table_threhold = DEFAULT_SCAN_TABLE_THREHOLD;
|
||||
int64_t polysemous_ht = DEFAULT_POLYSEMOUS_HT;
|
||||
int64_t max_codes = DEFAULT_MAX_CODES;
|
||||
|
||||
IVFPQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
const int64_t& nbits, const int64_t& m, METRICTYPE type)
|
||||
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), m(m), nbits(nbits) {
|
||||
}
|
||||
|
||||
IVFPQCfg() = default;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using IVFPQConfig = std::shared_ptr<IVFPQCfg>;
|
||||
|
||||
struct NSGCfg : public IVFCfg {
|
||||
int64_t knng = DEFAULT_NNG_K;
|
||||
int64_t search_length = DEFAULT_SEARCH_LENGTH;
|
||||
int64_t out_degree = DEFAULT_OUT_DEGREE;
|
||||
int64_t candidate_pool_size = DEFAULT_CANDIDATE_SISE;
|
||||
|
||||
NSGCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
|
||||
const int64_t& knng, const int64_t& search_length, const int64_t& out_degree, const int64_t& candidate_size,
|
||||
METRICTYPE type)
|
||||
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type),
|
||||
knng(knng),
|
||||
search_length(search_length),
|
||||
out_degree(out_degree),
|
||||
candidate_pool_size(candidate_size) {
|
||||
}
|
||||
|
||||
NSGCfg() = default;
|
||||
|
||||
std::stringstream
|
||||
DumpImpl() override;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using NSGConfig = std::shared_ptr<NSGCfg>;
|
||||
|
||||
struct SPTAGCfg : public Cfg {
|
||||
int64_t samples = DEFAULT_SAMPLES;
|
||||
int64_t tptnumber = DEFAULT_TPTNUMBER;
|
||||
int64_t tptleafsize = DEFAULT_TPTLEAFSIZE;
|
||||
int64_t numtopdimensiontptsplit = DEFAULT_NUMTOPDIMENSIONTPTSPLIT;
|
||||
int64_t neighborhoodsize = DEFAULT_NEIGHBORHOODSIZE;
|
||||
int64_t graphneighborhoodscale = DEFAULT_GRAPHNEIGHBORHOODSCALE;
|
||||
int64_t graphcefscale = DEFAULT_GRAPHCEFSCALE;
|
||||
int64_t refineiterations = DEFAULT_REFINEITERATIONS;
|
||||
int64_t cef = DEFAULT_CEF;
|
||||
int64_t maxcheckforrefinegraph = DEFAULT_MAXCHECKFORREFINEGRAPH;
|
||||
int64_t numofthreads = DEFAULT_NUMOFTHREADS;
|
||||
int64_t maxcheck = DEFAULT_MAXCHECK;
|
||||
int64_t thresholdofnumberofcontinuousnobetterpropagation = DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION;
|
||||
int64_t numberofinitialdynamicpivots = DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS;
|
||||
int64_t numberofotherdynamicpivots = DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS;
|
||||
|
||||
SPTAGCfg() = default;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using SPTAGConfig = std::shared_ptr<SPTAGCfg>;
|
||||
|
||||
struct KDTCfg : public SPTAGCfg {
|
||||
int64_t kdtnumber = DEFAULT_KDTNUMBER;
|
||||
int64_t numtopdimensionkdtsplit = DEFAULT_NUMTOPDIMENSIONKDTSPLIT;
|
||||
|
||||
KDTCfg() = default;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using KDTConfig = std::shared_ptr<KDTCfg>;
|
||||
|
||||
struct BKTCfg : public SPTAGCfg {
|
||||
int64_t bktnumber = DEFAULT_BKTNUMBER;
|
||||
int64_t bktkmeansk = DEFAULT_BKTKMEANSK;
|
||||
int64_t bktleafsize = DEFAULT_BKTLEAFSIZE;
|
||||
|
||||
BKTCfg() = default;
|
||||
|
||||
// bool
|
||||
// CheckValid() override {
|
||||
// return true;
|
||||
// };
|
||||
};
|
||||
using BKTConfig = std::shared_ptr<BKTCfg>;
|
||||
|
||||
struct BinIDMAPCfg : public Cfg {
|
||||
bool
|
||||
CheckValid() override {
|
||||
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
|
||||
metric_type == METRICTYPE::JACCARD) {
|
||||
return true;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "MetricType: " << int(metric_type) << " not support!";
|
||||
KNOWHERE_THROW_MSG(ss.str());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct HNSWCfg : public Cfg {
|
||||
int64_t M = DEFAULT_M;
|
||||
int64_t ef = DEFAULT_EF;
|
||||
|
||||
HNSWCfg() = default;
|
||||
};
|
||||
using HNSWConfig = std::shared_ptr<HNSWCfg>;
|
||||
GetMetricType(const std::string& type);
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -15,55 +15,53 @@
|
|||
|
||||
namespace knowhere {
|
||||
|
||||
const KDTConfig&
|
||||
const Config&
|
||||
SPTAGParameterMgr::GetKDTParameters() {
|
||||
return kdt_config_;
|
||||
}
|
||||
|
||||
const BKTConfig&
|
||||
const Config&
|
||||
SPTAGParameterMgr::GetBKTParameters() {
|
||||
return bkt_config_;
|
||||
}
|
||||
|
||||
SPTAGParameterMgr::SPTAGParameterMgr() {
|
||||
kdt_config_ = std::make_shared<KDTCfg>();
|
||||
kdt_config_->kdtnumber = 1;
|
||||
kdt_config_->numtopdimensionkdtsplit = 5;
|
||||
kdt_config_->samples = 100;
|
||||
kdt_config_->tptnumber = 1;
|
||||
kdt_config_->tptleafsize = 2000;
|
||||
kdt_config_->numtopdimensiontptsplit = 5;
|
||||
kdt_config_->neighborhoodsize = 32;
|
||||
kdt_config_->graphneighborhoodscale = 2;
|
||||
kdt_config_->graphcefscale = 2;
|
||||
kdt_config_->refineiterations = 0;
|
||||
kdt_config_->cef = 1000;
|
||||
kdt_config_->maxcheckforrefinegraph = 10000;
|
||||
kdt_config_->numofthreads = 1;
|
||||
kdt_config_->maxcheck = 8192;
|
||||
kdt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
|
||||
kdt_config_->numberofinitialdynamicpivots = 50;
|
||||
kdt_config_->numberofotherdynamicpivots = 4;
|
||||
kdt_config_["kdtnumber"] = 1;
|
||||
kdt_config_["numtopdimensionkdtsplit"] = 5;
|
||||
kdt_config_["samples"] = 100;
|
||||
kdt_config_["tptnumber"] = 1;
|
||||
kdt_config_["tptleafsize"] = 2000;
|
||||
kdt_config_["numtopdimensiontptsplit"] = 5;
|
||||
kdt_config_["neighborhoodsize"] = 32;
|
||||
kdt_config_["graphneighborhoodscale"] = 2;
|
||||
kdt_config_["graphcefscale"] = 2;
|
||||
kdt_config_["refineiterations"] = 0;
|
||||
kdt_config_["cef"] = 1000;
|
||||
kdt_config_["maxcheckforrefinegraph"] = 10000;
|
||||
kdt_config_["numofthreads"] = 1;
|
||||
kdt_config_["maxcheck"] = 8192;
|
||||
kdt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3;
|
||||
kdt_config_["numberofinitialdynamicpivots"] = 50;
|
||||
kdt_config_["numberofotherdynamicpivots"] = 4;
|
||||
|
||||
bkt_config_ = std::make_shared<BKTCfg>();
|
||||
bkt_config_->bktnumber = 1;
|
||||
bkt_config_->bktkmeansk = 32;
|
||||
bkt_config_->bktleafsize = 8;
|
||||
bkt_config_->samples = 100;
|
||||
bkt_config_->tptnumber = 1;
|
||||
bkt_config_->tptleafsize = 2000;
|
||||
bkt_config_->numtopdimensiontptsplit = 5;
|
||||
bkt_config_->neighborhoodsize = 32;
|
||||
bkt_config_->graphneighborhoodscale = 2;
|
||||
bkt_config_->graphcefscale = 2;
|
||||
bkt_config_->refineiterations = 0;
|
||||
bkt_config_->cef = 1000;
|
||||
bkt_config_->maxcheckforrefinegraph = 10000;
|
||||
bkt_config_->numofthreads = 1;
|
||||
bkt_config_->maxcheck = 8192;
|
||||
bkt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
|
||||
bkt_config_->numberofinitialdynamicpivots = 50;
|
||||
bkt_config_->numberofotherdynamicpivots = 4;
|
||||
bkt_config_["bktnumber"] = 1;
|
||||
bkt_config_["bktkmeansk"] = 32;
|
||||
bkt_config_["bktleafsize"] = 8;
|
||||
bkt_config_["samples"] = 100;
|
||||
bkt_config_["tptnumber"] = 1;
|
||||
bkt_config_["tptleafsize"] = 2000;
|
||||
bkt_config_["numtopdimensiontptsplit"] = 5;
|
||||
bkt_config_["neighborhoodsize"] = 32;
|
||||
bkt_config_["graphneighborhoodscale"] = 2;
|
||||
bkt_config_["graphcefscale"] = 2;
|
||||
bkt_config_["refineiterations"] = 0;
|
||||
bkt_config_["cef"] = 1000;
|
||||
bkt_config_["maxcheckforrefinegraph"] = 10000;
|
||||
bkt_config_["numofthreads"] = 1;
|
||||
bkt_config_["maxcheck"] = 8192;
|
||||
bkt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3;
|
||||
bkt_config_["numberofinitialdynamicpivots"] = 50;
|
||||
bkt_config_["numberofotherdynamicpivots"] = 4;
|
||||
}
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -18,18 +18,15 @@
|
|||
|
||||
#include <SPTAG/AnnService/inc/Core/Common.h>
|
||||
#include "IndexParameter.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
|
||||
namespace knowhere {
|
||||
|
||||
using KDTConfig = std::shared_ptr<KDTCfg>;
|
||||
using BKTConfig = std::shared_ptr<BKTCfg>;
|
||||
|
||||
class SPTAGParameterMgr {
|
||||
public:
|
||||
const KDTConfig&
|
||||
const Config&
|
||||
GetKDTParameters();
|
||||
|
||||
const BKTConfig&
|
||||
const Config&
|
||||
GetBKTParameters();
|
||||
|
||||
public:
|
||||
|
@ -48,8 +45,8 @@ class SPTAGParameterMgr {
|
|||
SPTAGParameterMgr();
|
||||
|
||||
private:
|
||||
KDTConfig kdt_config_;
|
||||
BKTConfig bkt_config_;
|
||||
Config kdt_config_;
|
||||
Config bkt_config_;
|
||||
};
|
||||
|
||||
} // namespace knowhere
|
||||
|
|
|
@ -29,16 +29,9 @@ namespace algo {
|
|||
|
||||
unsigned int seed = 100;
|
||||
|
||||
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric)
|
||||
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric)
|
||||
: dimension(dimension), ntotal(n), metric_type(metric) {
|
||||
switch (metric) {
|
||||
case METRICTYPE::L2:
|
||||
distance_ = new DistanceL2;
|
||||
break;
|
||||
case METRICTYPE::IP:
|
||||
distance_ = new DistanceIP;
|
||||
break;
|
||||
}
|
||||
distance_ = new DistanceL2; // hardcode here
|
||||
}
|
||||
|
||||
NsgIndex::~NsgIndex() {
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include <cstddef>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
|
@ -41,8 +42,8 @@ using Graph = std::vector<std::vector<node_t>>;
|
|||
class NsgIndex {
|
||||
public:
|
||||
size_t dimension;
|
||||
size_t ntotal; // totabl nb of indexed vectors
|
||||
METRICTYPE metric_type; // L2 | IP
|
||||
size_t ntotal; // totabl nb of indexed vectors
|
||||
std::string metric_type; // todo(linxj) IP
|
||||
Distance* distance_;
|
||||
|
||||
float* ori_data_;
|
||||
|
@ -62,7 +63,7 @@ class NsgIndex {
|
|||
size_t out_degree;
|
||||
|
||||
public:
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric = METRICTYPE::L2);
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric = "L2");
|
||||
|
||||
NsgIndex() = default;
|
||||
|
||||
|
|
|
@ -14,10 +14,6 @@
|
|||
|
||||
#include <omp.h>
|
||||
|
||||
#ifdef __SSE__
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
|
|
@ -30,7 +30,6 @@ set(util_srcs
|
|||
${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/VectorAdapter.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
|
||||
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp
|
||||
${INDEX_SOURCE_DIR}/unittest/utils.cpp
|
||||
|
|
|
@ -72,35 +72,32 @@ class ParamGenerator {
|
|||
knowhere::Config
|
||||
Gen(const ParameterType& type) {
|
||||
if (type == ParameterType::ivf) {
|
||||
auto tempconf = std::make_shared<knowhere::IVFCfg>();
|
||||
tempconf->d = DIM;
|
||||
tempconf->gpu_id = DEVICEID;
|
||||
tempconf->nlist = 100;
|
||||
tempconf->nprobe = 4;
|
||||
tempconf->k = K;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
return tempconf;
|
||||
return knowhere::Config{
|
||||
{knowhere::meta::DIM, DIM},
|
||||
{knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, DEVICEID},
|
||||
};
|
||||
} else if (type == ParameterType::ivfpq) {
|
||||
auto tempconf = std::make_shared<knowhere::IVFPQCfg>();
|
||||
tempconf->d = DIM;
|
||||
tempconf->gpu_id = DEVICEID;
|
||||
tempconf->nlist = 100;
|
||||
tempconf->nprobe = 4;
|
||||
tempconf->k = K;
|
||||
tempconf->m = 4;
|
||||
tempconf->nbits = 8;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
return tempconf;
|
||||
return knowhere::Config{
|
||||
{knowhere::meta::DIM, DIM},
|
||||
{knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100},
|
||||
{knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::IndexParams::m, 4},
|
||||
{knowhere::IndexParams::nbits, 8},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, DEVICEID},
|
||||
};
|
||||
} else if (type == ParameterType::ivfsq) {
|
||||
auto tempconf = std::make_shared<knowhere::IVFSQCfg>();
|
||||
tempconf->d = DIM;
|
||||
tempconf->gpu_id = DEVICEID;
|
||||
tempconf->nlist = 100;
|
||||
tempconf->nprobe = 4;
|
||||
tempconf->k = K;
|
||||
tempconf->nbits = 8;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
return tempconf;
|
||||
return knowhere::Config{
|
||||
{knowhere::meta::DIM, DIM}, {knowhere::meta::TOPK, K},
|
||||
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4},
|
||||
{knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
{knowhere::meta::DEVICEID, DEVICEID},
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -21,7 +21,7 @@ using ::testing::Combine;
|
|||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
|
||||
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
|
@ -37,17 +37,17 @@ class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::MET
|
|||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
|
||||
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
|
||||
knowhere::METRICTYPE::HAMMING));
|
||||
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
|
||||
|
||||
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
knowhere::METRICTYPE MetricType = GetParam();
|
||||
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = MetricType;
|
||||
std::string MetricType = GetParam();
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, k},
|
||||
{knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
|
||||
index_->Train(conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
|
@ -88,11 +88,12 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
|||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
knowhere::METRICTYPE MetricType = GetParam();
|
||||
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = MetricType;
|
||||
std::string MetricType = GetParam();
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, k},
|
||||
{knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
|
|
|
@ -27,25 +27,24 @@ using ::testing::Combine;
|
|||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
|
||||
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
|
||||
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<std::string> {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
knowhere::METRICTYPE MetricType = GetParam();
|
||||
std::string MetricType = GetParam();
|
||||
Init_with_binary_default();
|
||||
// nb = 1000000;
|
||||
// nq = 1000;
|
||||
// k = 1000;
|
||||
// Generate(DIM, NB, NQ);
|
||||
index_ = std::make_shared<knowhere::BinaryIVF>();
|
||||
auto x_conf = std::make_shared<knowhere::IVFBinCfg>();
|
||||
x_conf->d = dim;
|
||||
x_conf->k = k;
|
||||
x_conf->metric_type = MetricType;
|
||||
x_conf->nlist = 100;
|
||||
x_conf->nprobe = 10;
|
||||
conf = x_conf;
|
||||
conf->Dump();
|
||||
|
||||
knowhere::Config temp_conf{
|
||||
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k},
|
||||
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 10},
|
||||
{knowhere::Metric::TYPE, MetricType},
|
||||
};
|
||||
conf = temp_conf;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -59,8 +58,7 @@ class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRI
|
|||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
|
||||
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
|
||||
knowhere::METRICTYPE::HAMMING));
|
||||
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
|
||||
|
||||
TEST_P(BinaryIVFTest, binaryivf_basic) {
|
||||
assert(!xb.empty());
|
||||
|
@ -75,7 +73,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
|||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb);
|
||||
|
@ -123,7 +121,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
|||
// index_->set_index_model(model);
|
||||
// index_->Add(base_dataset, conf);
|
||||
// auto result = index_->Search(query_dataset, conf);
|
||||
// AssertAnns(result, nq, conf->k);
|
||||
// AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// }
|
||||
|
||||
{
|
||||
|
@ -147,7 +145,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
|||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
|||
for (int i = 0; i < 3; ++i) {
|
||||
auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf);
|
||||
auto result = gpu_idx->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
}
|
||||
|
@ -86,30 +86,18 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
|||
auto quantization = pair.second;
|
||||
|
||||
auto result = gpu_idx->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
|
||||
quantizer_conf->mode = 2; // only copy data
|
||||
quantizer_conf->gpu_id = DEVICEID;
|
||||
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, DEVICEID}, {"mode", 2}};
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto hybrid_idx = std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
|
||||
hybrid_idx->Load(binaryset);
|
||||
auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf);
|
||||
auto result = new_idx->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
}
|
||||
|
||||
{
|
||||
// invalid quantizer config
|
||||
quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
|
||||
auto hybrid_idx = std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
|
||||
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, nullptr));
|
||||
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, quantizer_conf));
|
||||
quantizer_conf->mode = 2; // only copy data
|
||||
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, quantizer_conf));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -126,7 +114,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
|||
|
||||
hybrid_idx->SetQuantizer(quantization);
|
||||
auto result = hybrid_idx->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
hybrid_idx->UnsetQuantizer();
|
||||
}
|
||||
|
|
|
@ -45,10 +45,8 @@ class IDMAPTest : public DataGen, public TestGpuIndexBase {
|
|||
TEST_F(IDMAPTest, idmap_basic) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
auto conf = std::make_shared<knowhere::Cfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = knowhere::METRICTYPE::L2;
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
|
||||
|
||||
// null faiss index
|
||||
{
|
||||
|
@ -107,10 +105,8 @@ TEST_F(IDMAPTest, idmap_serialize) {
|
|||
reader(ret, bin->size);
|
||||
};
|
||||
|
||||
auto conf = std::make_shared<knowhere::Cfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = knowhere::METRICTYPE::L2;
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
|
||||
|
||||
{
|
||||
// serialize index
|
||||
|
@ -146,10 +142,8 @@ TEST_F(IDMAPTest, idmap_serialize) {
|
|||
TEST_F(IDMAPTest, copy_test) {
|
||||
ASSERT_TRUE(!xb.empty());
|
||||
|
||||
auto conf = std::make_shared<knowhere::Cfg>();
|
||||
conf->d = dim;
|
||||
conf->k = k;
|
||||
conf->metric_type = knowhere::METRICTYPE::L2;
|
||||
knowhere::Config conf{
|
||||
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
|
||||
|
||||
index_->Train(conf);
|
||||
index_->Add(base_dataset, conf);
|
||||
|
|
|
@ -62,7 +62,7 @@ class IVFTest : public DataGen, public TestWithParam<::std::tuple<std::string, P
|
|||
Generate(DIM, NB, NQ);
|
||||
index_ = IndexFactory(index_type);
|
||||
conf = ParamGenerator::GetInstance().Gen(parameter_type_);
|
||||
conf->Dump();
|
||||
// KNOWHERE_LOG_DEBUG << "conf: " << conf->dump();
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -109,7 +109,7 @@ TEST_P(IVFTest, ivf_basic) {
|
|||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
if (index_type.find("GPU") == std::string::npos && index_type.find("Hybrid") == std::string::npos &&
|
||||
|
@ -190,7 +190,7 @@ TEST_P(IVFTest, ivf_serialize) {
|
|||
index_->set_index_model(model);
|
||||
index_->Add(base_dataset, conf);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -214,7 +214,7 @@ TEST_P(IVFTest, ivf_serialize) {
|
|||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -232,7 +232,7 @@ TEST_P(IVFTest, clone_test) {
|
|||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
auto AssertEqual = [&](knowhere::DatasetPtr p1, knowhere::DatasetPtr p2) {
|
||||
|
@ -254,7 +254,7 @@ TEST_P(IVFTest, clone_test) {
|
|||
// EXPECT_NO_THROW({
|
||||
// auto clone_index = index_->Clone();
|
||||
// auto clone_result = clone_index->Search(query_dataset, conf);
|
||||
// //AssertAnns(result, nq, conf->k);
|
||||
// //AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
// AssertEqual(result, clone_result);
|
||||
// std::cout << "inplace clone [" << index_type << "] success" << std::endl;
|
||||
// });
|
||||
|
@ -339,7 +339,7 @@ TEST_P(IVFTest, gpu_seal_test) {
|
|||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dimension(), dim);
|
||||
auto result = index_->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
|
||||
|
||||
fiu_init(0);
|
||||
fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0);
|
||||
|
@ -374,7 +374,7 @@ TEST_P(IVFTest, invalid_gpu_source) {
|
|||
}
|
||||
|
||||
auto invalid_conf = ParamGenerator::GetInstance().Gen(parameter_type_);
|
||||
invalid_conf->gpu_id = -1;
|
||||
invalid_conf[knowhere::meta::DEVICEID] = -1;
|
||||
|
||||
if (index_type == "GPUIVF") {
|
||||
// null faiss index
|
||||
|
@ -430,15 +430,6 @@ TEST_P(IVFTest, IVFSQHybrid_test) {
|
|||
ASSERT_TRUE(index != nullptr);
|
||||
ASSERT_ANY_THROW(index->UnsetQuantizer());
|
||||
|
||||
knowhere::QuantizerConfig config = std::make_shared<knowhere::QuantizerCfg>();
|
||||
config->gpu_id = knowhere::INVALID_VALUE;
|
||||
|
||||
// mode = -1
|
||||
ASSERT_ANY_THROW(index->LoadQuantizer(config));
|
||||
config->mode = 1;
|
||||
ASSERT_ANY_THROW(index->LoadQuantizer(config));
|
||||
config->gpu_id = DEVICEID;
|
||||
// index->LoadQuantizer(config);
|
||||
ASSERT_ANY_THROW(index->SetQuantizer(nullptr));
|
||||
}
|
||||
|
||||
|
|
|
@ -46,24 +46,19 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test {
|
|||
Generate(256, 1000000 / 100, 1);
|
||||
index_ = std::make_shared<knowhere::NSG>();
|
||||
|
||||
auto tmp_conf = std::make_shared<knowhere::NSGCfg>();
|
||||
tmp_conf->gpu_id = DEVICEID;
|
||||
tmp_conf->d = 256;
|
||||
tmp_conf->knng = 20;
|
||||
tmp_conf->nprobe = 8;
|
||||
tmp_conf->nlist = 163;
|
||||
tmp_conf->search_length = 40;
|
||||
tmp_conf->out_degree = 30;
|
||||
tmp_conf->candidate_pool_size = 100;
|
||||
tmp_conf->metric_type = knowhere::METRICTYPE::L2;
|
||||
train_conf = tmp_conf;
|
||||
train_conf->Dump();
|
||||
train_conf = knowhere::Config{{knowhere::meta::DIM, 256},
|
||||
{knowhere::IndexParams::nlist, 163},
|
||||
{knowhere::IndexParams::nprobe, 8},
|
||||
{knowhere::IndexParams::knng, 20},
|
||||
{knowhere::IndexParams::search_length, 40},
|
||||
{knowhere::IndexParams::out_degree, 30},
|
||||
{knowhere::IndexParams::candidate, 100},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2}};
|
||||
|
||||
auto tmp2_conf = std::make_shared<knowhere::NSGCfg>();
|
||||
tmp2_conf->k = k;
|
||||
tmp2_conf->search_length = 30;
|
||||
search_conf = tmp2_conf;
|
||||
search_conf->Dump();
|
||||
search_conf = knowhere::Config{
|
||||
{knowhere::meta::TOPK, k},
|
||||
{knowhere::IndexParams::search_length, 30},
|
||||
};
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -87,9 +82,9 @@ TEST_F(NSGInterfaceTest, basic_test) {
|
|||
ASSERT_ANY_THROW(index_->Search(query_dataset, search_conf));
|
||||
ASSERT_ANY_THROW(index_->Serialize());
|
||||
}
|
||||
train_conf->gpu_id = knowhere::INVALID_VALUE;
|
||||
auto model_invalid_gpu = index_->Train(base_dataset, train_conf);
|
||||
train_conf->gpu_id = DEVICEID;
|
||||
// train_conf->gpu_id = knowhere::INVALID_VALUE;
|
||||
// auto model_invalid_gpu = index_->Train(base_dataset, train_conf);
|
||||
train_conf[knowhere::meta::DEVICEID] = DEVICEID;
|
||||
auto model = index_->Train(base_dataset, train_conf);
|
||||
auto result = index_->Search(query_dataset, search_conf);
|
||||
AssertAnns(result, nq, k);
|
||||
|
|
|
@ -33,17 +33,17 @@ class SPTAGTest : public DataGen, public TestWithParam<std::string> {
|
|||
Generate(128, 100, 5);
|
||||
index_ = std::make_shared<knowhere::CPUSPTAGRNG>(IndexType);
|
||||
if (IndexType == "KDT") {
|
||||
auto tempconf = std::make_shared<knowhere::KDTCfg>();
|
||||
tempconf->tptnumber = 1;
|
||||
tempconf->k = 10;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
conf = tempconf;
|
||||
conf = knowhere::Config{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, 10},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
};
|
||||
} else {
|
||||
auto tempconf = std::make_shared<knowhere::BKTCfg>();
|
||||
tempconf->tptnumber = 1;
|
||||
tempconf->k = 10;
|
||||
tempconf->metric_type = knowhere::METRICTYPE::L2;
|
||||
conf = tempconf;
|
||||
conf = knowhere::Config{
|
||||
{knowhere::meta::DIM, dim},
|
||||
{knowhere::meta::TOPK, 10},
|
||||
{knowhere::Metric::TYPE, knowhere::Metric::L2},
|
||||
};
|
||||
}
|
||||
|
||||
Init_with_default();
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
|
||||
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, const milvus::json& extra_params,
|
||||
const engine::VectorsData& vectors)
|
||||
: Job(JobType::SEARCH), context_(context), topk_(topk), nprobe_(nprobe), vectors_(vectors) {
|
||||
: Job(JobType::SEARCH), context_(context), topk_(topk), extra_params_(extra_params), vectors_(vectors) {
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -72,7 +72,7 @@ SearchJob::Dump() const {
|
|||
json ret{
|
||||
{"topk", topk_},
|
||||
{"nq", vectors_.vector_count_},
|
||||
{"nprobe", nprobe_},
|
||||
{"extra_params", extra_params_.dump()},
|
||||
};
|
||||
auto base = Job::Dump();
|
||||
ret.insert(base.begin(), base.end());
|
||||
|
|
|
@ -40,7 +40,7 @@ using ResultDistances = engine::ResultDistances;
|
|||
|
||||
class SearchJob : public Job {
|
||||
public:
|
||||
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
|
||||
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, const milvus::json& extra_params,
|
||||
const engine::VectorsData& vectors);
|
||||
|
||||
public:
|
||||
|
@ -79,9 +79,9 @@ class SearchJob : public Job {
|
|||
return vectors_.vector_count_;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
nprobe() const {
|
||||
return nprobe_;
|
||||
const milvus::json&
|
||||
extra_params() const {
|
||||
return extra_params_;
|
||||
}
|
||||
|
||||
const engine::VectorsData&
|
||||
|
@ -103,7 +103,7 @@ class SearchJob : public Job {
|
|||
const std::shared_ptr<server::Context> context_;
|
||||
|
||||
uint64_t topk_ = 0;
|
||||
uint64_t nprobe_ = 0;
|
||||
milvus::json extra_params_;
|
||||
// TODO: smart pointer
|
||||
const engine::VectorsData& vectors_;
|
||||
|
||||
|
|
|
@ -41,8 +41,9 @@ XBuildIndexTask::XBuildIndexTask(TableFileSchemaPtr file, TaskLabelPtr label)
|
|||
engine_type = (EngineType)file->engine_type_;
|
||||
}
|
||||
|
||||
auto json = milvus::json::parse(file_->index_params_);
|
||||
to_index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
|
||||
(MetricType)file_->metric_type_, file_->nlist_);
|
||||
(MetricType)file_->metric_type_, json);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -115,8 +115,12 @@ XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, TableF
|
|||
engine_type = (EngineType)file->engine_type_;
|
||||
}
|
||||
|
||||
milvus::json json_params;
|
||||
if (!file_->index_params_.empty()) {
|
||||
json_params = milvus::json::parse(file_->index_params_);
|
||||
}
|
||||
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
|
||||
(MetricType)file_->metric_type_, file_->nlist_);
|
||||
(MetricType)file_->metric_type_, json_params);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -217,7 +221,8 @@ XSearchTask::Execute() {
|
|||
// step 1: allocate memory
|
||||
uint64_t nq = search_job->nq();
|
||||
uint64_t topk = search_job->topk();
|
||||
uint64_t nprobe = search_job->nprobe();
|
||||
const milvus::json& extra_params = search_job->extra_params();
|
||||
ENGINE_LOG_DEBUG << "Search job extra params: " << extra_params.dump();
|
||||
const engine::VectorsData& vectors = search_job->vectors();
|
||||
|
||||
output_ids.resize(topk * nq);
|
||||
|
@ -235,13 +240,13 @@ XSearchTask::Execute() {
|
|||
}
|
||||
Status s;
|
||||
if (!vectors.float_data_.empty()) {
|
||||
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, nprobe, output_distance.data(),
|
||||
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, extra_params, output_distance.data(),
|
||||
output_ids.data(), hybrid);
|
||||
} else if (!vectors.binary_data_.empty()) {
|
||||
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, nprobe, output_distance.data(),
|
||||
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, extra_params, output_distance.data(),
|
||||
output_ids.data(), hybrid);
|
||||
} else if (!vectors.id_array_.empty()) {
|
||||
s = index_engine_->Search(nq, vectors.id_array_, topk, nprobe, output_distance.data(),
|
||||
s = index_engine_->Search(nq, vectors.id_array_, topk, extra_params, output_distance.data(),
|
||||
output_ids.data(), hybrid);
|
||||
}
|
||||
|
||||
|
|
|
@ -549,9 +549,14 @@ Config::UpdateFileConfigFromMem(const std::string& parent_key, const std::string
|
|||
// convert value string to standard string stored in yaml file
|
||||
std::string value_str;
|
||||
if (child_key == CONFIG_CACHE_CACHE_INSERT_DATA || child_key == CONFIG_STORAGE_S3_ENABLE ||
|
||||
child_key == CONFIG_METRIC_ENABLE_MONITOR || child_key == CONFIG_GPU_RESOURCE_ENABLE) {
|
||||
value_str =
|
||||
(value == "True" || value == "true" || value == "On" || value == "on" || value == "1") ? "true" : "false";
|
||||
child_key == CONFIG_METRIC_ENABLE_MONITOR || child_key == CONFIG_GPU_RESOURCE_ENABLE ||
|
||||
child_key == CONFIG_WAL_ENABLE || child_key == CONFIG_WAL_RECOVERY_ERROR_IGNORE) {
|
||||
bool ok = false;
|
||||
status = StringHelpFunctions::ConvertToBoolean(value, ok);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
value_str = ok ? "true" : "false";
|
||||
} else if (child_key == CONFIG_GPU_RESOURCE_SEARCH_RESOURCES ||
|
||||
child_key == CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES) {
|
||||
std::vector<std::string> vec;
|
||||
|
@ -593,7 +598,6 @@ Config::UpdateFileConfigFromMem(const std::string& parent_key, const std::string
|
|||
}
|
||||
|
||||
// values of gpu resources are sequences, need to remove old here
|
||||
std::regex reg("\\S*");
|
||||
if (child_key == CONFIG_GPU_RESOURCE_SEARCH_RESOURCES || child_key == CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES) {
|
||||
while (getline(conf_fin, line)) {
|
||||
if (line.find("- gpu") != std::string::npos)
|
||||
|
@ -1009,8 +1013,7 @@ Config::CheckCacheConfigCpuCacheCapacity(const std::string& value) {
|
|||
std::cerr << "WARNING: cpu cache capacity value is too big" << std::endl;
|
||||
}
|
||||
|
||||
std::string str =
|
||||
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_INSERT_BUFFER_SIZE, CONFIG_CACHE_INSERT_BUFFER_SIZE_DEFAULT);
|
||||
std::string str = GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_INSERT_BUFFER_SIZE, "0");
|
||||
int64_t buffer_value = std::stoll(str);
|
||||
|
||||
int64_t insert_buffer_size = buffer_value * GB;
|
||||
|
@ -1059,9 +1062,8 @@ Config::CheckCacheConfigInsertBufferSize(const std::string& value) {
|
|||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
|
||||
std::string str =
|
||||
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT);
|
||||
int64_t cache_size = std::stoll(str);
|
||||
std::string str = GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, "0");
|
||||
int64_t cache_size = std::stoll(str) * GB;
|
||||
|
||||
uint64_t total_mem = 0, free_mem = 0;
|
||||
CommonUtil::GetSystemMemInfo(total_mem, free_mem);
|
||||
|
@ -1855,7 +1857,8 @@ Config::SetDBConfigBackendUrl(const std::string& value) {
|
|||
Status
|
||||
Config::SetDBConfigPreloadTable(const std::string& value) {
|
||||
CONFIG_CHECK(CheckDBConfigPreloadTable(value));
|
||||
return SetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRELOAD_TABLE, value);
|
||||
std::string cor_value = value == "*" ? "\'*\'" : value;
|
||||
return SetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRELOAD_TABLE, cor_value);
|
||||
}
|
||||
|
||||
Status
|
||||
|
|
|
@ -70,8 +70,8 @@ RequestHandler::DropTable(const std::shared_ptr<Context>& context, const std::st
|
|||
|
||||
Status
|
||||
RequestHandler::CreateIndex(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
|
||||
int64_t nlist) {
|
||||
BaseRequestPtr request_ptr = CreateIndexRequest::Create(context, table_name, index_type, nlist);
|
||||
const milvus::json& json_params) {
|
||||
BaseRequestPtr request_ptr = CreateIndexRequest::Create(context, table_name, index_type, json_params);
|
||||
RequestScheduler::ExecRequest(request_ptr);
|
||||
|
||||
return request_ptr->status();
|
||||
|
@ -123,11 +123,11 @@ RequestHandler::ShowTableInfo(const std::shared_ptr<Context>& context, const std
|
|||
|
||||
Status
|
||||
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
|
||||
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
|
||||
TopKQueryResult& result) {
|
||||
BaseRequestPtr request_ptr =
|
||||
SearchRequest::Create(context, table_name, vectors, topk, nprobe, partition_list, file_id_list, result);
|
||||
SearchRequest::Create(context, table_name, vectors, topk, extra_params, partition_list, file_id_list, result);
|
||||
RequestScheduler::ExecRequest(request_ptr);
|
||||
|
||||
return request_ptr->status();
|
||||
|
@ -135,10 +135,10 @@ RequestHandler::Search(const std::shared_ptr<Context>& context, const std::strin
|
|||
|
||||
Status
|
||||
RequestHandler::SearchByID(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
|
||||
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
TopKQueryResult& result) {
|
||||
int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list, TopKQueryResult& result) {
|
||||
BaseRequestPtr request_ptr =
|
||||
SearchByIDRequest::Create(context, table_name, vector_id, topk, nprobe, partition_list, result);
|
||||
SearchByIDRequest::Create(context, table_name, vector_id, topk, extra_params, partition_list, result);
|
||||
RequestScheduler::ExecRequest(request_ptr);
|
||||
|
||||
return request_ptr->status();
|
||||
|
|
|
@ -38,7 +38,7 @@ class RequestHandler {
|
|||
|
||||
Status
|
||||
CreateIndex(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
|
||||
int64_t nlist);
|
||||
const milvus::json& json_params);
|
||||
|
||||
Status
|
||||
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
|
||||
|
@ -60,12 +60,13 @@ class RequestHandler {
|
|||
|
||||
Status
|
||||
Search(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
|
||||
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
|
||||
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
|
||||
|
||||
Status
|
||||
SearchByID(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id, int64_t topk,
|
||||
int64_t nprobe, const std::vector<std::string>& partition_list, TopKQueryResult& result);
|
||||
const milvus::json& extra_params, const std::vector<std::string>& partition_list,
|
||||
TopKQueryResult& result);
|
||||
|
||||
Status
|
||||
DescribeTable(const std::shared_ptr<Context>& context, const std::string& table_name, TableSchema& table_schema);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "grpc/gen-status/status.grpc.pb.h"
|
||||
#include "grpc/gen-status/status.pb.h"
|
||||
#include "server/context/Context.h"
|
||||
#include "utils/Json.h"
|
||||
#include "utils/Status.h"
|
||||
|
||||
#include <condition_variable>
|
||||
|
@ -73,17 +74,15 @@ struct TopKQueryResult {
|
|||
struct IndexParam {
|
||||
std::string table_name_;
|
||||
int64_t index_type_;
|
||||
int64_t nlist_;
|
||||
std::string extra_params_;
|
||||
|
||||
IndexParam() {
|
||||
index_type_ = 0;
|
||||
nlist_ = 0;
|
||||
}
|
||||
|
||||
IndexParam(const std::string& table_name, int64_t index_type, int64_t nlist) {
|
||||
IndexParam(const std::string& table_name, int64_t index_type) {
|
||||
table_name_ = table_name;
|
||||
index_type_ = index_type;
|
||||
nlist_ = nlist;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -24,14 +24,17 @@ namespace milvus {
|
|||
namespace server {
|
||||
|
||||
CreateIndexRequest::CreateIndexRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
int64_t index_type, int64_t nlist)
|
||||
: BaseRequest(context, DDL_DML_REQUEST_GROUP), table_name_(table_name), index_type_(index_type), nlist_(nlist) {
|
||||
int64_t index_type, const milvus::json& json_params)
|
||||
: BaseRequest(context, DDL_DML_REQUEST_GROUP),
|
||||
table_name_(table_name),
|
||||
index_type_(index_type),
|
||||
json_params_(json_params) {
|
||||
}
|
||||
|
||||
BaseRequestPtr
|
||||
CreateIndexRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
|
||||
int64_t nlist) {
|
||||
return std::shared_ptr<BaseRequest>(new CreateIndexRequest(context, table_name, index_type, nlist));
|
||||
const milvus::json& json_params) {
|
||||
return std::shared_ptr<BaseRequest>(new CreateIndexRequest(context, table_name, index_type, json_params));
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -69,7 +72,7 @@ CreateIndexRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = ValidationUtil::ValidateTableIndexNlist(nlist_);
|
||||
status = ValidationUtil::ValidateIndexParams(json_params_, table_schema, index_type_);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -109,7 +112,7 @@ CreateIndexRequest::OnExecute() {
|
|||
// step 3: create index
|
||||
engine::TableIndex index;
|
||||
index.engine_type_ = adapter_index_type;
|
||||
index.nlist_ = nlist_;
|
||||
index.extra_params_ = json_params_;
|
||||
status = DBWrapper::DB()->CreateIndex(table_name_, index);
|
||||
fiu_do_on("CreateIndexRequest.OnExecute.create_index_fail",
|
||||
status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
|
||||
|
|
|
@ -21,11 +21,12 @@ namespace server {
|
|||
class CreateIndexRequest : public BaseRequest {
|
||||
public:
|
||||
static BaseRequestPtr
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type, int64_t nlist);
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
|
||||
const milvus::json& json_params);
|
||||
|
||||
protected:
|
||||
CreateIndexRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
|
||||
int64_t nlist);
|
||||
const milvus::json& json_params);
|
||||
|
||||
Status
|
||||
OnExecute() override;
|
||||
|
@ -33,7 +34,7 @@ class CreateIndexRequest : public BaseRequest {
|
|||
private:
|
||||
const std::string table_name_;
|
||||
const int64_t index_type_;
|
||||
const int64_t nlist_;
|
||||
milvus::json json_params_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -78,7 +78,7 @@ DescribeIndexRequest::OnExecute() {
|
|||
|
||||
index_param_.table_name_ = table_name_;
|
||||
index_param_.index_type_ = index.engine_type_;
|
||||
index_param_.nlist_ = index.nlist_;
|
||||
index_param_.extra_params_ = index.extra_params_.dump();
|
||||
} catch (std::exception& ex) {
|
||||
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
|
||||
}
|
||||
|
|
|
@ -34,23 +34,23 @@ namespace milvus {
|
|||
namespace server {
|
||||
|
||||
SearchByIDRequest::SearchByIDRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
int64_t vector_id, int64_t topk, int64_t nprobe,
|
||||
int64_t vector_id, int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list, TopKQueryResult& result)
|
||||
: BaseRequest(context, DQL_REQUEST_GROUP),
|
||||
table_name_(table_name),
|
||||
vector_id_(vector_id),
|
||||
topk_(topk),
|
||||
nprobe_(nprobe),
|
||||
extra_params_(extra_params),
|
||||
partition_list_(partition_list),
|
||||
result_(result) {
|
||||
}
|
||||
|
||||
BaseRequestPtr
|
||||
SearchByIDRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
|
||||
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
TopKQueryResult& result) {
|
||||
int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list, TopKQueryResult& result) {
|
||||
return std::shared_ptr<BaseRequest>(
|
||||
new SearchByIDRequest(context, table_name, vector_id, topk, nprobe, partition_list, result));
|
||||
new SearchByIDRequest(context, table_name, vector_id, topk, extra_params, partition_list, result));
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -59,7 +59,7 @@ SearchByIDRequest::OnExecute() {
|
|||
auto pre_query_ctx = context_->Child("Pre query");
|
||||
|
||||
std::string hdr = "SearchByIDRequest(table=" + table_name_ + ", id=" + std::to_string(vector_id_) +
|
||||
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
|
||||
", k=" + std::to_string(topk_) + ", extra_params=" + extra_params_.dump() + ")";
|
||||
|
||||
TimeRecorder rc(hdr);
|
||||
|
||||
|
@ -88,6 +88,11 @@ SearchByIDRequest::OnExecute() {
|
|||
}
|
||||
}
|
||||
|
||||
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Check whether GPU search resource is enabled
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
Config& config = Config::GetInstance();
|
||||
|
@ -122,11 +127,6 @@ SearchByIDRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = ValidationUtil::ValidateSearchNprobe(nprobe_, table_schema);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
rc.RecordSection("check validation");
|
||||
|
||||
// step 5: search vectors
|
||||
|
@ -140,8 +140,8 @@ SearchByIDRequest::OnExecute() {
|
|||
|
||||
pre_query_ctx->GetTraceContext()->GetSpan()->Finish();
|
||||
|
||||
status = DBWrapper::DB()->QueryByID(context_, table_name_, partition_list_, (size_t)topk_, nprobe_, vector_id_,
|
||||
result_ids, result_distances);
|
||||
status = DBWrapper::DB()->QueryByID(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
|
||||
vector_id_, result_ids, result_distances);
|
||||
|
||||
#ifdef MILVUS_ENABLE_PROFILING
|
||||
ProfilerStop();
|
||||
|
|
|
@ -30,11 +30,11 @@ class SearchByIDRequest : public BaseRequest {
|
|||
public:
|
||||
static BaseRequestPtr
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id, int64_t topk,
|
||||
int64_t nprobe, const std::vector<std::string>& partition_list, TopKQueryResult& result);
|
||||
const milvus::json& extra_params, const std::vector<std::string>& partition_list, TopKQueryResult& result);
|
||||
|
||||
protected:
|
||||
SearchByIDRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
|
||||
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
|
||||
TopKQueryResult& result);
|
||||
|
||||
Status
|
||||
|
@ -44,7 +44,7 @@ class SearchByIDRequest : public BaseRequest {
|
|||
const std::string table_name_;
|
||||
const int64_t vector_id_;
|
||||
int64_t topk_;
|
||||
int64_t nprobe_;
|
||||
milvus::json extra_params_;
|
||||
const std::vector<std::string> partition_list_;
|
||||
|
||||
TopKQueryResult& result_;
|
||||
|
|
|
@ -26,14 +26,14 @@ namespace milvus {
|
|||
namespace server {
|
||||
|
||||
SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
|
||||
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list,
|
||||
const std::vector<std::string>& file_id_list, TopKQueryResult& result)
|
||||
: BaseRequest(context, DQL_REQUEST_GROUP),
|
||||
table_name_(table_name),
|
||||
vectors_data_(vectors),
|
||||
topk_(topk),
|
||||
nprobe_(nprobe),
|
||||
extra_params_(extra_params),
|
||||
partition_list_(partition_list),
|
||||
file_id_list_(file_id_list),
|
||||
result_(result) {
|
||||
|
@ -41,11 +41,11 @@ SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std:
|
|||
|
||||
BaseRequestPtr
|
||||
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
|
||||
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
|
||||
TopKQueryResult& result) {
|
||||
return std::shared_ptr<BaseRequest>(
|
||||
new SearchRequest(context, table_name, vectors, topk, nprobe, partition_list, file_id_list, result));
|
||||
new SearchRequest(context, table_name, vectors, topk, extra_params, partition_list, file_id_list, result));
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -56,7 +56,7 @@ SearchRequest::OnExecute() {
|
|||
auto pre_query_ctx = context_->Child("Pre query");
|
||||
|
||||
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
|
||||
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
|
||||
", k=" + std::to_string(topk_) + ", extra_params=" + extra_params_.dump() + ")";
|
||||
|
||||
TimeRecorder rc(hdr);
|
||||
|
||||
|
@ -84,13 +84,13 @@ SearchRequest::OnExecute() {
|
|||
}
|
||||
}
|
||||
|
||||
// step 3: check search parameter
|
||||
status = ValidationUtil::ValidateSearchTopk(topk_, table_schema);
|
||||
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = ValidationUtil::ValidateSearchNprobe(nprobe_, table_schema);
|
||||
// step 3: check search parameter
|
||||
status = ValidationUtil::ValidateSearchTopk(topk_, table_schema);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -150,10 +150,10 @@ SearchRequest::OnExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, nprobe_,
|
||||
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
|
||||
vectors_data_, result_ids, result_distances);
|
||||
} else {
|
||||
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, nprobe_,
|
||||
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, extra_params_,
|
||||
vectors_data_, result_ids, result_distances);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,12 +24,12 @@ class SearchRequest : public BaseRequest {
|
|||
public:
|
||||
static BaseRequestPtr
|
||||
Create(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
|
||||
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
|
||||
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
|
||||
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
|
||||
|
||||
protected:
|
||||
SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
|
||||
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
|
||||
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
|
||||
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
|
||||
TopKQueryResult& result);
|
||||
|
||||
|
@ -40,7 +40,7 @@ class SearchRequest : public BaseRequest {
|
|||
const std::string table_name_;
|
||||
const engine::VectorsData& vectors_data_;
|
||||
int64_t topk_;
|
||||
int64_t nprobe_;
|
||||
milvus::json extra_params_;
|
||||
const std::vector<std::string> partition_list_;
|
||||
const std::vector<std::string> file_id_list_;
|
||||
|
||||
|
|
|
@ -280,8 +280,16 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::
|
|||
::milvus::grpc::Status* response) {
|
||||
CHECK_NULLPTR_RETURN(request);
|
||||
|
||||
Status status = request_handler_.CreateIndex(context_map_[context], request->table_name(),
|
||||
request->index().index_type(), request->index().nlist());
|
||||
milvus::json json_params;
|
||||
for (int i = 0; i < request->extra_params_size(); i++) {
|
||||
const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
|
||||
if (extra.key() == EXTRA_PARAM_KEY) {
|
||||
json_params = json::parse(extra.value());
|
||||
}
|
||||
}
|
||||
|
||||
Status status =
|
||||
request_handler_.CreateIndex(context_map_[context], request->table_name(), request->index_type(), json_params);
|
||||
|
||||
SET_RESPONSE(response, status, context);
|
||||
return ::grpc::Status::OK;
|
||||
|
@ -366,14 +374,23 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
|
|||
partitions.emplace_back(partition);
|
||||
}
|
||||
|
||||
// step 3: search vectors
|
||||
// step 3: parse extra parameters
|
||||
milvus::json json_params;
|
||||
for (int i = 0; i < request->extra_params_size(); i++) {
|
||||
const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
|
||||
if (extra.key() == EXTRA_PARAM_KEY) {
|
||||
json_params = json::parse(extra.value());
|
||||
}
|
||||
}
|
||||
|
||||
// step 4: search vectors
|
||||
std::vector<std::string> file_ids;
|
||||
TopKQueryResult result;
|
||||
fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id"));
|
||||
Status status = request_handler_.Search(context_map_[context], request->table_name(), vectors, request->topk(),
|
||||
request->nprobe(), partitions, file_ids, result);
|
||||
json_params, partitions, file_ids, result);
|
||||
|
||||
// step 4: construct and return result
|
||||
// step 5: construct and return result
|
||||
ConstructResults(result, response);
|
||||
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
|
@ -392,12 +409,21 @@ GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::g
|
|||
partitions.emplace_back(partition);
|
||||
}
|
||||
|
||||
// step 2: search vectors
|
||||
// step 2: parse extra parameters
|
||||
milvus::json json_params;
|
||||
for (int i = 0; i < request->extra_params_size(); i++) {
|
||||
const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
|
||||
if (extra.key() == EXTRA_PARAM_KEY) {
|
||||
json_params = json::parse(extra.value());
|
||||
}
|
||||
}
|
||||
|
||||
// step 3: search vectors
|
||||
TopKQueryResult result;
|
||||
Status status = request_handler_.SearchByID(context_map_[context], request->table_name(), request->id(),
|
||||
request->topk(), request->nprobe(), partitions, result);
|
||||
request->topk(), json_params, partitions, result);
|
||||
|
||||
// step 3: construct and return result
|
||||
// step 4: construct and return result
|
||||
ConstructResults(result, response);
|
||||
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
|
@ -429,13 +455,21 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus
|
|||
partitions.emplace_back(partition);
|
||||
}
|
||||
|
||||
// step 4: search vectors
|
||||
TopKQueryResult result;
|
||||
Status status =
|
||||
request_handler_.Search(context_map_[context], search_request->table_name(), vectors, search_request->topk(),
|
||||
search_request->nprobe(), partitions, file_ids, result);
|
||||
// step 4: parse extra parameters
|
||||
milvus::json json_params;
|
||||
for (int i = 0; i < search_request->extra_params_size(); i++) {
|
||||
const ::milvus::grpc::KeyValuePair& extra = search_request->extra_params(i);
|
||||
if (extra.key() == EXTRA_PARAM_KEY) {
|
||||
json_params = json::parse(extra.value());
|
||||
}
|
||||
}
|
||||
|
||||
// step 5: construct and return result
|
||||
// step 5: search vectors
|
||||
TopKQueryResult result;
|
||||
Status status = request_handler_.Search(context_map_[context], search_request->table_name(), vectors,
|
||||
search_request->topk(), json_params, partitions, file_ids, result);
|
||||
|
||||
// step 6: construct and return result
|
||||
ConstructResults(result, response);
|
||||
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
|
@ -549,8 +583,10 @@ GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus
|
|||
IndexParam param;
|
||||
Status status = request_handler_.DescribeIndex(context_map_[context], request->table_name(), param);
|
||||
response->set_table_name(param.table_name_);
|
||||
response->mutable_index()->set_index_type(param.index_type_);
|
||||
response->mutable_index()->set_nlist(param.nlist_);
|
||||
response->set_index_type(param.index_type_);
|
||||
::milvus::grpc::KeyValuePair* kv = response->add_extra_params();
|
||||
kv->set_key(EXTRA_PARAM_KEY);
|
||||
kv->set_value(param.extra_params_);
|
||||
SET_RESPONSE(response->mutable_status(), status, context);
|
||||
|
||||
return ::grpc::Status::OK;
|
||||
|
|
|
@ -56,6 +56,8 @@ namespace grpc {
|
|||
::milvus::grpc::ErrorCode
|
||||
ErrorMap(ErrorCode code);
|
||||
|
||||
static const char* EXTRA_PARAM_KEY = "params";
|
||||
|
||||
class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service, public GrpcInterceptorHookHandler {
|
||||
public:
|
||||
explicit GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer);
|
||||
|
|
|
@ -208,14 +208,14 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(TablesOptions)
|
||||
|
||||
ENDPOINT("OPTIONS", "/tables", TablesOptions) {
|
||||
ENDPOINT("OPTIONS", "/collections", TablesOptions) {
|
||||
return createResponse(Status::CODE_204, "No Content");
|
||||
}
|
||||
|
||||
ADD_CORS(CreateTable)
|
||||
|
||||
ENDPOINT("POST", "/tables", CreateTable, BODY_DTO(TableRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables\'");
|
||||
ENDPOINT("POST", "/collections", CreateTable, BODY_DTO(TableRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
@ -238,8 +238,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(ShowTables)
|
||||
|
||||
ENDPOINT("GET", "/tables", ShowTables, QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables\'");
|
||||
ENDPOINT("GET", "/collections", ShowTables, QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
@ -265,21 +265,21 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(TableOptions)
|
||||
|
||||
ENDPOINT("OPTIONS", "/tables/{table_name}", TableOptions) {
|
||||
ENDPOINT("OPTIONS", "/collections/{collection_name}", TableOptions) {
|
||||
return createResponse(Status::CODE_204, "No Content");
|
||||
}
|
||||
|
||||
ADD_CORS(GetTable)
|
||||
|
||||
ENDPOINT("GET", "/tables/{table_name}", GetTable,
|
||||
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables/" + table_name->std_str() + "\'");
|
||||
ENDPOINT("GET", "/collections/{collection_name}", GetTable,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
String response_str;
|
||||
auto status_dto = handler.GetTable(table_name, query_params, response_str);
|
||||
auto status_dto = handler.GetTable(collection_name, query_params, response_str);
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
switch (status_dto->code->getValue()) {
|
||||
|
@ -302,14 +302,14 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(DropTable)
|
||||
|
||||
ENDPOINT("DELETE", "/tables/{table_name}", DropTable, PATH(String, table_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/tables/" + table_name->std_str() + "\'");
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}", DropTable, PATH(String, collection_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.DropTable(table_name);
|
||||
auto status_dto = handler.DropTable(collection_name);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
|
@ -330,21 +330,21 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(IndexOptions)
|
||||
|
||||
ENDPOINT("OPTIONS", "/tables/{table_name}/indexes", IndexOptions) {
|
||||
ENDPOINT("OPTIONS", "/collections/{collection_name}/indexes", IndexOptions) {
|
||||
return createResponse(Status::CODE_204, "No Content");
|
||||
}
|
||||
|
||||
ADD_CORS(CreateIndex)
|
||||
|
||||
ENDPOINT("POST", "/tables/{table_name}/indexes", CreateIndex,
|
||||
PATH(String, table_name), BODY_DTO(IndexRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + table_name->std_str() + "/indexes\'");
|
||||
ENDPOINT("POST", "/tables/{collection_name}/indexes", CreateIndex,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.CreateIndex(table_name, body);
|
||||
auto status_dto = handler.CreateIndex(collection_name, body);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
|
@ -365,19 +365,19 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(GetIndex)
|
||||
|
||||
ENDPOINT("GET", "/tables/{table_name}/indexes", GetIndex, PATH(String, table_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables/" + table_name->std_str() + "/indexes\'");
|
||||
ENDPOINT("GET", "/collections/{collection_name}/indexes", GetIndex, PATH(String, collection_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto index_dto = IndexDto::createShared();
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
auto status_dto = handler.GetIndex(table_name, index_dto);
|
||||
OString result;
|
||||
auto status_dto = handler.GetIndex(collection_name, result);
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_200, index_dto);
|
||||
response = createResponse(Status::CODE_200, result);
|
||||
break;
|
||||
case StatusCode::TABLE_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
|
@ -395,14 +395,14 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(DropIndex)
|
||||
|
||||
ENDPOINT("DELETE", "/tables/{table_name}/indexes", DropIndex, PATH(String, table_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/tables/" + table_name->std_str() + "/indexes\'");
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}/indexes", DropIndex, PATH(String, collection_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.DropIndex(table_name);
|
||||
auto status_dto = handler.DropIndex(collection_name);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
|
@ -423,21 +423,21 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(PartitionsOptions)
|
||||
|
||||
ENDPOINT("OPTIONS", "/tables/{table_name}/partitions", PartitionsOptions) {
|
||||
ENDPOINT("OPTIONS", "/collections/{collection_name}/partitions", PartitionsOptions) {
|
||||
return createResponse(Status::CODE_204, "No Content");
|
||||
}
|
||||
|
||||
ADD_CORS(CreatePartition)
|
||||
|
||||
ENDPOINT("POST", "/tables/{table_name}/partitions",
|
||||
CreatePartition, PATH(String, table_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + table_name->std_str() + "/partitions\'");
|
||||
ENDPOINT("POST", "/collections/{collection_name}/partitions",
|
||||
CreatePartition, PATH(String, collection_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/partitions\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.CreatePartition(table_name, body);
|
||||
auto status_dto = handler.CreatePartition(collection_name, body);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
|
@ -457,9 +457,9 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(ShowPartitions)
|
||||
|
||||
ENDPOINT("GET", "/tables/{table_name}/partitions", ShowPartitions,
|
||||
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables/" + table_name->std_str() + "/partitions\'");
|
||||
ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/partitions\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto offset = query_params.get("offset");
|
||||
|
@ -469,7 +469,7 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.ShowPartitions(table_name, query_params, partition_list_dto);
|
||||
auto status_dto = handler.ShowPartitions(collection_name, query_params, partition_list_dto);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_200, partition_list_dto);
|
||||
|
@ -488,16 +488,16 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(DropPartition)
|
||||
|
||||
ENDPOINT("DELETE", "/tables/{table_name}/partitions", DropPartition,
|
||||
PATH(String, table_name), BODY_STRING(String, body)) {
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) +
|
||||
"DELETE \'/tables/" + table_name->std_str() + "/partitions\'");
|
||||
"DELETE \'/collections/" + collection_name->std_str() + "/partitions\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.DropPartition(table_name, body);
|
||||
auto status_dto = handler.DropPartition(collection_name, body);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_204, status_dto);
|
||||
|
@ -517,14 +517,14 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(ShowSegments)
|
||||
|
||||
ENDPOINT("GET", "/tables/{table_name}/segments", ShowSegments,
|
||||
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
|
||||
ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
auto offset = query_params.get("offset");
|
||||
auto page_size = query_params.get("page_size");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
String response;
|
||||
auto status_dto = handler.ShowSegments(table_name, query_params, response);
|
||||
auto status_dto = handler.ShowSegments(collection_name, query_params, response);
|
||||
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
|
@ -541,14 +541,14 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
*
|
||||
* GetSegmentVector
|
||||
*/
|
||||
ENDPOINT("GET", "/tables/{table_name}/segments/{segment_name}/{info}", GetSegmentInfo,
|
||||
PATH(String, table_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) {
|
||||
ENDPOINT("GET", "/collections/{collection_name}/segments/{segment_name}/{info}", GetSegmentInfo,
|
||||
PATH(String, collection_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) {
|
||||
auto offset = query_params.get("offset");
|
||||
auto page_size = query_params.get("page_size");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
String response;
|
||||
auto status_dto = handler.GetSegmentInfo(table_name, segment_name, info, query_params, response);
|
||||
auto status_dto = handler.GetSegmentInfo(collection_name, segment_name, info, query_params, response);
|
||||
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
|
@ -562,7 +562,7 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(VectorsOptions)
|
||||
|
||||
ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) {
|
||||
ENDPOINT("OPTIONS", "/collections/{collection_name}/vectors", VectorsOptions) {
|
||||
return createResponse(Status::CODE_204, "No Content");
|
||||
}
|
||||
|
||||
|
@ -571,11 +571,11 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
*
|
||||
* GetVectorByID ?id=
|
||||
*/
|
||||
ENDPOINT("GET", "/tables/{table_name}/vectors", GetVectors,
|
||||
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
|
||||
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
auto handler = WebRequestHandler();
|
||||
String response;
|
||||
auto status_dto = handler.GetVector(table_name, query_params, response);
|
||||
auto status_dto = handler.GetVector(collection_name, query_params, response);
|
||||
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
|
@ -589,16 +589,16 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(Insert)
|
||||
|
||||
ENDPOINT("POST", "/tables/{table_name}/vectors", Insert,
|
||||
PATH(String, table_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + table_name->std_str() + "/vectors\'");
|
||||
ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/vectors\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto ids_dto = VectorIdsDto::createShared();
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.Insert(table_name, body, ids_dto);
|
||||
auto status_dto = handler.Insert(collection_name, body, ids_dto);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, ids_dto);
|
||||
|
@ -621,16 +621,16 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
* Search
|
||||
* Delete by ID
|
||||
* */
|
||||
ENDPOINT("PUT", "/tables/{table_name}/vectors", VectorsOp,
|
||||
PATH(String, table_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/tables/" + table_name->std_str() + "/vectors\'");
|
||||
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + "/vectors\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
OString result;
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.VectorsOp(table_name, body, result);
|
||||
auto status_dto = handler.VectorsOp(collection_name, body, result);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createResponse(Status::CODE_200, result);
|
||||
|
@ -674,11 +674,6 @@ class WebController : public oatpp::web::server::api::ApiController {
|
|||
|
||||
ADD_CORS(SystemOp)
|
||||
|
||||
/**
|
||||
* Load
|
||||
* Compact
|
||||
* Flush
|
||||
*/
|
||||
ENDPOINT("PUT", "/system/{Op}", SystemOp, PATH(String, Op), BODY_STRING(String, body_str)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/system/" + Op->std_str() + "\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
|
|
@ -26,7 +26,7 @@ class IndexRequestDto : public oatpp::data::mapping::type::Object {
|
|||
|
||||
DTO_FIELD(String, index_type) = VALUE_INDEX_INDEX_TYPE_DEFAULT;
|
||||
|
||||
DTO_FIELD(Int64, nlist) = VALUE_INDEX_NLIST_DEFAULT;
|
||||
DTO_FIELD(String, params) = VALUE_INDEX_NLIST_DEFAULT;
|
||||
};
|
||||
|
||||
using IndexDto = IndexRequestDto;
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace web {
|
|||
class TableRequestDto : public oatpp::data::mapping::type::Object {
|
||||
DTO_INIT(TableRequestDto, Object)
|
||||
|
||||
DTO_FIELD(String, table_name, "table_name");
|
||||
DTO_FIELD(String, collection_name, "collection_name");
|
||||
DTO_FIELD(Int64, dimension, "dimension");
|
||||
DTO_FIELD(Int64, index_file_size, "index_file_size") = VALUE_TABLE_INDEX_FILE_SIZE_DEFAULT;
|
||||
DTO_FIELD(String, metric_type, "metric_type") = VALUE_TABLE_METRIC_TYPE_DEFAULT;
|
||||
|
@ -34,25 +34,25 @@ class TableRequestDto : public oatpp::data::mapping::type::Object {
|
|||
class TableFieldsDto : public oatpp::data::mapping::type::Object {
|
||||
DTO_INIT(TableFieldsDto, Object)
|
||||
|
||||
DTO_FIELD(String, table_name);
|
||||
DTO_FIELD(String, collection_name);
|
||||
DTO_FIELD(Int64, dimension);
|
||||
DTO_FIELD(Int64, index_file_size);
|
||||
DTO_FIELD(String, metric_type);
|
||||
DTO_FIELD(Int64, count);
|
||||
DTO_FIELD(String, index);
|
||||
DTO_FIELD(Int64, nlist);
|
||||
DTO_FIELD(String, index_params);
|
||||
};
|
||||
|
||||
class TableListDto : public OObject {
|
||||
DTO_INIT(TableListDto, Object)
|
||||
|
||||
DTO_FIELD(List<String>::ObjectWrapper, table_names);
|
||||
DTO_FIELD(List<String>::ObjectWrapper, collection_names);
|
||||
};
|
||||
|
||||
class TableListFieldsDto : public OObject {
|
||||
DTO_INIT(TableListFieldsDto, Object)
|
||||
|
||||
DTO_FIELD(List<TableFieldsDto::ObjectWrapper>::ObjectWrapper, tables);
|
||||
DTO_FIELD(List<TableFieldsDto::ObjectWrapper>::ObjectWrapper, collections);
|
||||
DTO_FIELD(Int64, count) = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -20,26 +20,6 @@ namespace web {
|
|||
|
||||
#include OATPP_CODEGEN_BEGIN(DTO)
|
||||
|
||||
class SearchRequestDto : public OObject {
|
||||
DTO_INIT(SearchRequestDto, Object)
|
||||
|
||||
DTO_FIELD(Int64, topk);
|
||||
DTO_FIELD(Int64, nprobe);
|
||||
DTO_FIELD(List<String>::ObjectWrapper, tags);
|
||||
DTO_FIELD(List<String>::ObjectWrapper, file_ids);
|
||||
DTO_FIELD(List<List<Float32>::ObjectWrapper>::ObjectWrapper, records);
|
||||
DTO_FIELD(List<List<Int64>::ObjectWrapper>::ObjectWrapper, records_bin);
|
||||
};
|
||||
|
||||
class InsertRequestDto : public oatpp::data::mapping::type::Object {
|
||||
DTO_INIT(InsertRequestDto, Object)
|
||||
|
||||
DTO_FIELD(String, tag) = VALUE_PARTITION_TAG_DEFAULT;
|
||||
DTO_FIELD(List<List<Float32>::ObjectWrapper>::ObjectWrapper, records);
|
||||
DTO_FIELD(List<List<Int64>::ObjectWrapper>::ObjectWrapper, records_bin);
|
||||
DTO_FIELD(List<Int64>::ObjectWrapper, ids);
|
||||
};
|
||||
|
||||
class VectorIdsDto : public oatpp::data::mapping::type::Object {
|
||||
DTO_INIT(VectorIdsDto, Object)
|
||||
|
||||
|
|
|
@ -109,6 +109,26 @@ WebRequestHandler::ParseQueryStr(const OQueryParams& query_params, const std::st
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::ParseQueryBool(const OQueryParams& query_params, const std::string& key, bool& value,
|
||||
bool nullable) {
|
||||
auto query = query_params.get(key.c_str());
|
||||
if (nullptr != query.get() && query->getSize() > 0) {
|
||||
std::string value_str = query->std_str();
|
||||
if (!ValidationUtil::ValidateStringIsBool(value_str).ok()) {
|
||||
return Status(ILLEGAL_QUERY_PARAM, "Query param \'all_required\' must be a bool");
|
||||
}
|
||||
value = value_str == "True" || value_str == "true";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!nullable) {
|
||||
return Status(QUERY_PARAM_LOSS, "Query param \"" + key + "\" is required");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void
|
||||
WebRequestHandler::AddStatusToJson(nlohmann::json& json, int64_t code, const std::string& msg) {
|
||||
json["code"] = (int64_t)code;
|
||||
|
@ -142,9 +162,9 @@ WebRequestHandler::ParsePartitionStat(const milvus::server::PartitionStat& par_s
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::IsBinaryTable(const std::string& table_name, bool& bin) {
|
||||
WebRequestHandler::IsBinaryTable(const std::string& collection_name, bool& bin) {
|
||||
TableSchema schema;
|
||||
auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema);
|
||||
auto status = request_handler_.DescribeTable(context_ptr_, collection_name, schema);
|
||||
if (status.ok()) {
|
||||
auto metric = engine::MetricType(schema.metric_type_);
|
||||
bin = engine::MetricType::HAMMING == metric || engine::MetricType::JACCARD == metric ||
|
||||
|
@ -187,30 +207,30 @@ WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::Vecto
|
|||
|
||||
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
|
||||
Status
|
||||
WebRequestHandler::GetTableMetaInfo(const std::string& table_name, nlohmann::json& json_out) {
|
||||
WebRequestHandler::GetTableMetaInfo(const std::string& collection_name, nlohmann::json& json_out) {
|
||||
TableSchema schema;
|
||||
auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema);
|
||||
auto status = request_handler_.DescribeTable(context_ptr_, collection_name, schema);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
int64_t count;
|
||||
status = request_handler_.CountTable(context_ptr_, table_name, count);
|
||||
status = request_handler_.CountTable(context_ptr_, collection_name, count);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
IndexParam index_param;
|
||||
status = request_handler_.DescribeIndex(context_ptr_, table_name, index_param);
|
||||
status = request_handler_.DescribeIndex(context_ptr_, collection_name, index_param);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
json_out["table_name"] = schema.table_name_;
|
||||
json_out["collection_name"] = schema.table_name_;
|
||||
json_out["dimension"] = schema.dimension_;
|
||||
json_out["index_file_size"] = schema.index_file_size_;
|
||||
json_out["index"] = IndexMap.at(engine::EngineType(index_param.index_type_));
|
||||
json_out["nlist"] = index_param.nlist_;
|
||||
json_out["index_params"] = index_param.extra_params_;
|
||||
json_out["metric_type"] = MetricMap.at(engine::MetricType(schema.metric_type_));
|
||||
json_out["count"] = count;
|
||||
|
||||
|
@ -218,15 +238,15 @@ WebRequestHandler::GetTableMetaInfo(const std::string& table_name, nlohmann::jso
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::GetTableStat(const std::string& table_name, nlohmann::json& json_out) {
|
||||
struct TableInfo table_info;
|
||||
auto status = request_handler_.ShowTableInfo(context_ptr_, table_name, table_info);
|
||||
WebRequestHandler::GetTableStat(const std::string& collection_name, nlohmann::json& json_out) {
|
||||
struct TableInfo collection_info;
|
||||
auto status = request_handler_.ShowTableInfo(context_ptr_, collection_name, collection_info);
|
||||
|
||||
if (status.ok()) {
|
||||
json_out["count"] = table_info.total_row_num_;
|
||||
json_out["count"] = collection_info.total_row_num_;
|
||||
|
||||
std::vector<nlohmann::json> par_stat_json;
|
||||
for (auto& par : table_info.partitions_stat_) {
|
||||
for (auto& par : collection_info.partitions_stat_) {
|
||||
nlohmann::json par_json;
|
||||
ParsePartitionStat(par, par_json);
|
||||
par_stat_json.push_back(par_json);
|
||||
|
@ -238,10 +258,10 @@ WebRequestHandler::GetTableStat(const std::string& table_name, nlohmann::json& j
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::GetSegmentVectors(const std::string& table_name, const std::string& segment_name, int64_t page_size,
|
||||
int64_t offset, nlohmann::json& json_out) {
|
||||
WebRequestHandler::GetSegmentVectors(const std::string& collection_name, const std::string& segment_name,
|
||||
int64_t page_size, int64_t offset, nlohmann::json& json_out) {
|
||||
std::vector<int64_t> vector_ids;
|
||||
auto status = request_handler_.GetVectorIDs(context_ptr_, table_name, segment_name, vector_ids);
|
||||
auto status = request_handler_.GetVectorIDs(context_ptr_, collection_name, segment_name, vector_ids);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -251,7 +271,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& table_name, const std::s
|
|||
|
||||
auto ids = std::vector<int64_t>(vector_ids.begin() + ids_begin, vector_ids.begin() + ids_end);
|
||||
nlohmann::json vectors_json;
|
||||
status = GetVectorsByIDs(table_name, ids, vectors_json);
|
||||
status = GetVectorsByIDs(collection_name, ids, vectors_json);
|
||||
|
||||
nlohmann::json result_json;
|
||||
if (vectors_json.empty()) {
|
||||
|
@ -267,10 +287,10 @@ WebRequestHandler::GetSegmentVectors(const std::string& table_name, const std::s
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::GetSegmentIds(const std::string& table_name, const std::string& segment_name, int64_t page_size,
|
||||
WebRequestHandler::GetSegmentIds(const std::string& collection_name, const std::string& segment_name, int64_t page_size,
|
||||
int64_t offset, nlohmann::json& json_out) {
|
||||
std::vector<int64_t> vector_ids;
|
||||
auto status = request_handler_.GetVectorIDs(context_ptr_, table_name, segment_name, vector_ids);
|
||||
auto status = request_handler_.GetVectorIDs(context_ptr_, collection_name, segment_name, vector_ids);
|
||||
if (status.ok()) {
|
||||
auto ids_begin = std::min(vector_ids.size(), (size_t)offset);
|
||||
auto ids_end = std::min(vector_ids.size(), (size_t)(offset + page_size));
|
||||
|
@ -310,12 +330,12 @@ WebRequestHandler::Cmd(const std::string& cmd, std::string& result_str) {
|
|||
|
||||
Status
|
||||
WebRequestHandler::PreLoadTable(const nlohmann::json& json, std::string& result_str) {
|
||||
if (!json.contains("table_name")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"load\" must contains table_name");
|
||||
if (!json.contains("collection_name")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"load\" must contains collection_name");
|
||||
}
|
||||
|
||||
auto table_name = json["table_name"];
|
||||
auto status = request_handler_.PreloadTable(context_ptr_, table_name.get<std::string>());
|
||||
auto collection_name = json["collection_name"];
|
||||
auto status = request_handler_.PreloadTable(context_ptr_, collection_name.get<std::string>());
|
||||
if (status.ok()) {
|
||||
nlohmann::json result;
|
||||
AddStatusToJson(result, status.code(), status.message());
|
||||
|
@ -327,17 +347,17 @@ WebRequestHandler::PreLoadTable(const nlohmann::json& json, std::string& result_
|
|||
|
||||
Status
|
||||
WebRequestHandler::Flush(const nlohmann::json& json, std::string& result_str) {
|
||||
if (!json.contains("table_names")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"flush\" must contains table_names");
|
||||
if (!json.contains("collection_names")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"flush\" must contains collection_names");
|
||||
}
|
||||
|
||||
auto table_names = json["table_names"];
|
||||
if (!table_names.is_array()) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"table_names\" must be and array");
|
||||
auto collection_names = json["collection_names"];
|
||||
if (!collection_names.is_array()) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"collection_names\" must be and array");
|
||||
}
|
||||
|
||||
std::vector<std::string> names;
|
||||
for (auto& name : table_names) {
|
||||
for (auto& name : collection_names) {
|
||||
names.emplace_back(name.get<std::string>());
|
||||
}
|
||||
|
||||
|
@ -353,16 +373,16 @@ WebRequestHandler::Flush(const nlohmann::json& json, std::string& result_str) {
|
|||
|
||||
Status
|
||||
WebRequestHandler::Compact(const nlohmann::json& json, std::string& result_str) {
|
||||
if (!json.contains("table_name")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"compact\" must contains table_names");
|
||||
if (!json.contains("collection_name")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"compact\" must contains collection_names");
|
||||
}
|
||||
|
||||
auto table_name = json["table_name"];
|
||||
if (!table_name.is_string()) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"table_names\" must be a string");
|
||||
auto collection_name = json["collection_name"];
|
||||
if (!collection_name.is_string()) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"collection_names\" must be a string");
|
||||
}
|
||||
|
||||
auto name = table_name.get<std::string>();
|
||||
auto name = collection_name.get<std::string>();
|
||||
|
||||
auto status = request_handler_.Compact(context_ptr_, name);
|
||||
|
||||
|
@ -463,17 +483,12 @@ WebRequestHandler::SetConfig(const nlohmann::json& json, std::string& result_str
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& json, std::string& result_str) {
|
||||
WebRequestHandler::Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str) {
|
||||
if (!json.contains("topk")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \'topk\' is required");
|
||||
}
|
||||
int64_t topk = json["topk"];
|
||||
|
||||
if (!json.contains("nprobe")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \'nprobe\' is required");
|
||||
}
|
||||
int64_t nprobe = json["nprobe"];
|
||||
|
||||
std::vector<std::string> partition_tags;
|
||||
if (json.contains("partition_tags")) {
|
||||
auto tags = json["partition_tags"];
|
||||
|
@ -497,8 +512,12 @@ WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& j
|
|||
}
|
||||
}
|
||||
|
||||
if (!json.contains("params")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \'params\' is required");
|
||||
}
|
||||
|
||||
bool bin_flag = false;
|
||||
auto status = IsBinaryTable(table_name, bin_flag);
|
||||
auto status = IsBinaryTable(collection_name, bin_flag);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -514,8 +533,9 @@ WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& j
|
|||
}
|
||||
|
||||
TopKQueryResult result;
|
||||
status = request_handler_.Search(context_ptr_, table_name, vectors_data, topk, nprobe, partition_tags, file_id_vec,
|
||||
result);
|
||||
status = request_handler_.Search(context_ptr_, collection_name, vectors_data, topk, json["params"], partition_tags,
|
||||
file_id_vec, result);
|
||||
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -547,7 +567,8 @@ WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& j
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::DeleteByIDs(const std::string& table_name, const nlohmann::json& json, std::string& result_str) {
|
||||
WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json,
|
||||
std::string& result_str) {
|
||||
std::vector<int64_t> vector_ids;
|
||||
if (!json.contains("ids")) {
|
||||
return Status(BODY_FIELD_LOSS, "Field \"delete\" must contains \"ids\"");
|
||||
|
@ -565,7 +586,7 @@ WebRequestHandler::DeleteByIDs(const std::string& table_name, const nlohmann::js
|
|||
vector_ids.emplace_back(std::stol(id_str));
|
||||
}
|
||||
|
||||
auto status = request_handler_.DeleteByID(context_ptr_, table_name, vector_ids);
|
||||
auto status = request_handler_.DeleteByID(context_ptr_, collection_name, vector_ids);
|
||||
|
||||
nlohmann::json result_json;
|
||||
AddStatusToJson(result_json, status.code(), status.message());
|
||||
|
@ -575,13 +596,13 @@ WebRequestHandler::DeleteByIDs(const std::string& table_name, const nlohmann::js
|
|||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::GetVectorsByIDs(const std::string& table_name, const std::vector<int64_t>& ids,
|
||||
WebRequestHandler::GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
|
||||
nlohmann::json& json_out) {
|
||||
std::vector<engine::VectorsData> vector_batch;
|
||||
for (size_t i = 0; i < ids.size(); i++) {
|
||||
auto vec_ids = std::vector<int64_t>(ids.begin() + i, ids.begin() + i + 1);
|
||||
engine::VectorsData vectors_data;
|
||||
auto status = request_handler_.GetVectorByID(context_ptr_, table_name, vec_ids, vectors_data);
|
||||
auto status = request_handler_.GetVectorByID(context_ptr_, collection_name, vec_ids, vectors_data);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -589,7 +610,7 @@ WebRequestHandler::GetVectorsByIDs(const std::string& table_name, const std::vec
|
|||
}
|
||||
|
||||
bool bin;
|
||||
auto status = IsBinaryTable(table_name, bin);
|
||||
auto status = IsBinaryTable(collection_name, bin);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -879,30 +900,31 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
|
|||
* Table {
|
||||
*/
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& table_schema) {
|
||||
if (nullptr == table_schema->table_name.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'table_name\' is missing")
|
||||
WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& collection_schema) {
|
||||
if (nullptr == collection_schema->collection_name.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'collection_name\' is missing")
|
||||
}
|
||||
|
||||
if (nullptr == table_schema->dimension.get()) {
|
||||
if (nullptr == collection_schema->dimension.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'dimension\' is missing")
|
||||
}
|
||||
|
||||
if (nullptr == table_schema->index_file_size.get()) {
|
||||
if (nullptr == collection_schema->index_file_size.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_file_size\' is missing")
|
||||
}
|
||||
|
||||
if (nullptr == table_schema->metric_type.get()) {
|
||||
if (nullptr == collection_schema->metric_type.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'metric_type\' is missing")
|
||||
}
|
||||
|
||||
if (MetricNameMap.find(table_schema->metric_type->std_str()) == MetricNameMap.end()) {
|
||||
if (MetricNameMap.find(collection_schema->metric_type->std_str()) == MetricNameMap.end()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_METRIC_TYPE, "metric_type is illegal")
|
||||
}
|
||||
|
||||
auto status = request_handler_.CreateTable(
|
||||
context_ptr_, table_schema->table_name->std_str(), table_schema->dimension, table_schema->index_file_size,
|
||||
static_cast<int64_t>(MetricNameMap.at(table_schema->metric_type->std_str())));
|
||||
auto status =
|
||||
request_handler_.CreateTable(context_ptr_, collection_schema->collection_name->std_str(),
|
||||
collection_schema->dimension, collection_schema->index_file_size,
|
||||
static_cast<int64_t>(MetricNameMap.at(collection_schema->metric_type->std_str())));
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -926,45 +948,41 @@ WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result)
|
|||
}
|
||||
|
||||
bool all_required = false;
|
||||
auto required = query_params.get("all_required");
|
||||
if (nullptr != required.get()) {
|
||||
auto required_str = required->std_str();
|
||||
if (!ValidationUtil::ValidateStringIsBool(required_str).ok()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, "Query param \'all_required\' must be a bool")
|
||||
}
|
||||
all_required = required_str == "True" || required_str == "true";
|
||||
ParseQueryBool(query_params, "all_required", all_required);
|
||||
if (!status.ok()) {
|
||||
RETURN_STATUS_DTO(status.code(), status.message().c_str());
|
||||
}
|
||||
|
||||
std::vector<std::string> tables;
|
||||
status = request_handler_.ShowTables(context_ptr_, tables);
|
||||
std::vector<std::string> collections;
|
||||
status = request_handler_.ShowTables(context_ptr_, collections);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
if (all_required) {
|
||||
offset = 0;
|
||||
page_size = tables.size();
|
||||
page_size = collections.size();
|
||||
} else {
|
||||
offset = std::min((size_t)offset, tables.size());
|
||||
page_size = std::min(tables.size() - offset, (size_t)page_size);
|
||||
offset = std::min((size_t)offset, collections.size());
|
||||
page_size = std::min(collections.size() - offset, (size_t)page_size);
|
||||
}
|
||||
|
||||
nlohmann::json tables_json;
|
||||
nlohmann::json collections_json;
|
||||
for (int64_t i = offset; i < page_size + offset; i++) {
|
||||
nlohmann::json table_json;
|
||||
status = GetTableMetaInfo(tables.at(i), table_json);
|
||||
nlohmann::json collection_json;
|
||||
status = GetTableMetaInfo(collections.at(i), collection_json);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
tables_json.push_back(table_json);
|
||||
collections_json.push_back(collection_json);
|
||||
}
|
||||
|
||||
nlohmann::json result_json;
|
||||
result_json["count"] = tables.size();
|
||||
if (tables_json.empty()) {
|
||||
result_json["tables"] = std::vector<int64_t>();
|
||||
result_json["count"] = collections.size();
|
||||
if (collections_json.empty()) {
|
||||
result_json["collections"] = std::vector<int64_t>();
|
||||
} else {
|
||||
result_json["tables"] = tables_json;
|
||||
result_json["collections"] = collections_json;
|
||||
}
|
||||
|
||||
result = result_json.dump().c_str();
|
||||
|
@ -973,9 +991,9 @@ WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result)
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query_params, OString& result) {
|
||||
if (nullptr == table_name.get()) {
|
||||
RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'table_name\' is required!");
|
||||
WebRequestHandler::GetTable(const OString& collection_name, const OQueryParams& query_params, OString& result) {
|
||||
if (nullptr == collection_name.get()) {
|
||||
RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'collection_name\' is required!");
|
||||
}
|
||||
|
||||
std::string stat;
|
||||
|
@ -986,11 +1004,11 @@ WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query
|
|||
|
||||
if (!stat.empty() && stat == "stat") {
|
||||
nlohmann::json json;
|
||||
status = GetTableStat(table_name->std_str(), json);
|
||||
status = GetTableStat(collection_name->std_str(), json);
|
||||
result = status.ok() ? json.dump().c_str() : "NULL";
|
||||
} else {
|
||||
nlohmann::json json;
|
||||
status = GetTableMetaInfo(table_name->std_str(), json);
|
||||
status = GetTableMetaInfo(collection_name->std_str(), json);
|
||||
result = status.ok() ? json.dump().c_str() : "NULL";
|
||||
}
|
||||
|
||||
|
@ -998,8 +1016,8 @@ WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::DropTable(const OString& table_name) {
|
||||
auto status = request_handler_.DropTable(context_ptr_, table_name->std_str());
|
||||
WebRequestHandler::DropTable(const OString& collection_name) {
|
||||
auto status = request_handler_.DropTable(context_ptr_, collection_name->std_str());
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -1010,41 +1028,49 @@ WebRequestHandler::DropTable(const OString& table_name) {
|
|||
*/
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreateIndex(const OString& table_name, const IndexRequestDto::ObjectWrapper& index_param) {
|
||||
if (nullptr == index_param->index_type.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required")
|
||||
}
|
||||
std::string index_type = index_param->index_type->std_str();
|
||||
if (IndexNameMap.find(index_type) == IndexNameMap.end()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_INDEX_TYPE, "The index type is invalid.")
|
||||
WebRequestHandler::CreateIndex(const OString& table_name, const OString& body) {
|
||||
try {
|
||||
auto request_json = nlohmann::json::parse(body->std_str());
|
||||
if (!request_json.contains("index_type")) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required");
|
||||
}
|
||||
|
||||
std::string index_type = request_json["index_type"];
|
||||
if (IndexNameMap.find(index_type) == IndexNameMap.end()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_INDEX_TYPE, "The index type is invalid.")
|
||||
}
|
||||
auto index = static_cast<int64_t>(IndexNameMap.at(index_type));
|
||||
if (!request_json.contains("params")) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'params\' is required")
|
||||
}
|
||||
auto status = request_handler_.CreateIndex(context_ptr_, table_name->std_str(), index, request_json["params"]);
|
||||
ASSIGN_RETURN_STATUS_DTO(status);
|
||||
} catch (nlohmann::detail::parse_error& e) {
|
||||
} catch (nlohmann::detail::type_error& e) {
|
||||
}
|
||||
|
||||
if (nullptr == index_param->nlist.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'nlist\' is required")
|
||||
}
|
||||
|
||||
auto status =
|
||||
request_handler_.CreateIndex(context_ptr_, table_name->std_str(),
|
||||
static_cast<int64_t>(IndexNameMap.at(index_type)), index_param->nlist->getValue());
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
ASSIGN_RETURN_STATUS_DTO(Status::OK())
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::GetIndex(const OString& table_name, IndexDto::ObjectWrapper& index_dto) {
|
||||
WebRequestHandler::GetIndex(const OString& collection_name, OString& result) {
|
||||
IndexParam param;
|
||||
auto status = request_handler_.DescribeIndex(context_ptr_, table_name->std_str(), param);
|
||||
auto status = request_handler_.DescribeIndex(context_ptr_, collection_name->std_str(), param);
|
||||
|
||||
if (status.ok()) {
|
||||
index_dto->index_type = IndexMap.at(engine::EngineType(param.index_type_)).c_str();
|
||||
index_dto->nlist = param.nlist_;
|
||||
nlohmann::json json_out;
|
||||
auto index_type = IndexMap.at(engine::EngineType(param.index_type_));
|
||||
json_out["index_type"] = index_type;
|
||||
json_out["params"] = nlohmann::json::parse(param.extra_params_);
|
||||
result = json_out.dump().c_str();
|
||||
}
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::DropIndex(const OString& table_name) {
|
||||
auto status = request_handler_.DropIndex(context_ptr_, table_name->std_str());
|
||||
WebRequestHandler::DropIndex(const OString& collection_name) {
|
||||
auto status = request_handler_.DropIndex(context_ptr_, collection_name->std_str());
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -1054,19 +1080,19 @@ WebRequestHandler::DropIndex(const OString& table_name) {
|
|||
* Partition {
|
||||
*/
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreatePartition(const OString& table_name, const PartitionRequestDto::ObjectWrapper& param) {
|
||||
WebRequestHandler::CreatePartition(const OString& collection_name, const PartitionRequestDto::ObjectWrapper& param) {
|
||||
if (nullptr == param->partition_tag.get()) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'partition_tag\' is required")
|
||||
}
|
||||
|
||||
auto status =
|
||||
request_handler_.CreatePartition(context_ptr_, table_name->std_str(), param->partition_tag->std_str());
|
||||
request_handler_.CreatePartition(context_ptr_, collection_name->std_str(), param->partition_tag->std_str());
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::ShowPartitions(const OString& table_name, const OQueryParams& query_params,
|
||||
WebRequestHandler::ShowPartitions(const OString& collection_name, const OQueryParams& query_params,
|
||||
PartitionListDto::ObjectWrapper& partition_list_dto) {
|
||||
int64_t offset = 0;
|
||||
auto status = ParseQueryInteger(query_params, "offset", offset);
|
||||
|
@ -1096,7 +1122,7 @@ WebRequestHandler::ShowPartitions(const OString& table_name, const OQueryParams&
|
|||
}
|
||||
|
||||
std::vector<PartitionParam> partitions;
|
||||
status = request_handler_.ShowPartitions(context_ptr_, table_name->std_str(), partitions);
|
||||
status = request_handler_.ShowPartitions(context_ptr_, collection_name->std_str(), partitions);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -1124,7 +1150,7 @@ WebRequestHandler::ShowPartitions(const OString& table_name, const OQueryParams&
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::DropPartition(const OString& table_name, const OString& body) {
|
||||
WebRequestHandler::DropPartition(const OString& collection_name, const OString& body) {
|
||||
std::string tag;
|
||||
try {
|
||||
auto json = nlohmann::json::parse(body->std_str());
|
||||
|
@ -1134,7 +1160,7 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& body)
|
|||
} catch (nlohmann::detail::type_error& e) {
|
||||
RETURN_STATUS_DTO(BODY_PARSE_FAIL, e.what())
|
||||
}
|
||||
auto status = request_handler_.DropPartition(context_ptr_, table_name->std_str(), tag);
|
||||
auto status = request_handler_.DropPartition(context_ptr_, collection_name->std_str(), tag);
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -1144,7 +1170,7 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& body)
|
|||
* Segment {
|
||||
*/
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::ShowSegments(const OString& table_name, const OQueryParams& query_params, OString& response) {
|
||||
WebRequestHandler::ShowSegments(const OString& collection_name, const OQueryParams& query_params, OString& response) {
|
||||
int64_t offset = 0;
|
||||
auto status = ParseQueryInteger(query_params, "offset", offset);
|
||||
if (!status.ok()) {
|
||||
|
@ -1177,7 +1203,7 @@ WebRequestHandler::ShowSegments(const OString& table_name, const OQueryParams& q
|
|||
}
|
||||
|
||||
TableInfo info;
|
||||
status = request_handler_.ShowTableInfo(context_ptr_, table_name->std_str(), info);
|
||||
status = request_handler_.ShowTableInfo(context_ptr_, collection_name->std_str(), info);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -1228,7 +1254,7 @@ WebRequestHandler::ShowSegments(const OString& table_name, const OQueryParams& q
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::GetSegmentInfo(const OString& table_name, const OString& segment_name, const OString& info,
|
||||
WebRequestHandler::GetSegmentInfo(const OString& collection_name, const OString& segment_name, const OString& info,
|
||||
const OQueryParams& query_params, OString& result) {
|
||||
int64_t offset = 0;
|
||||
auto status = ParseQueryInteger(query_params, "offset", offset);
|
||||
|
@ -1252,10 +1278,10 @@ WebRequestHandler::GetSegmentInfo(const OString& table_name, const OString& segm
|
|||
nlohmann::json json;
|
||||
// Get vectors
|
||||
if (re == "vectors") {
|
||||
status = GetSegmentVectors(table_name->std_str(), segment_name->std_str(), page_size, offset, json);
|
||||
status = GetSegmentVectors(collection_name->std_str(), segment_name->std_str(), page_size, offset, json);
|
||||
// Get vector ids
|
||||
} else if (re == "ids") {
|
||||
status = GetSegmentIds(table_name->std_str(), segment_name->std_str(), page_size, offset, json);
|
||||
status = GetSegmentIds(collection_name->std_str(), segment_name->std_str(), page_size, offset, json);
|
||||
}
|
||||
|
||||
result = status.ok() ? json.dump().c_str() : "NULL";
|
||||
|
@ -1268,14 +1294,14 @@ WebRequestHandler::GetSegmentInfo(const OString& table_name, const OString& segm
|
|||
* Vector {
|
||||
*/
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::Insert(const OString& table_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto) {
|
||||
WebRequestHandler::Insert(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto) {
|
||||
if (nullptr == body.get() || body->getSize() == 0) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Request payload is required.")
|
||||
}
|
||||
|
||||
// step 1: copy vectors
|
||||
bool bin_flag;
|
||||
auto status = IsBinaryTable(table_name->std_str(), bin_flag);
|
||||
auto status = IsBinaryTable(collection_name->std_str(), bin_flag);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
@ -1310,7 +1336,7 @@ WebRequestHandler::Insert(const OString& table_name, const OString& body, Vector
|
|||
}
|
||||
|
||||
// step 4: construct result
|
||||
status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, tag);
|
||||
status = request_handler_.Insert(context_ptr_, collection_name->std_str(), vectors, tag);
|
||||
if (status.ok()) {
|
||||
ids_dto->ids = ids_dto->ids->createShared();
|
||||
for (auto& id : vectors.id_array_) {
|
||||
|
@ -1322,7 +1348,7 @@ WebRequestHandler::Insert(const OString& table_name, const OString& body, Vector
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::GetVector(const OString& table_name, const OQueryParams& query_params, OString& response) {
|
||||
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
|
||||
int64_t id = 0;
|
||||
auto status = ParseQueryInteger(query_params, "id", id, false);
|
||||
if (!status.ok()) {
|
||||
|
@ -1332,7 +1358,7 @@ WebRequestHandler::GetVector(const OString& table_name, const OQueryParams& quer
|
|||
std::vector<int64_t> ids = {id};
|
||||
engine::VectorsData vectors;
|
||||
nlohmann::json vectors_json;
|
||||
status = GetVectorsByIDs(table_name->std_str(), ids, vectors_json);
|
||||
status = GetVectorsByIDs(collection_name->std_str(), ids, vectors_json);
|
||||
if (!status.ok()) {
|
||||
response = "NULL";
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
|
@ -1352,7 +1378,7 @@ WebRequestHandler::GetVector(const OString& table_name, const OQueryParams& quer
|
|||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::VectorsOp(const OString& table_name, const OString& payload, OString& response) {
|
||||
WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payload, OString& response) {
|
||||
auto status = Status::OK();
|
||||
std::string result_str;
|
||||
|
||||
|
@ -1360,9 +1386,9 @@ WebRequestHandler::VectorsOp(const OString& table_name, const OString& payload,
|
|||
nlohmann::json payload_json = nlohmann::json::parse(payload->std_str());
|
||||
|
||||
if (payload_json.contains("delete")) {
|
||||
status = DeleteByIDs(table_name->std_str(), payload_json["delete"], result_str);
|
||||
status = DeleteByIDs(collection_name->std_str(), payload_json["delete"], result_str);
|
||||
} else if (payload_json.contains("search")) {
|
||||
status = Search(table_name->std_str(), payload_json["search"], result_str);
|
||||
status = Search(collection_name->std_str(), payload_json["search"], result_str);
|
||||
} else {
|
||||
status = Status(ILLEGAL_BODY, "Unknown body");
|
||||
}
|
||||
|
|
|
@ -84,6 +84,9 @@ class WebRequestHandler {
|
|||
Status
|
||||
ParseQueryStr(const OQueryParams& query_params, const std::string& key, std::string& value, bool nullable = true);
|
||||
|
||||
Status
|
||||
ParseQueryBool(const OQueryParams& query_params, const std::string& key, bool& value, bool nullable = true);
|
||||
|
||||
private:
|
||||
void
|
||||
AddStatusToJson(nlohmann::json& json, int64_t code, const std::string& msg);
|
||||
|
@ -189,10 +192,10 @@ class WebRequestHandler {
|
|||
* Index
|
||||
*/
|
||||
StatusDto::ObjectWrapper
|
||||
CreateIndex(const OString& table_name, const IndexRequestDto::ObjectWrapper& index_param);
|
||||
CreateIndex(const OString& table_name, const OString& body);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
GetIndex(const OString& table_name, IndexDto::ObjectWrapper& index_dto);
|
||||
GetIndex(const OString& table_name, OString& result);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
DropIndex(const OString& table_name);
|
||||
|
|
|
@ -12,9 +12,12 @@
|
|||
#include "utils/StringHelpFunctions.h"
|
||||
|
||||
#include <fiu-local.h>
|
||||
#include <algorithm>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
#include "utils/ValidationUtil.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
|
@ -148,11 +151,21 @@ StringHelpFunctions::IsRegexMatch(const std::string& target_str, const std::stri
|
|||
// regex match
|
||||
std::regex pattern(pattern_str);
|
||||
std::smatch results;
|
||||
if (std::regex_match(target_str, results, pattern)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
return std::regex_match(target_str, results, pattern);
|
||||
}
|
||||
|
||||
Status
|
||||
StringHelpFunctions::ConvertToBoolean(const std::string& str, bool& value) {
|
||||
auto status = ValidationUtil::ValidateStringIsBool(str);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
std::string s = str;
|
||||
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
|
||||
value = s == "true" || s == "on" || s == "yes" || s == "1";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -64,6 +64,12 @@ class StringHelpFunctions {
|
|||
// regex grammar reference: http://www.cplusplus.com/reference/regex/ECMAScript/
|
||||
static bool
|
||||
IsRegexMatch(const std::string& target_str, const std::string& pattern);
|
||||
|
||||
// conversion rules refer to ValidationUtil::ValidateStringIsBool()
|
||||
// "true", "on", "yes", "1" ==> true
|
||||
// "false", "off", "no", "0", "" ==> false
|
||||
static Status
|
||||
ConvertToBoolean(const std::string& str, bool& value);
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "utils/ValidationUtil.h"
|
||||
#include "Log.h"
|
||||
#include "db/engine/ExecutionEngine.h"
|
||||
#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
#include "utils/StringHelpFunctions.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
|
@ -32,10 +33,62 @@
|
|||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
namespace {
|
||||
constexpr size_t TABLE_NAME_SIZE_LIMIT = 255;
|
||||
constexpr int64_t TABLE_DIMENSION_LIMIT = 32768;
|
||||
constexpr int32_t INDEX_FILE_SIZE_LIMIT = 4096; // index trigger size max = 4096 MB
|
||||
|
||||
Status
|
||||
CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max,
|
||||
bool min_close = true, bool max_closed = true) {
|
||||
if (json_params.find(param_name) == json_params.end()) {
|
||||
std::string msg = "Parameter list must contain: ";
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg + param_name);
|
||||
}
|
||||
|
||||
try {
|
||||
int64_t value = json_params[param_name];
|
||||
bool min_err = min_close ? value < min : value <= min;
|
||||
bool max_err = max_closed ? value > max : value >= max;
|
||||
if (min_err || max_err) {
|
||||
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value) + ". Valid range is " +
|
||||
(min_close ? "[" : "(") + std::to_string(min) + ", " + std::to_string(max) +
|
||||
(max_closed ? "]" : ")");
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
std::string msg = "Invalid " + param_name + ": ";
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg + e.what());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
CheckParameterExistence(const milvus::json& json_params, const std::string& param_name) {
|
||||
if (json_params.find(param_name) == json_params.end()) {
|
||||
std::string msg = "Parameter list must contain: ";
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg + param_name);
|
||||
}
|
||||
|
||||
try {
|
||||
int64_t value = json_params[param_name];
|
||||
if (value < 0) {
|
||||
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value);
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg);
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
std::string msg = "Invalid " + param_name + ": ";
|
||||
return Status(SERVER_INVALID_ARGUMENT, msg + e.what());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateTableName(const std::string& table_name) {
|
||||
// Table name shouldn't be empty.
|
||||
|
@ -109,24 +162,114 @@ ValidationUtil::ValidateTableIndexType(int32_t index_type) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateIndexParams(const milvus::json& index_params, const engine::meta::TableSchema& table_schema,
|
||||
int32_t index_type) {
|
||||
switch (index_type) {
|
||||
case (int32_t)engine::EngineType::FAISS_IDMAP:
|
||||
case (int32_t)engine::EngineType::FAISS_BIN_IDMAP: {
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::FAISS_IVFFLAT:
|
||||
case (int32_t)engine::EngineType::FAISS_IVFSQ8:
|
||||
case (int32_t)engine::EngineType::FAISS_IVFSQ8H:
|
||||
case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT: {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 0, 999999, false);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::FAISS_PQ: {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 0, 999999, false);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CheckParameterExistence(index_params, knowhere::IndexParams::m);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::NSG_MIX: {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::search_length, 10, 300);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = CheckParameterRange(index_params, knowhere::IndexParams::out_degree, 5, 300);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = CheckParameterRange(index_params, knowhere::IndexParams::candidate, 50, 1000);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = CheckParameterRange(index_params, knowhere::IndexParams::knng, 5, 300);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::HNSW: {
|
||||
auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 5, 48);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = CheckParameterRange(index_params, knowhere::IndexParams::efConstruction, 100, 500);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateSearchParams(const milvus::json& search_params, const engine::meta::TableSchema& table_schema,
|
||||
int64_t topk) {
|
||||
switch (table_schema.engine_type_) {
|
||||
case (int32_t)engine::EngineType::FAISS_IDMAP:
|
||||
case (int32_t)engine::EngineType::FAISS_BIN_IDMAP: {
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::FAISS_IVFFLAT:
|
||||
case (int32_t)engine::EngineType::FAISS_IVFSQ8:
|
||||
case (int32_t)engine::EngineType::FAISS_IVFSQ8H:
|
||||
case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT:
|
||||
case (int32_t)engine::EngineType::FAISS_PQ: {
|
||||
auto status = CheckParameterRange(search_params, knowhere::IndexParams::nprobe, 1, 999999);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::NSG_MIX: {
|
||||
auto status = CheckParameterRange(search_params, knowhere::IndexParams::search_length, 10, 300);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case (int32_t)engine::EngineType::HNSW: {
|
||||
auto status = CheckParameterRange(search_params, knowhere::IndexParams::ef, topk, 1000);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool
|
||||
ValidationUtil::IsBinaryIndexType(int32_t index_type) {
|
||||
return (index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP)) ||
|
||||
(index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT));
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateTableIndexNlist(int32_t nlist) {
|
||||
if (nlist <= 0) {
|
||||
std::string msg =
|
||||
"Invalid index nlist: " + std::to_string(nlist) + ". " + "The index nlist must be greater than 0.";
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status(SERVER_INVALID_INDEX_NLIST, msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateTableIndexFileSize(int64_t index_file_size) {
|
||||
if (index_file_size <= 0 || index_file_size > INDEX_FILE_SIZE_LIMIT) {
|
||||
|
@ -170,18 +313,6 @@ ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchem
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema) {
|
||||
if (nprobe <= 0 || nprobe > table_schema.nlist_) {
|
||||
std::string msg = "Invalid nprobe: " + std::to_string(nprobe) + ". " +
|
||||
"The nprobe must be within the range of 1 ~ index nlist.";
|
||||
SERVER_LOG_ERROR << msg;
|
||||
return Status(SERVER_INVALID_NPROBE, msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidatePartitionName(const std::string& partition_name) {
|
||||
if (partition_name.empty()) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "db/meta/MetaTypes.h"
|
||||
#include "utils/Json.h"
|
||||
#include "utils/Status.h"
|
||||
|
||||
#include <string>
|
||||
|
@ -34,11 +35,16 @@ class ValidationUtil {
|
|||
static Status
|
||||
ValidateTableIndexType(int32_t index_type);
|
||||
|
||||
static bool
|
||||
IsBinaryIndexType(int32_t index_type);
|
||||
static Status
|
||||
ValidateIndexParams(const milvus::json& index_params, const engine::meta::TableSchema& table_schema,
|
||||
int32_t index_type);
|
||||
|
||||
static Status
|
||||
ValidateTableIndexNlist(int32_t nlist);
|
||||
ValidateSearchParams(const milvus::json& search_params, const engine::meta::TableSchema& table_schema,
|
||||
int64_t topk);
|
||||
|
||||
static bool
|
||||
IsBinaryIndexType(int32_t index_type);
|
||||
|
||||
static Status
|
||||
ValidateTableIndexFileSize(int64_t index_file_size);
|
||||
|
@ -52,9 +58,6 @@ class ValidationUtil {
|
|||
static Status
|
||||
ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema);
|
||||
|
||||
static Status
|
||||
ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema);
|
||||
|
||||
static Status
|
||||
ValidatePartitionName(const std::string& partition_name);
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ Status
|
|||
BinVecImpl::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const uint8_t* xt) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
|
@ -47,15 +47,13 @@ BinVecImpl::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, c
|
|||
Status
|
||||
BinVecImpl::Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg) {
|
||||
try {
|
||||
auto k = cfg->k;
|
||||
int64_t k = cfg[knowhere::meta::TOPK];
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nq);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::TENSOR, xq);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
|
||||
auto res = index_->Search(ret_ds, search_cfg);
|
||||
auto res = index_->Search(ret_ds, cfg);
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
|
@ -150,9 +148,7 @@ BinVecImpl::GetVectorById(const int64_t n, const int64_t* ids, uint8_t* x, const
|
|||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::IDS, ids);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
|
||||
auto res = index_->GetVectorById(ret_ds, search_cfg);
|
||||
auto res = index_->GetVectorById(ret_ds, cfg);
|
||||
|
||||
// TODO(linxj): avoid copy here.
|
||||
auto res_x = res->Get<uint8_t*>(knowhere::meta::TENSOR);
|
||||
|
@ -176,15 +172,13 @@ BinVecImpl::SearchById(const int64_t& nq, const int64_t* xq, float* dist, int64_
|
|||
throw WrapperException("not support");
|
||||
}
|
||||
try {
|
||||
auto k = cfg->k;
|
||||
int64_t k = cfg[knowhere::meta::TOPK];
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nq);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
ret_ds->Set(knowhere::meta::IDS, xq);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
|
||||
auto res = index_->SearchById(ret_ds, search_cfg);
|
||||
auto res = index_->SearchById(ret_ds, cfg);
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
|
@ -235,7 +229,7 @@ BinVecImpl::GetBlacklist(faiss::ConcurrentBitsetPtr& list) {
|
|||
ErrorCode
|
||||
BinBFIndex::Build(const Config& cfg) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->Train(cfg);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
|
@ -251,7 +245,7 @@ Status
|
|||
BinBFIndex::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const uint8_t* xt) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
ret_ds->Set(knowhere::meta::ROWS, nb);
|
||||
ret_ds->Set(knowhere::meta::DIM, dim);
|
||||
|
|
|
@ -8,12 +8,13 @@
|
|||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "wrapper/ConfAdapter.h"
|
||||
|
||||
#include <fiu-local.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "WrapperException.h"
|
||||
|
@ -21,8 +22,6 @@
|
|||
#include "server/Config.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
// TODO(lxj): add conf checker
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
|
@ -32,45 +31,64 @@ namespace engine {
|
|||
#define GPU_MAX_NRPOBE 1024
|
||||
#endif
|
||||
|
||||
void
|
||||
ConfAdapter::MatchBase(knowhere::Config conf, knowhere::METRICTYPE default_metric) {
|
||||
if (conf->metric_type == knowhere::DEFAULT_TYPE)
|
||||
conf->metric_type = default_metric;
|
||||
#define DEFAULT_MAX_DIM 16384
|
||||
#define DEFAULT_MIN_DIM 1
|
||||
#define DEFAULT_MAX_K 16384
|
||||
#define DEFAULT_MIN_K 1
|
||||
#define DEFAULT_MIN_ROWS 1 // minimum size for build index
|
||||
#define DEFAULT_MAX_ROWS 50000000
|
||||
|
||||
#define CheckIntByRange(key, min, max) \
|
||||
if (!oricfg.contains(key) || !oricfg[key].is_number_integer() || oricfg[key].get<int64_t>() > max || \
|
||||
oricfg[key].get<int64_t>() < min) { \
|
||||
return false; \
|
||||
}
|
||||
|
||||
// #define checkfloat(key, min, max) \
|
||||
// if (!oricfg.contains(key) || !oricfg[key].is_number_float() || oricfg[key] >= max || oricfg[key] <= min) { \
|
||||
// return false; \
|
||||
// }
|
||||
|
||||
#define CheckIntByValues(key, container) \
|
||||
if (!oricfg.contains(key) || !oricfg[key].is_number_integer()) { \
|
||||
return false; \
|
||||
} else { \
|
||||
auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get<int64_t>()); \
|
||||
if (finder == std::end(container)) { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CheckStrByValues(key, container) \
|
||||
if (!oricfg.contains(key) || !oricfg[key].is_string()) { \
|
||||
return false; \
|
||||
} else { \
|
||||
auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get<std::string>()); \
|
||||
if (finder == std::end(container)) { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
|
||||
bool
|
||||
ConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
|
||||
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
ConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::Cfg>();
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
conf->k = metaconf.k;
|
||||
MatchBase(conf);
|
||||
return conf;
|
||||
}
|
||||
bool
|
||||
ConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
CheckIntByRange(knowhere::meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K);
|
||||
|
||||
knowhere::Config
|
||||
ConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::Cfg>();
|
||||
conf->k = metaconf.k;
|
||||
return conf;
|
||||
return true;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
IVFConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
MatchBase(conf);
|
||||
return conf;
|
||||
}
|
||||
|
||||
static constexpr float TYPICAL_COUNT = 1000000.0;
|
||||
|
||||
int64_t
|
||||
IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
|
||||
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
|
||||
static float TYPICAL_COUNT = 1000000.0;
|
||||
if (size <= TYPICAL_COUNT / per_nlist + 1) {
|
||||
// handle less row count, avoid nlist set to 0
|
||||
return 1;
|
||||
|
@ -81,62 +99,88 @@ IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist, const int6
|
|||
return nlist;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
IVFConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::IVFCfg>();
|
||||
conf->k = metaconf.k;
|
||||
bool
|
||||
IVFConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static int64_t MAX_NLIST = 999999;
|
||||
static int64_t MIN_NLIST = 1;
|
||||
|
||||
if (metaconf.nprobe <= 0)
|
||||
conf->nprobe = 16; // hardcode here
|
||||
else
|
||||
conf->nprobe = metaconf.nprobe;
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
|
||||
switch (type) {
|
||||
case IndexType::FAISS_IVFFLAT_GPU:
|
||||
case IndexType::FAISS_IVFSQ8_GPU:
|
||||
case IndexType::FAISS_IVFPQ_GPU:
|
||||
if (conf->nprobe > GPU_MAX_NRPOBE) {
|
||||
WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE
|
||||
<< ", but you passed " << conf->nprobe << ". Search with " << GPU_MAX_NRPOBE
|
||||
<< " instead";
|
||||
conf->nprobe = GPU_MAX_NRPOBE;
|
||||
}
|
||||
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
|
||||
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
// auto tune params
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(),
|
||||
oricfg[knowhere::IndexParams::nlist].get<int64_t>(), 16384);
|
||||
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg);
|
||||
}
|
||||
|
||||
bool
|
||||
IVFConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
static int64_t MIN_NPROBE = 1;
|
||||
static int64_t MAX_NPROBE = 999999; // todo(linxj): [1, nlist]
|
||||
|
||||
if (type == IndexType::FAISS_IVFPQ_GPU || type == IndexType::FAISS_IVFSQ8_GPU ||
|
||||
type == IndexType::FAISS_IVFSQ8_HYBRID || type == IndexType::FAISS_IVFFLAT_GPU) {
|
||||
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, GPU_MAX_NRPOBE);
|
||||
} else {
|
||||
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE);
|
||||
}
|
||||
return conf;
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type);
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
IVFSQConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFSQCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
conf->nbits = 8;
|
||||
MatchBase(conf);
|
||||
return conf;
|
||||
bool
|
||||
IVFSQConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static int64_t DEFAULT_NBITS = 8;
|
||||
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
|
||||
|
||||
return IVFConfAdapter::CheckTrain(oricfg);
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
IVFPQConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFPQCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
conf->nbits = 8;
|
||||
MatchBase(conf);
|
||||
bool
|
||||
IVFPQConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static int64_t DEFAULT_NBITS = 8;
|
||||
static int64_t MAX_NLIST = 999999;
|
||||
static int64_t MIN_NLIST = 1;
|
||||
static std::vector<std::string> CPU_METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
|
||||
static std::vector<std::string> GPU_METRICS{knowhere::Metric::L2};
|
||||
|
||||
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
Status s;
|
||||
bool enable_gpu = false;
|
||||
server::Config& config = server::Config::GetInstance();
|
||||
s = config.GetGpuResourceConfigEnable(enable_gpu);
|
||||
if (s.ok() && conf->metric_type == knowhere::METRICTYPE::IP) {
|
||||
WRAPPER_LOG_ERROR << "PQ not support IP in GPU version!";
|
||||
throw WrapperException("PQ not support IP in GPU version!");
|
||||
if (s.ok()) {
|
||||
CheckStrByValues(knowhere::Metric::TYPE, GPU_METRICS);
|
||||
} else {
|
||||
CheckStrByValues(knowhere::Metric::TYPE, CPU_METRICS);
|
||||
}
|
||||
#endif
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
|
||||
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
|
||||
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
// auto tune params
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(),
|
||||
oricfg[knowhere::IndexParams::nlist].get<int64_t>(), 16384);
|
||||
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
|
||||
/*
|
||||
* Faiss 1.6
|
||||
|
@ -147,165 +191,110 @@ IVFPQConfAdapter::Match(const TempMetaConf& metaconf) {
|
|||
static std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
|
||||
std::vector<int64_t> resset;
|
||||
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
|
||||
if (!(conf->d % dimperquantizer)) {
|
||||
auto subquantzier_num = conf->d / dimperquantizer;
|
||||
if (!(oricfg[knowhere::meta::DIM].get<int64_t>() % dimperquantizer)) {
|
||||
auto subquantzier_num = oricfg[knowhere::meta::DIM].get<int64_t>() / dimperquantizer;
|
||||
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
|
||||
if (finder != support_subquantizer.end()) {
|
||||
resset.push_back(subquantzier_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
fiu_do_on("IVFPQConfAdapter.Match.empty_resset", resset.clear());
|
||||
if (resset.empty()) {
|
||||
// todo(linxj): throw exception here.
|
||||
WRAPPER_LOG_ERROR << "The dims of PQ is wrong : only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-"
|
||||
"quantizer are currently supported with no precomputed codes.";
|
||||
throw WrapperException(
|
||||
"The dims of PQ is wrong : only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims "
|
||||
"per sub-quantizer are currently supported with no precomputed codes.");
|
||||
// return nullptr;
|
||||
}
|
||||
static int64_t compression_level = 1; // 1:low, 2:high
|
||||
if (compression_level == 1) {
|
||||
conf->m = resset[int(resset.size() / 2)];
|
||||
WRAPPER_LOG_DEBUG << "PQ m = " << conf->m << ", compression radio = " << conf->d / conf->m * 4;
|
||||
}
|
||||
return conf;
|
||||
CheckIntByValues(knowhere::IndexParams::m, resset);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
IVFPQConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::IVFPQCfg>();
|
||||
conf->k = metaconf.k;
|
||||
bool
|
||||
NSGConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static int64_t MIN_KNNG = 5;
|
||||
static int64_t MAX_KNNG = 300;
|
||||
static int64_t MIN_SEARCH_LENGTH = 10;
|
||||
static int64_t MAX_SEARCH_LENGTH = 300;
|
||||
static int64_t MIN_OUT_DEGREE = 5;
|
||||
static int64_t MAX_OUT_DEGREE = 300;
|
||||
static int64_t MIN_CANDIDATE_POOL_SIZE = 50;
|
||||
static int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2};
|
||||
|
||||
if (metaconf.nprobe <= 0) {
|
||||
WRAPPER_LOG_ERROR << "The nprobe of PQ is wrong!";
|
||||
throw WrapperException("The nprobe of PQ is wrong!");
|
||||
} else {
|
||||
conf->nprobe = metaconf.nprobe;
|
||||
}
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::knng, MIN_KNNG, MAX_KNNG);
|
||||
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
|
||||
CheckIntByRange(knowhere::IndexParams::out_degree, MIN_OUT_DEGREE, MAX_OUT_DEGREE);
|
||||
CheckIntByRange(knowhere::IndexParams::candidate, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE);
|
||||
|
||||
return conf;
|
||||
// auto tune params
|
||||
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), 8192, 8192);
|
||||
oricfg[knowhere::IndexParams::nprobe] = int(oricfg[knowhere::IndexParams::nlist].get<int64_t>() * 0.01);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t
|
||||
IVFPQConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
|
||||
if (size <= TYPICAL_COUNT / 16384 + 1) {
|
||||
// handle less row count, avoid nlist set to 0
|
||||
return 1;
|
||||
} else if (int(size / TYPICAL_COUNT) * nlist <= 0) {
|
||||
// calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
|
||||
return int(size / TYPICAL_COUNT * 16384);
|
||||
}
|
||||
return nlist;
|
||||
bool
|
||||
NSGConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
static int64_t MIN_SEARCH_LENGTH = 1;
|
||||
static int64_t MAX_SEARCH_LENGTH = 300;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type);
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
NSGConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::NSGCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
conf->k = metaconf.k;
|
||||
bool
|
||||
HNSWConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static int64_t MIN_EFCONSTRUCTION = 100;
|
||||
static int64_t MAX_EFCONSTRUCTION = 500;
|
||||
static int64_t MIN_M = 5;
|
||||
static int64_t MAX_M = 48;
|
||||
|
||||
auto scale_factor = round(metaconf.dim / 128.0);
|
||||
scale_factor = scale_factor >= 4 ? 4 : scale_factor;
|
||||
conf->nprobe = int64_t(conf->nlist * 0.01);
|
||||
// conf->knng = 40 + 10 * scale_factor; // the size of knng
|
||||
conf->knng = 50;
|
||||
conf->search_length = 50 + 5 * scale_factor;
|
||||
conf->out_degree = 50 + 5 * scale_factor;
|
||||
conf->candidate_pool_size = 300;
|
||||
MatchBase(conf);
|
||||
return conf;
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
|
||||
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
|
||||
|
||||
return ConfAdapter::CheckTrain(oricfg);
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::NSGCfg>();
|
||||
conf->k = metaconf.k;
|
||||
conf->search_length = metaconf.search_length;
|
||||
if (metaconf.search_length == TEMPMETA_DEFAULT_VALUE) {
|
||||
conf->search_length = 30; // TODO(linxj): hardcode here.
|
||||
}
|
||||
return conf;
|
||||
bool
|
||||
HNSWConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
|
||||
static int64_t MAX_EF = 4096;
|
||||
|
||||
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
|
||||
|
||||
return ConfAdapter::CheckSearch(oricfg, type);
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
SPTAGKDTConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::KDTCfg>();
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
return conf;
|
||||
bool
|
||||
BinIDMAPConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
|
||||
knowhere::Metric::TANIMOTO};
|
||||
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
SPTAGKDTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::KDTCfg>();
|
||||
conf->k = metaconf.k;
|
||||
return conf;
|
||||
}
|
||||
bool
|
||||
BinIVFConfAdapter::CheckTrain(milvus::json& oricfg) {
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
|
||||
knowhere::Metric::TANIMOTO};
|
||||
static int64_t MAX_NLIST = 999999;
|
||||
static int64_t MIN_NLIST = 1;
|
||||
|
||||
knowhere::Config
|
||||
SPTAGBKTConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::BKTCfg>();
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
return conf;
|
||||
}
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
|
||||
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
|
||||
knowhere::Config
|
||||
SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::BKTCfg>();
|
||||
conf->k = metaconf.k;
|
||||
return conf;
|
||||
}
|
||||
int64_t nlist = oricfg[knowhere::IndexParams::nlist];
|
||||
CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
|
||||
|
||||
knowhere::Config
|
||||
HNSWConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::HNSWCfg>();
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
// Best Practice
|
||||
// static int64_t MIN_POINTS_PER_CENTROID = 40;
|
||||
// static int64_t MAX_POINTS_PER_CENTROID = 256;
|
||||
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
|
||||
|
||||
conf->ef = 500; // ef can be auto-configured by using sample data.
|
||||
conf->M = 24; // A reasonable range of M is from 5 to 48.
|
||||
return conf;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
HNSWConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
|
||||
auto conf = std::make_shared<knowhere::HNSWCfg>();
|
||||
conf->k = metaconf.k;
|
||||
|
||||
if (metaconf.nprobe < metaconf.k) {
|
||||
conf->ef = metaconf.k + 32;
|
||||
} else {
|
||||
conf->ef = metaconf.nprobe;
|
||||
}
|
||||
return conf;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
BinIDMAPConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
conf->k = metaconf.k;
|
||||
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
|
||||
return conf;
|
||||
}
|
||||
|
||||
knowhere::Config
|
||||
BinIVFConfAdapter::Match(const TempMetaConf& metaconf) {
|
||||
auto conf = std::make_shared<knowhere::IVFBinCfg>();
|
||||
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 2048);
|
||||
conf->d = metaconf.dim;
|
||||
conf->metric_type = metaconf.metric_type;
|
||||
conf->gpu_id = metaconf.gpu_id;
|
||||
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
|
||||
return conf;
|
||||
return true;
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace milvus
|
||||
|
|
|
@ -14,117 +14,78 @@
|
|||
#include <memory>
|
||||
|
||||
#include "VecIndex.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "utils/Json.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
// TODO(linxj): remove later, replace with real metaconf
|
||||
constexpr int64_t TEMPMETA_DEFAULT_VALUE = -1;
|
||||
struct TempMetaConf {
|
||||
int64_t size = TEMPMETA_DEFAULT_VALUE;
|
||||
int64_t nlist = TEMPMETA_DEFAULT_VALUE;
|
||||
int64_t dim = TEMPMETA_DEFAULT_VALUE;
|
||||
int64_t gpu_id = TEMPMETA_DEFAULT_VALUE;
|
||||
int64_t k = TEMPMETA_DEFAULT_VALUE;
|
||||
int64_t nprobe = TEMPMETA_DEFAULT_VALUE;
|
||||
int64_t search_length = TEMPMETA_DEFAULT_VALUE;
|
||||
knowhere::METRICTYPE metric_type = knowhere::DEFAULT_TYPE;
|
||||
};
|
||||
|
||||
class ConfAdapter {
|
||||
public:
|
||||
virtual knowhere::Config
|
||||
Match(const TempMetaConf& metaconf);
|
||||
virtual bool
|
||||
CheckTrain(milvus::json& oricfg);
|
||||
|
||||
virtual knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type);
|
||||
virtual bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type);
|
||||
|
||||
protected:
|
||||
static void
|
||||
MatchBase(knowhere::Config conf, knowhere::METRICTYPE defalut_metric = knowhere::METRICTYPE::L2);
|
||||
// todo(linxj): refactor in next release.
|
||||
//
|
||||
// virtual bool
|
||||
// CheckTrain(milvus::json&, IndexMode&) = 0;
|
||||
//
|
||||
// virtual bool
|
||||
// CheckSearch(milvus::json&, const IndexType&, IndexMode&) = 0;
|
||||
};
|
||||
|
||||
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
|
||||
|
||||
class IVFConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
|
||||
knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
|
||||
|
||||
protected:
|
||||
static int64_t
|
||||
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist);
|
||||
bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
|
||||
};
|
||||
|
||||
class IVFSQConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
};
|
||||
|
||||
class IVFPQConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
|
||||
knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
|
||||
|
||||
protected:
|
||||
static int64_t
|
||||
MatchNlist(const int64_t& size, const int64_t& nlist);
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
};
|
||||
|
||||
class NSGConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
|
||||
knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) final;
|
||||
};
|
||||
|
||||
class SPTAGKDTConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
|
||||
knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
|
||||
};
|
||||
|
||||
class SPTAGBKTConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
|
||||
knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
|
||||
bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
|
||||
};
|
||||
|
||||
class BinIDMAPConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
};
|
||||
|
||||
class BinIVFConfAdapter : public IVFConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
};
|
||||
|
||||
class HNSWConfAdapter : public ConfAdapter {
|
||||
public:
|
||||
knowhere::Config
|
||||
Match(const TempMetaConf& metaconf) override;
|
||||
bool
|
||||
CheckTrain(milvus::json& oricfg) override;
|
||||
|
||||
knowhere::Config
|
||||
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
|
||||
bool
|
||||
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
|
||||
};
|
||||
|
||||
} // namespace engine
|
||||
|
|
|
@ -54,8 +54,8 @@ AdapterMgr::RegisterAdapter() {
|
|||
|
||||
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexType::NSG_MIX, nsg_mix);
|
||||
|
||||
REGISTER_CONF_ADAPTER(SPTAGKDTConfAdapter, IndexType::SPTAG_KDT_RNT_CPU, sptag_kdt);
|
||||
REGISTER_CONF_ADAPTER(SPTAGBKTConfAdapter, IndexType::SPTAG_BKT_RNT_CPU, sptag_bkt);
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexType::SPTAG_KDT_RNT_CPU, sptag_kdt);
|
||||
REGISTER_CONF_ADAPTER(ConfAdapter, IndexType::SPTAG_BKT_RNT_CPU, sptag_bkt);
|
||||
|
||||
REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexType::HNSW, hnsw);
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ Status
|
|||
VecIndexImpl::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const float* xt) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
fiu_do_on("VecIndexImpl.BuildAll.throw_knowhere_exception", throw knowhere::KnowhereException(""));
|
||||
fiu_do_on("VecIndexImpl.BuildAll.throw_std_exception", throw std::exception());
|
||||
|
@ -80,15 +80,13 @@ VecIndexImpl::Add(const int64_t& nb, const float* xb, const int64_t* ids, const
|
|||
Status
|
||||
VecIndexImpl::Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) {
|
||||
try {
|
||||
auto k = cfg->k;
|
||||
int64_t k = cfg[knowhere::meta::TOPK];
|
||||
auto dataset = GenDataset(nq, dim, xq);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
|
||||
fiu_do_on("VecIndexImpl.Search.throw_knowhere_exception", throw knowhere::KnowhereException(""));
|
||||
fiu_do_on("VecIndexImpl.Search.throw_std_exception", throw std::exception());
|
||||
|
||||
auto res = index_->Search(dataset, search_cfg);
|
||||
auto res = index_->Search(dataset, cfg);
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
|
@ -216,8 +214,7 @@ VecIndexImpl::GetVectorById(const int64_t n, const int64_t* xid, float* x, const
|
|||
dataset->Set(knowhere::meta::DIM, dim);
|
||||
dataset->Set(knowhere::meta::IDS, xid);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
auto res = index_->GetVectorById(dataset, search_cfg);
|
||||
auto res = index_->GetVectorById(dataset, cfg);
|
||||
|
||||
// TODO(linxj): avoid copy here.
|
||||
auto res_x = res->Get<float*>(knowhere::meta::TENSOR);
|
||||
|
@ -242,14 +239,13 @@ VecIndexImpl::SearchById(const int64_t& nq, const int64_t* xq, float* dist, int6
|
|||
}
|
||||
|
||||
try {
|
||||
auto k = cfg->k;
|
||||
int64_t k = cfg[knowhere::meta::TOPK];
|
||||
auto dataset = std::make_shared<knowhere::Dataset>();
|
||||
dataset->Set(knowhere::meta::ROWS, nq);
|
||||
dataset->Set(knowhere::meta::DIM, dim);
|
||||
dataset->Set(knowhere::meta::IDS, xq);
|
||||
|
||||
Config search_cfg = cfg;
|
||||
auto res = index_->SearchById(dataset, search_cfg);
|
||||
auto res = index_->SearchById(dataset, cfg);
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
|
@ -337,7 +333,7 @@ BFIndex::Build(const Config& cfg) {
|
|||
try {
|
||||
fiu_do_on("BFIndex.Build.throw_knowhere_exception", throw knowhere::KnowhereException(""));
|
||||
fiu_do_on("BFIndex.Build.throw_std_exception", throw std::exception());
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg);
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
|
@ -353,7 +349,7 @@ Status
|
|||
BFIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
|
||||
const float* xt) {
|
||||
try {
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
fiu_do_on("BFIndex.BuildAll.throw_knowhere_exception", throw knowhere::KnowhereException(""));
|
||||
fiu_do_on("BFIndex.BuildAll.throw_std_exception", throw std::exception());
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <faiss/utils/ConcurrentBitset.h>
|
||||
#include <thirdparty/nlohmann/json.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -29,7 +30,8 @@
|
|||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
using Config = knowhere::Config;
|
||||
using json = nlohmann::json;
|
||||
using Config = json;
|
||||
|
||||
// TODO(linxj): replace with string, Do refactor serialization
|
||||
enum class IndexType {
|
||||
|
|
|
@ -38,7 +38,7 @@ IVFMixIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, co
|
|||
fiu_do_on("IVFMixIndex.BuildAll.throw_knowhere_exception", throw knowhere::KnowhereException(""));
|
||||
fiu_do_on("IVFMixIndex.BuildAll.throw_std_exception", throw std::exception());
|
||||
|
||||
dim = cfg->d;
|
||||
dim = cfg[knowhere::meta::DIM];
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
|
||||
index_->set_preprocessor(preprocessor);
|
||||
|
|
|
@ -186,16 +186,20 @@ TEST_F(DBTest, DB_TEST) {
|
|||
std::stringstream ss;
|
||||
uint64_t count = 0;
|
||||
uint64_t prev_count = 0;
|
||||
milvus::json json_params = {{"nprobe", 10}};
|
||||
|
||||
for (auto j = 0; j < 10; ++j) {
|
||||
ss.str("");
|
||||
db_->Size(count);
|
||||
prev_count = count;
|
||||
if (count == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
START_TIMER;
|
||||
|
||||
std::vector<std::string> tags;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, qxb, result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
|
@ -306,37 +310,41 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
stat = db_->InsertVectors(TABLE_NAME, "", xb);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
milvus::json json_params = {{"nprobe", 10}};
|
||||
milvus::engine::TableIndex index;
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IDMAP;
|
||||
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
|
||||
{
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFFLAT;
|
||||
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
|
||||
{
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
index.extra_params_ = {{"nlist", 16384}};
|
||||
// db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
//
|
||||
// {
|
||||
// std::vector<std::string> tags;
|
||||
// milvus::engine::ResultIds result_ids;
|
||||
// milvus::engine::ResultDistances result_distances;
|
||||
// stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
// ASSERT_TRUE(stat.ok());
|
||||
// }
|
||||
//
|
||||
// index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFFLAT;
|
||||
// index.extra_params_ = {{"nlist", 16384}};
|
||||
// db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
//
|
||||
// {
|
||||
// std::vector<std::string> tags;
|
||||
// milvus::engine::ResultIds result_ids;
|
||||
// milvus::engine::ResultDistances result_distances;
|
||||
// stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
// ASSERT_TRUE(stat.ok());
|
||||
// }
|
||||
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
|
||||
index.extra_params_ = {{"nlist", 16384}};
|
||||
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
|
||||
|
||||
{
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -349,7 +357,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
#endif
|
||||
|
@ -365,16 +373,22 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
}
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_distances);
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k,
|
||||
json_params,
|
||||
xq,
|
||||
result_ids,
|
||||
result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
FIU_ENABLE_FIU("SqliteMetaImpl.FilesToSearch.throw_exception");
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_distances);
|
||||
stat =
|
||||
db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
fiu_disable("SqliteMetaImpl.FilesToSearch.throw_exception");
|
||||
|
||||
FIU_ENABLE_FIU("DBImpl.QueryByFileID.empty_files_array");
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_distances);
|
||||
stat =
|
||||
db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
fiu_disable("DBImpl.QueryByFileID.empty_files_array");
|
||||
}
|
||||
|
@ -385,13 +399,13 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
FIU_ENABLE_FIU("SqliteMetaImpl.FilesToSearch.throw_exception");
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
fiu_disable("SqliteMetaImpl.FilesToSearch.throw_exception");
|
||||
}
|
||||
|
@ -409,7 +423,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
{
|
||||
result_ids.clear();
|
||||
result_dists.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, 10, xq, result_ids, result_dists);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, json_params, xq, result_ids, result_dists);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -423,7 +437,7 @@ TEST_F(DBTest, SEARCH_TEST) {
|
|||
}
|
||||
result_ids.clear();
|
||||
result_dists.clear();
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_dists);
|
||||
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, json_params, xq, result_ids, result_dists);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
#endif
|
||||
|
@ -564,13 +578,27 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, xb, result_ids, result_distances);
|
||||
milvus::json json_params = {{"nprobe", 1}};
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, tags, 1, json_params, xb, result_ids, result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
std::vector<std::string> file_ids;
|
||||
stat = db_->QueryByFileID(dummy_context_, table_info.table_id_, file_ids, 1, 1, xb, result_ids, result_distances);
|
||||
stat = db_->QueryByFileID(dummy_context_,
|
||||
table_info.table_id_,
|
||||
file_ids,
|
||||
1,
|
||||
json_params,
|
||||
xb,
|
||||
result_ids,
|
||||
result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, milvus::engine::VectorsData(), result_ids,
|
||||
stat = db_->Query(dummy_context_,
|
||||
table_info.table_id_,
|
||||
tags,
|
||||
1,
|
||||
json_params,
|
||||
milvus::engine::VectorsData(),
|
||||
result_ids,
|
||||
result_distances);
|
||||
ASSERT_FALSE(stat.ok());
|
||||
|
||||
|
@ -731,7 +759,7 @@ TEST_F(DBTest, INDEX_TEST) {
|
|||
stat = db_->DescribeIndex(table_info.table_id_, index_out);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(index.engine_type_, index_out.engine_type_);
|
||||
ASSERT_EQ(index.nlist_, index_out.nlist_);
|
||||
ASSERT_EQ(index.extra_params_, index_out.extra_params_);
|
||||
ASSERT_EQ(table_info.metric_type_, index_out.metric_type_);
|
||||
|
||||
stat = db_->DropIndex(table_info.table_id_);
|
||||
|
@ -845,7 +873,9 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -853,7 +883,7 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
tags.clear();
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -861,7 +891,7 @@ TEST_F(DBTest, PARTITION_TEST) {
|
|||
tags.push_back("\\d");
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
}
|
||||
|
@ -1074,11 +1104,12 @@ TEST_F(DBTestWAL, DB_STOP_TEST) {
|
|||
|
||||
const int64_t topk = 10;
|
||||
const int64_t nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
milvus::engine::VectorsData qxb;
|
||||
BuildVectors(qb, 0, qxb);
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, qb);
|
||||
|
||||
|
@ -1102,11 +1133,12 @@ TEST_F(DBTestWALRecovery, RECOVERY_WITH_NO_ERROR) {
|
|||
|
||||
const int64_t topk = 10;
|
||||
const int64_t nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
milvus::engine::VectorsData qxb;
|
||||
BuildVectors(qb, 0, qxb);
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_NE(result_ids.size() / topk, qb);
|
||||
|
||||
|
@ -1119,14 +1151,14 @@ TEST_F(DBTestWALRecovery, RECOVERY_WITH_NO_ERROR) {
|
|||
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size(), 0);
|
||||
|
||||
db_->Flush();
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, qb);
|
||||
}
|
||||
|
@ -1312,6 +1344,7 @@ TEST_F(DBTest2, SEARCH_WITH_DIFFERENT_INDEX) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
int topk = 10, nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
for (auto id : ids_to_search) {
|
||||
// std::cout << "xxxxxxxxxxxxxxxxxxxx " << i << std::endl;
|
||||
|
@ -1319,7 +1352,7 @@ TEST_F(DBTest2, SEARCH_WITH_DIFFERENT_INDEX) {
|
|||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
|
||||
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, nprobe, id, result_ids,
|
||||
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, json_params, id, result_ids,
|
||||
result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids[0], id);
|
||||
|
@ -1341,7 +1374,7 @@ result_distances);
|
|||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
|
||||
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, nprobe, id, result_ids,
|
||||
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, json_params, id, result_ids,
|
||||
result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids[0], id);
|
||||
|
|
|
@ -78,16 +78,20 @@ TEST_F(MySqlDBTest, DB_TEST) {
|
|||
std::stringstream ss;
|
||||
uint64_t count = 0;
|
||||
uint64_t prev_count = 0;
|
||||
milvus::json json_params = {{"nprobe", 10}};
|
||||
|
||||
for (auto j = 0; j < 10; ++j) {
|
||||
ss.str("");
|
||||
db_->Size(count);
|
||||
prev_count = count;
|
||||
if (count == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
START_TIMER;
|
||||
|
||||
std::vector<std::string> tags;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, qxb, result_ids, result_distances);
|
||||
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
STOP_TIMER(ss.str());
|
||||
|
||||
|
@ -186,7 +190,8 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
|
|||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
|
||||
milvus::json json_params = {{"nprobe", 10}};
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
}
|
||||
|
||||
|
@ -377,7 +382,8 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -385,7 +391,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
tags.clear();
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
|
||||
|
@ -393,7 +399,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
|
|||
tags.push_back("\\d");
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(result_ids.size() / topk, nq);
|
||||
}
|
||||
|
|
|
@ -301,6 +301,7 @@ TEST_F(DeleteTest, delete_with_index) {
|
|||
|
||||
milvus::engine::TableIndex index;
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
|
||||
index.extra_params_ = {{"nlist", 100}};
|
||||
stat = db_->CreateIndex(GetTableName(), index);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
|
@ -368,12 +369,13 @@ TEST_F(DeleteTest, delete_single_vector) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(row_count, 0);
|
||||
|
||||
int topk = 1, nprobe = 1;
|
||||
const int topk = 1, nprobe = 1;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, xb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, xb, result_ids, result_distances);
|
||||
ASSERT_TRUE(result_ids.empty());
|
||||
ASSERT_TRUE(result_distances.empty());
|
||||
// ASSERT_EQ(result_ids[0], -1);
|
||||
|
@ -402,6 +404,7 @@ TEST_F(DeleteTest, delete_add_create_index) {
|
|||
// ASSERT_TRUE(stat.ok());
|
||||
milvus::engine::TableIndex index;
|
||||
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
|
||||
index.extra_params_ = {{"nlist", 100}};
|
||||
stat = db_->CreateIndex(GetTableName(), index);
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
|
@ -426,7 +429,8 @@ TEST_F(DeleteTest, delete_add_create_index) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(row_count, nb * 2 - 1);
|
||||
|
||||
int topk = 10, nprobe = 10;
|
||||
const int topk = 10, nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
|
@ -435,14 +439,14 @@ TEST_F(DeleteTest, delete_add_create_index) {
|
|||
qb.float_data_.resize(TABLE_DIM);
|
||||
qb.vector_count_ = 1;
|
||||
qb.id_array_.clear();
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, qb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, qb, result_ids, result_distances);
|
||||
|
||||
ASSERT_EQ(result_ids[0], xb2.id_array_.front());
|
||||
ASSERT_LT(result_distances[0], 1e-4);
|
||||
|
||||
result_ids.clear();
|
||||
result_distances.clear();
|
||||
stat = db_->QueryByID(dummy_context_, GetTableName(), tags, topk, nprobe, ids_to_delete.front(), result_ids,
|
||||
stat = db_->QueryByID(dummy_context_, GetTableName(), tags, topk, json_params, ids_to_delete.front(), result_ids,
|
||||
result_distances);
|
||||
ASSERT_EQ(result_ids[0], -1);
|
||||
ASSERT_EQ(result_distances[0], std::numeric_limits<float>::max());
|
||||
|
@ -496,7 +500,8 @@ TEST_F(DeleteTest, delete_add_auto_flush) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_EQ(row_count, nb * 2 - 1);
|
||||
|
||||
int topk = 10, nprobe = 10;
|
||||
const int topk = 10, nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
|
@ -505,7 +510,7 @@ TEST_F(DeleteTest, delete_add_auto_flush) {
|
|||
qb.float_data_.resize(TABLE_DIM);
|
||||
qb.vector_count_ = 1;
|
||||
qb.id_array_.clear();
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, qb, result_ids, result_distances);
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, qb, result_ids, result_distances);
|
||||
|
||||
ASSERT_EQ(result_ids[0], xb2.id_array_.front());
|
||||
ASSERT_LT(result_distances[0], 1e-4);
|
||||
|
@ -555,7 +560,8 @@ TEST_F(CompactTest, compact_basic) {
|
|||
stat = db_->Compact(GetTableName());
|
||||
ASSERT_TRUE(stat.ok());
|
||||
|
||||
int topk = 1, nprobe = 1;
|
||||
const int topk = 1, nprobe = 1;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
|
@ -563,7 +569,8 @@ TEST_F(CompactTest, compact_basic) {
|
|||
milvus::engine::VectorsData qb = xb;
|
||||
|
||||
for (auto& id : ids_to_delete) {
|
||||
stat = db_->QueryByID(dummy_context_, GetTableName(), tags, topk, nprobe, id, result_ids, result_distances);
|
||||
stat =
|
||||
db_->QueryByID(dummy_context_, GetTableName(), tags, topk, json_params, id, result_ids, result_distances);
|
||||
ASSERT_EQ(result_ids[0], -1);
|
||||
ASSERT_EQ(result_distances[0], std::numeric_limits<float>::max());
|
||||
}
|
||||
|
@ -643,14 +650,17 @@ TEST_F(CompactTest, compact_with_index) {
|
|||
ASSERT_TRUE(stat.ok());
|
||||
ASSERT_FLOAT_EQ(table_index.engine_type_, index.engine_type_);
|
||||
|
||||
int topk = 10, nprobe = 10;
|
||||
const int topk = 10, nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
|
||||
for (auto& pair : search_vectors) {
|
||||
auto& search = pair.second;
|
||||
|
||||
std::vector<std::string> tags;
|
||||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, search, result_ids, result_distances);
|
||||
stat =
|
||||
db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, search, result_ids, result_distances);
|
||||
ASSERT_NE(result_ids[0], pair.first);
|
||||
// ASSERT_LT(result_distances[0], 1e-4);
|
||||
ASSERT_GT(result_distances[0], 1);
|
||||
|
|
|
@ -19,17 +19,58 @@
|
|||
#include <fiu-local.h>
|
||||
#include <fiu-control.h>
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr uint16_t DIMENSION = 64;
|
||||
static constexpr int64_t ROW_COUNT = 1000;
|
||||
static const char* INIT_PATH = "/tmp/milvus_index_1";
|
||||
|
||||
milvus::engine::ExecutionEnginePtr
|
||||
CreateExecEngine(const milvus::json& json_params, milvus::engine::MetricType metric = milvus::engine::MetricType::IP) {
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
DIMENSION,
|
||||
INIT_PATH,
|
||||
milvus::engine::EngineType::FAISS_IDMAP,
|
||||
metric,
|
||||
json_params);
|
||||
|
||||
std::vector<float> data;
|
||||
std::vector<int64_t> ids;
|
||||
data.reserve(ROW_COUNT * DIMENSION);
|
||||
ids.reserve(ROW_COUNT);
|
||||
for (int64_t i = 0; i < ROW_COUNT; i++) {
|
||||
ids.push_back(i);
|
||||
for (uint16_t k = 0; k < DIMENSION; k++) {
|
||||
data.push_back(i * DIMENSION + k);
|
||||
}
|
||||
}
|
||||
|
||||
auto status = engine_ptr->AddWithIds((int64_t)ids.size(), data.data(), ids.data());
|
||||
return engine_ptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(EngineTest, FACTORY_TEST) {
|
||||
const milvus::json index_params = {{"nlist", 1024}};
|
||||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::INVALID, milvus::engine::MetricType::IP, 1024);
|
||||
512,
|
||||
"/tmp/milvus_index_1",
|
||||
milvus::engine::EngineType::INVALID,
|
||||
milvus::engine::MetricType::IP,
|
||||
index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr == nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_IDMAP, milvus::engine::MetricType::IP, 1024);
|
||||
512,
|
||||
"/tmp/milvus_index_1",
|
||||
milvus::engine::EngineType::FAISS_IDMAP,
|
||||
milvus::engine::MetricType::IP,
|
||||
index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
@ -37,28 +78,40 @@ TEST_F(EngineTest, FACTORY_TEST) {
|
|||
{
|
||||
auto engine_ptr =
|
||||
milvus::engine::EngineFactory::Build(512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_IVFFLAT,
|
||||
milvus::engine::MetricType::IP, 1024);
|
||||
milvus::engine::MetricType::IP, index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_IVFSQ8, milvus::engine::MetricType::IP, 1024);
|
||||
512,
|
||||
"/tmp/milvus_index_1",
|
||||
milvus::engine::EngineType::FAISS_IVFSQ8,
|
||||
milvus::engine::MetricType::IP,
|
||||
index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::NSG_MIX, milvus::engine::MetricType::IP, 1024);
|
||||
512,
|
||||
"/tmp/milvus_index_1",
|
||||
milvus::engine::EngineType::NSG_MIX,
|
||||
milvus::engine::MetricType::IP,
|
||||
index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_PQ, milvus::engine::MetricType::IP, 1024);
|
||||
512,
|
||||
"/tmp/milvus_index_1",
|
||||
milvus::engine::EngineType::FAISS_PQ,
|
||||
milvus::engine::MetricType::IP,
|
||||
index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
@ -66,7 +119,7 @@ TEST_F(EngineTest, FACTORY_TEST) {
|
|||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
|
||||
milvus::engine::MetricType::L2, 1024);
|
||||
milvus::engine::MetricType::L2, index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
@ -74,7 +127,7 @@ TEST_F(EngineTest, FACTORY_TEST) {
|
|||
{
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
|
||||
milvus::engine::MetricType::L2, 1024);
|
||||
milvus::engine::MetricType::L2, index_params);
|
||||
|
||||
ASSERT_TRUE(engine_ptr != nullptr);
|
||||
}
|
||||
|
@ -85,7 +138,7 @@ TEST_F(EngineTest, FACTORY_TEST) {
|
|||
FIU_ENABLE_FIU("ExecutionEngineImpl.CreatetVecIndex.invalid_type");
|
||||
ASSERT_ANY_THROW(milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
|
||||
milvus::engine::MetricType::L2, 1024));
|
||||
milvus::engine::MetricType::L2, index_params));
|
||||
fiu_disable("ExecutionEngineImpl.CreatetVecIndex.invalid_type");
|
||||
}
|
||||
|
||||
|
@ -94,91 +147,88 @@ TEST_F(EngineTest, FACTORY_TEST) {
|
|||
FIU_ENABLE_FIU("BFIndex.Build.throw_knowhere_exception");
|
||||
ASSERT_ANY_THROW(milvus::engine::EngineFactory::Build(
|
||||
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
|
||||
milvus::engine::MetricType::L2, 1024));
|
||||
milvus::engine::MetricType::L2, index_params));
|
||||
fiu_disable("BFIndex.Build.throw_knowhere_exception");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(EngineTest, ENGINE_IMPL_TEST) {
|
||||
fiu_init(0);
|
||||
uint16_t dimension = 64;
|
||||
std::string file_path = "/tmp/milvus_index_1";
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, 1024);
|
||||
|
||||
std::vector<float> data;
|
||||
std::vector<int64_t> ids;
|
||||
const int row_count = 500;
|
||||
data.reserve(row_count * dimension);
|
||||
ids.reserve(row_count);
|
||||
for (int64_t i = 0; i < row_count; i++) {
|
||||
ids.push_back(i);
|
||||
for (uint16_t k = 0; k < dimension; k++) {
|
||||
data.push_back(i * dimension + k);
|
||||
}
|
||||
{
|
||||
milvus::json index_params = {{"nlist", 10}};
|
||||
auto engine_ptr = CreateExecEngine(index_params);
|
||||
|
||||
ASSERT_EQ(engine_ptr->Dimension(), DIMENSION);
|
||||
ASSERT_EQ(engine_ptr->Count(), ROW_COUNT);
|
||||
ASSERT_EQ(engine_ptr->GetLocation(), INIT_PATH);
|
||||
ASSERT_EQ(engine_ptr->IndexMetricType(), milvus::engine::MetricType::IP);
|
||||
|
||||
ASSERT_ANY_THROW(engine_ptr->BuildIndex(INIT_PATH, milvus::engine::EngineType::INVALID));
|
||||
FIU_ENABLE_FIU("VecIndexImpl.BuildAll.throw_knowhere_exception");
|
||||
ASSERT_ANY_THROW(engine_ptr->BuildIndex(INIT_PATH, milvus::engine::EngineType::SPTAG_KDT));
|
||||
fiu_disable("VecIndexImpl.BuildAll.throw_knowhere_exception");
|
||||
|
||||
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_2", milvus::engine::EngineType::FAISS_IVFSQ8);
|
||||
ASSERT_NE(engine_build, nullptr);
|
||||
}
|
||||
|
||||
auto status = engine_ptr->AddWithIds((int64_t)ids.size(), data.data(), ids.data());
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
ASSERT_EQ(engine_ptr->Dimension(), dimension);
|
||||
ASSERT_EQ(engine_ptr->Count(), ids.size());
|
||||
|
||||
ASSERT_EQ(engine_ptr->GetLocation(), file_path);
|
||||
ASSERT_EQ(engine_ptr->IndexMetricType(), milvus::engine::MetricType::IP);
|
||||
|
||||
ASSERT_ANY_THROW(engine_ptr->BuildIndex(file_path, milvus::engine::EngineType::INVALID));
|
||||
FIU_ENABLE_FIU("VecIndexImpl.BuildAll.throw_knowhere_exception");
|
||||
ASSERT_ANY_THROW(engine_ptr->BuildIndex(file_path, milvus::engine::EngineType::SPTAG_KDT));
|
||||
fiu_disable("VecIndexImpl.BuildAll.throw_knowhere_exception");
|
||||
|
||||
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_2", milvus::engine::EngineType::FAISS_IVFSQ8);
|
||||
{
|
||||
#ifndef MILVUS_GPU_VERSION
|
||||
//PQ don't support IP In gpu version
|
||||
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_3", milvus::engine::EngineType::FAISS_PQ);
|
||||
milvus::json index_params = {{"nlist", 10}, {"m", 16}};
|
||||
auto engine_ptr = CreateExecEngine(index_params);
|
||||
//PQ don't support IP In gpu version
|
||||
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_3", milvus::engine::EngineType::FAISS_PQ);
|
||||
ASSERT_NE(engine_build, nullptr);
|
||||
#endif
|
||||
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_4", milvus::engine::EngineType::SPTAG_KDT);
|
||||
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_5", milvus::engine::EngineType::SPTAG_BKT);
|
||||
engine_ptr->BuildIndex("/tmp/milvus_index_SPTAG_BKT", milvus::engine::EngineType::SPTAG_BKT);
|
||||
}
|
||||
|
||||
{
|
||||
milvus::json index_params = {{"nlist", 10}};
|
||||
auto engine_ptr = CreateExecEngine(index_params);
|
||||
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_4", milvus::engine::EngineType::SPTAG_KDT);
|
||||
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_5", milvus::engine::EngineType::SPTAG_BKT);
|
||||
engine_ptr->BuildIndex("/tmp/milvus_index_SPTAG_BKT", milvus::engine::EngineType::SPTAG_BKT);
|
||||
|
||||
//CPU version invoke CopyToCpu will fail
|
||||
auto status = engine_ptr->CopyToCpu();
|
||||
ASSERT_FALSE(status.ok());
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
FIU_ENABLE_FIU("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
|
||||
engine_ptr->BuildIndex("/tmp/milvus_index_NSG_MIX", milvus::engine::EngineType::NSG_MIX);
|
||||
engine_ptr->BuildIndex("/tmp/milvus_index_6", milvus::engine::EngineType::FAISS_IVFFLAT);
|
||||
engine_ptr->BuildIndex("/tmp/milvus_index_7", milvus::engine::EngineType::FAISS_IVFSQ8);
|
||||
ASSERT_ANY_THROW(engine_ptr->BuildIndex("/tmp/milvus_index_8", milvus::engine::EngineType::FAISS_IVFSQ8H));
|
||||
ASSERT_ANY_THROW(engine_ptr->BuildIndex("/tmp/milvus_index_9", milvus::engine::EngineType::FAISS_PQ));
|
||||
fiu_disable("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
|
||||
#endif
|
||||
{
|
||||
FIU_ENABLE_FIU("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
|
||||
milvus::json index_params = {{"search_length", 100}, {"out_degree", 40}, {"pool_size", 100}, {"knng", 200},
|
||||
{"candidate_pool_size", 500}};
|
||||
auto engine_ptr = CreateExecEngine(index_params, milvus::engine::MetricType::L2);
|
||||
engine_ptr->BuildIndex("/tmp/milvus_index_NSG_MIX", milvus::engine::EngineType::NSG_MIX);
|
||||
fiu_disable("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
|
||||
|
||||
//CPU version invoke CopyToCpu will fail
|
||||
status = engine_ptr->CopyToCpu();
|
||||
ASSERT_FALSE(status.ok());
|
||||
auto status = engine_ptr->CopyToGpu(0, false);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine_ptr->GpuCache(0);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine_ptr->CopyToGpu(0, false);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
status = engine_ptr->CopyToGpu(0, false);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine_ptr->GpuCache(0);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine_ptr->CopyToGpu(0, false);
|
||||
ASSERT_TRUE(status.ok());
|
||||
// auto new_engine = engine_ptr->Clone();
|
||||
// ASSERT_EQ(new_engine->Dimension(), dimension);
|
||||
// ASSERT_EQ(new_engine->Count(), ids.size());
|
||||
|
||||
// auto new_engine = engine_ptr->Clone();
|
||||
// ASSERT_EQ(new_engine->Dimension(), dimension);
|
||||
// ASSERT_EQ(new_engine->Count(), ids.size());
|
||||
|
||||
status = engine_ptr->CopyToCpu();
|
||||
ASSERT_TRUE(status.ok());
|
||||
engine_ptr->CopyToCpu();
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine_ptr->CopyToCpu();
|
||||
ASSERT_TRUE(status.ok());
|
||||
engine_ptr->CopyToCpu();
|
||||
ASSERT_TRUE(status.ok());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(EngineTest, ENGINE_IMPL_NULL_INDEX_TEST) {
|
||||
uint16_t dimension = 64;
|
||||
std::string file_path = "/tmp/milvus_index_1";
|
||||
milvus::json index_params = {{"nlist", 1024}};
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, 1024);
|
||||
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, index_params);
|
||||
|
||||
fiu_init(0); // init
|
||||
fiu_enable("read_null_index", 1, NULL, 0);
|
||||
|
@ -209,11 +259,13 @@ TEST_F(EngineTest, ENGINE_IMPL_NULL_INDEX_TEST) {
|
|||
TEST_F(EngineTest, ENGINE_IMPL_THROW_EXCEPTION_TEST) {
|
||||
uint16_t dimension = 64;
|
||||
std::string file_path = "/tmp/invalid_file";
|
||||
milvus::json index_params = {{"nlist", 1024}};
|
||||
|
||||
fiu_init(0); // init
|
||||
fiu_enable("ValidateStringNotBool", 1, NULL, 0);
|
||||
|
||||
auto engine_ptr = milvus::engine::EngineFactory::Build(
|
||||
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, 1024);
|
||||
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, index_params);
|
||||
|
||||
fiu_disable("ValidateStringNotBool");
|
||||
|
||||
|
|
|
@ -285,7 +285,8 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
|
|||
search_vectors.insert(std::make_pair(xb.id_array_[index], search));
|
||||
}
|
||||
|
||||
int topk = 10, nprobe = 10;
|
||||
const int topk = 10, nprobe = 10;
|
||||
milvus::json json_params = {{"nprobe", nprobe}};
|
||||
for (auto& pair : search_vectors) {
|
||||
auto& search = pair.second;
|
||||
|
||||
|
@ -293,7 +294,8 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
|
|||
milvus::engine::ResultIds result_ids;
|
||||
milvus::engine::ResultDistances result_distances;
|
||||
|
||||
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, search, result_ids, result_distances);
|
||||
stat =
|
||||
db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, search, result_ids, result_distances);
|
||||
ASSERT_EQ(result_ids[0], pair.first);
|
||||
ASSERT_LT(result_distances[0], 1e-4);
|
||||
}
|
||||
|
@ -388,6 +390,7 @@ TEST_F(MemManagerTest2, INSERT_BINARY_TEST) {
|
|||
// std::stringstream ss;
|
||||
// uint64_t count = 0;
|
||||
// uint64_t prev_count = 0;
|
||||
// milvus::json json_params = {{"nprobe", 10}};
|
||||
//
|
||||
// for (auto j = 0; j < 10; ++j) {
|
||||
// ss.str("");
|
||||
|
@ -397,7 +400,8 @@ TEST_F(MemManagerTest2, INSERT_BINARY_TEST) {
|
|||
// START_TIMER;
|
||||
//
|
||||
// std::vector<std::string> tags;
|
||||
// stat = db_->Query(dummy_context_, GetTableName(), tags, k, 10, qxb, result_ids, result_distances);
|
||||
// stat =
|
||||
// db_->Query(dummy_context_, GetTableName(), tags, k, json_params, qxb, result_ids, result_distances);
|
||||
// ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
|
||||
// STOP_TIMER(ss.str());
|
||||
//
|
||||
|
|
|
@ -647,7 +647,7 @@ TEST_F(MetaTest, INDEX_TEST) {
|
|||
|
||||
milvus::engine::TableIndex index;
|
||||
index.metric_type_ = 2;
|
||||
index.nlist_ = 1234;
|
||||
index.extra_params_ = {{"nlist", 1234}};
|
||||
index.engine_type_ = 3;
|
||||
status = impl_->UpdateTableIndex(table_id, index);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
@ -664,14 +664,13 @@ TEST_F(MetaTest, INDEX_TEST) {
|
|||
milvus::engine::TableIndex index_out;
|
||||
status = impl_->DescribeTableIndex(table_id, index_out);
|
||||
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
|
||||
ASSERT_EQ(index_out.nlist_, index.nlist_);
|
||||
ASSERT_EQ(index_out.extra_params_, index.extra_params_);
|
||||
ASSERT_EQ(index_out.engine_type_, index.engine_type_);
|
||||
|
||||
status = impl_->DropTableIndex(table_id);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = impl_->DescribeTableIndex(table_id, index_out);
|
||||
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
|
||||
ASSERT_NE(index_out.nlist_, index.nlist_);
|
||||
ASSERT_NE(index_out.engine_type_, index.engine_type_);
|
||||
|
||||
status = impl_->UpdateTableFilesToIndex(table_id);
|
||||
|
|
|
@ -700,7 +700,7 @@ TEST_F(MySqlMetaTest, INDEX_TEST) {
|
|||
|
||||
milvus::engine::TableIndex index;
|
||||
index.metric_type_ = 2;
|
||||
index.nlist_ = 1234;
|
||||
index.extra_params_ = {{"nlist", 1234}};
|
||||
index.engine_type_ = 3;
|
||||
status = impl_->UpdateTableIndex(table_id, index);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
@ -740,14 +740,13 @@ TEST_F(MySqlMetaTest, INDEX_TEST) {
|
|||
milvus::engine::TableIndex index_out;
|
||||
status = impl_->DescribeTableIndex(table_id, index_out);
|
||||
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
|
||||
ASSERT_EQ(index_out.nlist_, index.nlist_);
|
||||
ASSERT_EQ(index_out.extra_params_, index.extra_params_);
|
||||
ASSERT_EQ(index_out.engine_type_, index.engine_type_);
|
||||
|
||||
status = impl_->DropTableIndex(table_id);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = impl_->DescribeTableIndex(table_id, index_out);
|
||||
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
|
||||
ASSERT_NE(index_out.nlist_, index.nlist_);
|
||||
ASSERT_NE(index_out.engine_type_, index.engine_type_);
|
||||
|
||||
FIU_ENABLE_FIU("MySQLMetaImpl.DescribeTableIndex.null_connection");
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue