Merge branch 'master' into caiyd_codec_opt

pull/1538/head
Cai Yudong 2020-03-08 14:17:07 +08:00 committed by GitHub
commit 5afef85466
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
145 changed files with 29248 additions and 5783 deletions

View File

@ -11,16 +11,16 @@ Please mark all change in change log and use the issue from GitHub
- \#805 IVFTest.gpu_seal_test unittest failed
- \#831 Judge branch error in CommonUtil.cpp
- \#977 Server crash when create tables concurrently
- \#990 check gpu resources setting when assign repeated value
- \#990 Check gpu resources setting when assign repeated value
- \#995 table count set to 0 if no tables found
- \#1010 improve error message when offset or page_size is equal 0
- \#1010 Improve error message when offset or page_size is equal 0
- \#1022 check if partition name is legal
- \#1028 check if table exists when show partitions
- \#1029 check if table exists when try to delete partition
- \#1066 optimize http insert and search speed
- \#1067 Add binary vectors support in http server
- \#1075 improve error message when page size or offset is illegal
- \#1082 check page_size or offset value to avoid float
- \#1075 Improve error message when page size or offset is illegal
- \#1082 Check page_size or offset value to avoid float
- \#1115 http server support load table into memory
- \#1152 Error log output continuously after server start
- \#1211 Server down caused by searching with index_type: HNSW
@ -29,18 +29,21 @@ Please mark all change in change log and use the issue from GitHub
- \#1359 Negative distance value returned when searching with HNSW index type
- \#1429 Server crashed when searching vectors with GPU
- \#1476 Fix vectors results bug when getting vectors from segments
- \#1484 Index type changed to IDMAP after compacted
- \#1484 Index type changed to IDMAP after compacted
- \#1491 Server crashed during adding vectors
- \#1499 Fix duplicated ID number issue
- \#1491 Server crashed during adding vectors
- \#1504 Avoid possible race condition between delete and search
- \#1504 Avoid possible race condition between delete and search
- \#1507 set_config for insert_buffer_size is wrong
- \#1510 Add set interfaces for WAL configurations
- \#1511 Fix big integer cannot pass to server correctly
- \#1518 Table count did not match after deleting vectors and compact
- \#1521 Make cache_insert_data take effect in-service
- \#1525 Add setter API for config preload_table
- \#1529 Fix server crash when cache_insert_data enabled
- \#1530 Set table file with correct engine type in meta
- \#1532 Search with ivf_flat failed with open-dataset: sift-256-hamming
- \#1535 Degradation searching performance with metric_type: binary_idmap
- \#1556 Index file not created after table and index created
## Feature
- \#216 Add CLI to get server info
@ -86,9 +89,11 @@ Please mark all change in change log and use the issue from GitHub
- \#1320 Remove debug logging from faiss
- \#1426 Support to configure whether to enabled autoflush and the autoflush interval
- \#1444 Improve delete
- \#1448 General proto api for NNS libraries
- \#1480 Add return code for AVX512 selection
- \#1524 Update config "preload_table" description
- \#1537 Optimize raw vector and uids read/write
- \#1544 Update resources name in HTTP module
## Task
- \#1327 Exclude third-party code from codebeat

View File

@ -2,6 +2,7 @@ timeout(time: 60, unit: 'MINUTES') {
dir ("tests/milvus_python_test") {
// sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com'
sh 'python3 -m pip install -r requirements.txt'
sh 'python3 -m pip install git+https://github.com/BossZou/pymilvus.git@nns'
sh "pytest . --alluredir=\"test_out/dev/single/sqlite\" --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local"
}

View File

@ -14,7 +14,6 @@
#include <string>
#include "server/Config.h"
#include "utils/Log.h"
namespace milvus {
namespace server {

View File

@ -112,18 +112,18 @@ class DB {
virtual Status
QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) = 0;
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status
Size(uint64_t& result) = 0;

View File

@ -371,8 +371,10 @@ DBImpl::PreloadTable(const std::string& table_id) {
} else {
engine_type = (EngineType)file.engine_type_;
}
ExecutionEnginePtr engine = EngineFactory::Build(file.dimension_, file.location_, engine_type,
(MetricType)file.metric_type_, file.nlist_);
auto json = milvus::json::parse(file.index_params_);
ExecutionEnginePtr engine =
EngineFactory::Build(file.dimension_, file.location_, engine_type, (MetricType)file.metric_type_, json);
fiu_do_on("DBImpl.PreloadTable.null_engine", engine = nullptr);
if (engine == nullptr) {
ENGINE_LOG_ERROR << "Invalid engine type";
@ -382,7 +384,7 @@ DBImpl::PreloadTable(const std::string& table_id) {
size += engine->PhysicalSize();
fiu_do_on("DBImpl.PreloadTable.exceed_cache", size = available_size + 1);
if (size > available_size) {
ENGINE_LOG_DEBUG << "Pre-load canceled since cache almost full";
ENGINE_LOG_DEBUG << "Pre-load cancelled since cache is almost full";
return Status(SERVER_CACHE_FULL, "Cache is full");
} else {
try {
@ -810,7 +812,7 @@ DBImpl::CompactFile(const std::string& table_id, const meta::TableFileSchema& fi
// Update table files state
// if index type isn't IDMAP, set file type to TO_INDEX if file size exceed index_file_size
// else set file type to RAW, no need to build index
if (compacted_file.engine_type_ != (int)EngineType::FAISS_IDMAP) {
if (!utils::IsRawIndexType(compacted_file.engine_type_)) {
compacted_file.file_type_ = (segment_writer_ptr->Size() >= compacted_file.index_file_size_)
? meta::TableFileSchema::TO_INDEX
: meta::TableFileSchema::RAW;
@ -1110,8 +1112,8 @@ DBImpl::DropIndex(const std::string& table_id) {
Status
DBImpl::QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) {
if (!initialized_.load(std::memory_order_acquire)) {
return SHUTDOWN_ERROR;
}
@ -1119,14 +1121,15 @@ DBImpl::QueryByID(const std::shared_ptr<server::Context>& context, const std::st
VectorsData vectors_data = VectorsData();
vectors_data.id_array_.emplace_back(vector_id);
vectors_data.vector_count_ = 1;
Status result = Query(context, table_id, partition_tags, k, nprobe, vectors_data, result_ids, result_distances);
Status result =
Query(context, table_id, partition_tags, k, extra_params, vectors_data, result_ids, result_distances);
return result;
}
Status
DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_ctx = context->Child("Query");
if (!initialized_.load(std::memory_order_acquire)) {
@ -1169,7 +1172,7 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
status = QueryAsync(query_ctx, table_id, files_array, k, extra_params, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
query_ctx->GetTraceContext()->GetSpan()->Finish();
@ -1179,8 +1182,8 @@ DBImpl::Query(const std::shared_ptr<server::Context>& context, const std::string
Status
DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_ctx = context->Child("Query by file id");
if (!initialized_.load(std::memory_order_acquire)) {
@ -1208,7 +1211,7 @@ DBImpl::QueryByFileID(const std::shared_ptr<server::Context>& context, const std
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(query_ctx, table_id, files_array, k, nprobe, vectors, result_ids, result_distances);
status = QueryAsync(query_ctx, table_id, files_array, k, extra_params, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
query_ctx->GetTraceContext()->GetSpan()->Finish();
@ -1230,8 +1233,8 @@ DBImpl::Size(uint64_t& result) {
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status
DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) {
const meta::TableFilesSchema& files, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) {
auto query_async_ctx = context->Child("Query Async");
server::CollectQueryMetrics metrics(vectors.vector_count_);
@ -1242,7 +1245,7 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, const std::s
auto status = OngoingFileChecker::GetInstance().MarkOngoingFiles(files);
ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size();
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, nprobe, vectors);
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(query_async_ctx, k, extra_params, vectors);
for (auto& file : files) {
scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
job->AddIndexFile(file_ptr);
@ -1465,7 +1468,7 @@ DBImpl::MergeFiles(const std::string& table_id, const meta::TableFilesSchema& fi
// step 4: update table files state
// if index type isn't IDMAP, set file type to TO_INDEX if file size exceed index_file_size
// else set file type to RAW, no need to build index
if (table_file.engine_type_ != (int)EngineType::FAISS_IDMAP) {
if (!utils::IsRawIndexType(table_file.engine_type_)) {
table_file.file_type_ = (segment_writer_ptr->Size() >= table_file.index_file_size_)
? meta::TableFileSchema::TO_INDEX
: meta::TableFileSchema::RAW;
@ -1767,7 +1770,7 @@ DBImpl::BuildTableIndexRecursively(const std::string& table_id, const TableIndex
// for IDMAP type, only wait all NEW file converted to RAW file
// for other type, wait NEW/RAW/NEW_MERGE/NEW_INDEX/TO_INDEX files converted to INDEX files
std::vector<int> file_types;
if (index.engine_type_ == static_cast<int32_t>(EngineType::FAISS_IDMAP)) {
if (utils::IsRawIndexType(index.engine_type_)) {
file_types = {
static_cast<int32_t>(meta::TableFileSchema::NEW),
static_cast<int32_t>(meta::TableFileSchema::NEW_MERGE),
@ -1789,7 +1792,7 @@ DBImpl::BuildTableIndexRecursively(const std::string& table_id, const TableIndex
while (!table_files.empty()) {
ENGINE_LOG_DEBUG << "Non index files detected! Will build index " << times;
if (index.engine_type_ != (int)EngineType::FAISS_IDMAP) {
if (!utils::IsRawIndexType(index.engine_type_)) {
status = meta_ptr_->UpdateTableFilesToIndex(table_id);
}

View File

@ -131,18 +131,18 @@ class DBImpl : public DB, public server::CacheConfigHandler {
Status
QueryByID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, IDNumber vector_id,
ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
IDNumber vector_id, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
Query(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& partition_tags, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& partition_tags, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
QueryByFileID(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const std::vector<std::string>& file_ids, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances) override;
const std::vector<std::string>& file_ids, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances) override;
Status
Size(uint64_t& result) override;
@ -154,8 +154,8 @@ class DBImpl : public DB, public server::CacheConfigHandler {
private:
Status
QueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
const meta::TableFilesSchema& files, uint64_t k, uint64_t nprobe, const VectorsData& vectors,
ResultIds& result_ids, ResultDistances& result_distances);
const meta::TableFilesSchema& files, uint64_t k, const milvus::json& extra_params,
const VectorsData& vectors, ResultIds& result_ids, ResultDistances& result_distances);
Status
GetVectorByIdHelper(const std::string& table_id, IDNumber vector_id, VectorsData& vector,

View File

@ -22,6 +22,7 @@
#include "db/engine/ExecutionEngine.h"
#include "segment/Types.h"
#include "utils/Json.h"
namespace milvus {
namespace engine {
@ -35,8 +36,8 @@ typedef std::vector<faiss::Index::distance_t> ResultDistances;
struct TableIndex {
int32_t engine_type_ = (int)EngineType::FAISS_IDMAP;
int32_t nlist_ = 16384;
int32_t metric_type_ = (int)MetricType::L2;
milvus::json extra_params_ = {{"nlist", 16384}};
};
struct VectorsData {

View File

@ -211,10 +211,15 @@ GetParentPath(const std::string& path, std::string& parent_path) {
bool
IsSameIndex(const TableIndex& index1, const TableIndex& index2) {
return index1.engine_type_ == index2.engine_type_ && index1.nlist_ == index2.nlist_ &&
return index1.engine_type_ == index2.engine_type_ && index1.extra_params_ == index2.extra_params_ &&
index1.metric_type_ == index2.metric_type_;
}
bool
IsRawIndexType(int32_t type) {
return (type == (int32_t)EngineType::FAISS_IDMAP) || (type == (int32_t)EngineType::FAISS_BIN_IDMAP);
}
meta::DateT
GetDate(const std::time_t& t, int day_delta) {
struct tm ltm;

View File

@ -45,6 +45,9 @@ GetParentPath(const std::string& path, std::string& parent_path);
bool
IsSameIndex(const TableIndex& index1, const TableIndex& index2);
bool
IsRawIndexType(int32_t type);
meta::DateT
GetDate(const std::time_t& t, int day_delta = 0);
meta::DateT

View File

@ -20,7 +20,7 @@ namespace engine {
ExecutionEnginePtr
EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist) {
const milvus::json& index_params) {
if (index_type == EngineType::INVALID) {
ENGINE_LOG_ERROR << "Unsupported engine type";
return nullptr;
@ -28,7 +28,7 @@ EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType
ENGINE_LOG_DEBUG << "EngineFactory index type: " << (int)index_type;
ExecutionEnginePtr execution_engine_ptr =
std::make_shared<ExecutionEngineImpl>(dimension, location, index_type, metric_type, nlist);
std::make_shared<ExecutionEngineImpl>(dimension, location, index_type, metric_type, index_params);
execution_engine_ptr->Init();
return execution_engine_ptr;

View File

@ -12,6 +12,7 @@
#pragma once
#include "ExecutionEngine.h"
#include "utils/Json.h"
#include "utils/Status.h"
#include <string>
@ -23,7 +24,7 @@ class EngineFactory {
public:
static ExecutionEnginePtr
Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist);
const milvus::json& index_params);
};
} // namespace engine

View File

@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "utils/Json.h"
#include "utils/Status.h"
namespace milvus {
@ -94,15 +95,16 @@ class ExecutionEngine {
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) = 0;
virtual Status
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels, bool hybrid) = 0;
virtual Status
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
bool hybrid) = 0;
virtual Status
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) = 0;
Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) = 0;
virtual Status
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) = 0;
virtual std::shared_ptr<ExecutionEngine>
BuildIndex(const std::string& location, EngineType engine_type) = 0;

View File

@ -43,22 +43,22 @@ namespace engine {
namespace {
Status
MappingMetricType(MetricType metric_type, knowhere::METRICTYPE& kw_type) {
MappingMetricType(MetricType metric_type, milvus::json& conf) {
switch (metric_type) {
case MetricType::IP:
kw_type = knowhere::METRICTYPE::IP;
conf[knowhere::Metric::TYPE] = knowhere::Metric::IP;
break;
case MetricType::L2:
kw_type = knowhere::METRICTYPE::L2;
conf[knowhere::Metric::TYPE] = knowhere::Metric::L2;
break;
case MetricType::HAMMING:
kw_type = knowhere::METRICTYPE::HAMMING;
conf[knowhere::Metric::TYPE] = knowhere::Metric::HAMMING;
break;
case MetricType::JACCARD:
kw_type = knowhere::METRICTYPE::JACCARD;
conf[knowhere::Metric::TYPE] = knowhere::Metric::JACCARD;
break;
case MetricType::TANIMOTO:
kw_type = knowhere::METRICTYPE::TANIMOTO;
conf[knowhere::Metric::TYPE] = knowhere::Metric::TANIMOTO;
break;
default:
return Status(DB_ERROR, "Unsupported metric type");
@ -94,8 +94,12 @@ class CachedQuantizer : public cache::DataObj {
};
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
MetricType metric_type, int32_t nlist)
: location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
MetricType metric_type, const milvus::json& index_params)
: location_(location),
dim_(dimension),
index_type_(index_type),
metric_type_(metric_type),
index_params_(index_params) {
EngineType tmp_index_type = server::ValidationUtil::IsBinaryMetricType((int32_t)metric_type)
? EngineType::FAISS_BIN_IDMAP
: EngineType::FAISS_IDMAP;
@ -104,16 +108,15 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
throw Exception(DB_ERROR, "Unsupported index type");
}
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = dimension;
auto status = MappingMetricType(metric_type, temp_conf.metric_type);
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
milvus::json conf = index_params;
conf[knowhere::meta::DEVICEID] = gpu_num_;
conf[knowhere::meta::DIM] = dimension;
MappingMetricType(metric_type, conf);
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->Match(temp_conf);
if (!adapter->CheckTrain(conf)) {
throw Exception(DB_ERROR, "Illegal index params");
}
ErrorCode ec = KNOWHERE_UNEXPECTED_ERROR;
if (auto bf_index = std::dynamic_pointer_cast<BFIndex>(index_)) {
@ -127,8 +130,12 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string&
}
ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type,
MetricType metric_type, int32_t nlist)
: index_(std::move(index)), location_(location), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
MetricType metric_type, const milvus::json& index_params)
: index_(std::move(index)),
location_(location),
index_type_(index_type),
metric_type_(metric_type),
index_params_(index_params) {
}
VecIndexPtr
@ -273,10 +280,9 @@ ExecutionEngineImpl::HybridLoad() const {
auto best_index = std::distance(all_free_mem.begin(), max_e);
auto best_device_id = gpus[best_index];
auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
quantizer_conf->mode = 1;
quantizer_conf->gpu_id = best_device_id;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, best_device_id}, {"mode", 1}};
auto quantizer = index_->LoadQuantizer(quantizer_conf);
ENGINE_LOG_DEBUG << "Quantizer params: " << quantizer_conf.dump();
if (quantizer == nullptr) {
ENGINE_LOG_ERROR << "quantizer is nullptr";
}
@ -400,22 +406,18 @@ ExecutionEngineImpl::Load(bool to_cache) {
utils::GetParentPath(location_, segment_dir);
auto segment_reader_ptr = std::make_shared<segment::SegmentReader>(segment_dir);
if (index_type_ == EngineType::FAISS_IDMAP || index_type_ == EngineType::FAISS_BIN_IDMAP) {
if (utils::IsRawIndexType((int32_t)index_type_)) {
index_ = index_type_ == EngineType::FAISS_IDMAP ? GetVecIndexFactory(IndexType::FAISS_IDMAP)
: GetVecIndexFactory(IndexType::FAISS_BIN_IDMAP);
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = dim_;
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
if (!status.ok()) {
return status;
milvus::json conf{{knowhere::meta::DEVICEID, gpu_num_}, {knowhere::meta::DIM, dim_}};
MappingMetricType(metric_type_, conf);
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
if (!adapter->CheckTrain(conf)) {
throw Exception(DB_ERROR, "Illegal index params");
}
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->Match(temp_conf);
status = segment_reader_ptr->Load();
auto status = segment_reader_ptr->Load();
if (!status.ok()) {
std::string msg = "Failed to load segment from " + location_;
ENGINE_LOG_ERROR << msg;
@ -429,6 +431,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
auto vectors_uids = vectors->GetUids();
index_->SetUids(vectors_uids);
ENGINE_LOG_DEBUG << "set uids " << index_->GetUids().size() << " for index " << location_;
auto vectors_data = vectors->GetData();
@ -453,7 +456,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
float_vectors.data(), Config());
status = std::static_pointer_cast<BFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
int64_t index_size = vectors->GetCount() * conf->d * sizeof(float);
int64_t index_size = vectors->GetCount() * dim_ * sizeof(float);
int64_t bitset_size = vectors->GetCount() / 8;
index_->set_size(index_size + bitset_size);
} else if (index_type_ == EngineType::FAISS_BIN_IDMAP) {
@ -465,7 +468,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
vectors_data.data(), Config());
status = std::static_pointer_cast<BinBFIndex>(index_)->SetBlacklist(concurrent_bitset_ptr);
int64_t index_size = vectors->GetCount() * conf->d * sizeof(uint8_t);
int64_t index_size = vectors->GetCount() * dim_ * sizeof(uint8_t);
int64_t bitset_size = vectors->GetCount() / 8;
index_->set_size(index_size + bitset_size);
}
@ -508,6 +511,7 @@ ExecutionEngineImpl::Load(bool to_cache) {
std::vector<segment::doc_id_t> uids;
segment_reader_ptr->LoadUids(uids);
index_->SetUids(uids);
ENGINE_LOG_DEBUG << "set uids " << index_->GetUids().size() << " for index " << location_;
ENGINE_LOG_DEBUG << "Finished loading index file from segment " << segment_dir;
}
@ -548,9 +552,7 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
if (device_id != NOT_FOUND) {
// cache hit
auto config = std::make_shared<knowhere::QuantizerCfg>();
config->gpu_id = device_id;
config->mode = 2;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID : device_id}, {"mode" : 2}};
auto new_index = index_->LoadData(quantizer, config);
index_ = new_index;
}
@ -723,30 +725,36 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
throw Exception(DB_ERROR, "Unsupported index type");
}
TempMetaConf temp_conf;
temp_conf.gpu_id = gpu_num_;
temp_conf.dim = Dimension();
temp_conf.nlist = nlist_;
temp_conf.size = Count();
auto status = MappingMetricType(metric_type_, temp_conf.metric_type);
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
milvus::json conf = index_params_;
conf[knowhere::meta::DIM] = Dimension();
conf[knowhere::meta::ROWS] = Count();
conf[knowhere::meta::DEVICEID] = gpu_num_;
MappingMetricType(metric_type_, conf);
ENGINE_LOG_DEBUG << "Index params: " << conf.dump();
auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType());
auto conf = adapter->Match(temp_conf);
if (!adapter->CheckTrain(conf)) {
throw Exception(DB_ERROR, "Illegal index params");
}
ENGINE_LOG_DEBUG << "Index config: " << conf.dump();
auto status = Status::OK();
std::vector<segment::doc_id_t> uids;
if (from_index) {
status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf);
uids = from_index->GetUids();
} else if (bin_from_index) {
status = to_index->BuildAll(Count(), bin_from_index->GetRawVectors(), bin_from_index->GetRawIds(), conf);
uids = bin_from_index->GetUids();
}
to_index->SetUids(uids);
ENGINE_LOG_DEBUG << "set uids " << to_index->GetUids().size() << " for " << location;
if (!status.ok()) {
throw Exception(DB_ERROR, status.message());
}
ENGINE_LOG_DEBUG << "Finish build index file: " << location << " size: " << to_index->Size();
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, index_params_);
}
// map offsets to ids
@ -761,8 +769,8 @@ MapUids(const std::vector<segment::doc_id_t>& uids, int64_t* labels, size_t num)
}
Status
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) {
ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) {
#if 0
if (index_type_ == EngineType::FAISS_IVFSQ8H) {
if (!hybrid) {
@ -786,9 +794,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
if (device_id != NOT_FOUND) {
// cache hit
auto config = std::make_shared<knowhere::QuantizerCfg>();
config->gpu_id = device_id;
config->mode = 2;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID : device_id}, {"mode" : 2}};
auto new_index = index_->LoadData(quantizer, config);
index_ = new_index;
}
@ -824,15 +830,13 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
milvus::json conf = extra_params;
conf[knowhere::meta::TOPK] = k;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
if (!adapter->CheckSearch(conf, index_->GetType())) {
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
HybridLoad();
@ -843,6 +847,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
rc.RecordSection("search done");
// map offsets to ids
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
MapUids(index_->GetUids(), labels, n * k);
rc.RecordSection("map uids " + std::to_string(n * k));
@ -858,8 +863,8 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
}
Status
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances,
int64_t* labels, bool hybrid) {
ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params,
float* distances, int64_t* labels, bool hybrid) {
TimeRecorder rc("ExecutionEngineImpl::Search uint8");
if (index_ == nullptr) {
@ -867,15 +872,13 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
milvus::json conf = extra_params;
conf[knowhere::meta::TOPK] = k;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
if (!adapter->CheckSearch(conf, index_->GetType())) {
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
HybridLoad();
@ -886,6 +889,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
rc.RecordSection("search done");
// map offsets to ids
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
MapUids(index_->GetUids(), labels, n * k);
rc.RecordSection("map uids " + std::to_string(n * k));
@ -901,8 +905,8 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, int64_t n
}
Status
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances,
int64_t* labels, bool hybrid) {
ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params,
float* distances, int64_t* labels, bool hybrid) {
TimeRecorder rc("ExecutionEngineImpl::Search vector of ids");
if (index_ == nullptr) {
@ -910,15 +914,13 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
return Status(DB_ERROR, "index is null");
}
ENGINE_LOG_DEBUG << "Search by ids Params: [k] " << k << " [nprobe] " << nprobe;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
temp_conf.k = k;
temp_conf.nprobe = nprobe;
milvus::json conf = extra_params;
conf[knowhere::meta::TOPK] = k;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
ENGINE_LOG_DEBUG << "Search params: " << conf.dump();
if (!adapter->CheckSearch(conf, index_->GetType())) {
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
HybridLoad();
@ -971,6 +973,7 @@ ExecutionEngineImpl::Search(int64_t n, const std::vector<int64_t>& ids, int64_t
rc.RecordSection("search done");
// map offsets to ids
ENGINE_LOG_DEBUG << "get uids " << index_->GetUids().size() << " from index " << location_;
MapUids(uids, labels, offsets.size() * k);
rc.RecordSection("map uids " + std::to_string(offsets.size() * k));
@ -993,19 +996,13 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, float* vector, bool hybrid
return Status(DB_ERROR, "index is null");
}
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
if (hybrid) {
HybridLoad();
}
// Only one id for now
std::vector<int64_t> ids{id};
auto status = index_->GetVectorById(1, ids.data(), vector, conf);
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
if (hybrid) {
HybridUnset();
@ -1026,19 +1023,13 @@ ExecutionEngineImpl::GetVectorByID(const int64_t& id, uint8_t* vector, bool hybr
ENGINE_LOG_DEBUG << "Get binary vector by id: " << id;
// TODO(linxj): remove here. Get conf from function
TempMetaConf temp_conf;
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
if (hybrid) {
HybridLoad();
}
// Only one id for now
std::vector<int64_t> ids{id};
auto status = index_->GetVectorById(1, ids.data(), vector, conf);
auto status = index_->GetVectorById(1, ids.data(), vector, milvus::json());
if (hybrid) {
HybridUnset();
@ -1075,7 +1066,7 @@ ExecutionEngineImpl::Init() {
std::vector<int64_t> gpu_ids;
Status s = config.GetGpuResourceConfigBuildIndexResources(gpu_ids);
if (!s.ok()) {
gpu_num_ = knowhere::INVALID_VALUE;
gpu_num_ = -1;
return s;
}
for (auto id : gpu_ids) {

View File

@ -11,7 +11,8 @@
#pragma once
#include <src/segment/SegmentReader.h>
#include "segment/SegmentReader.h"
#include "utils/Json.h"
#include <memory>
#include <string>
@ -26,10 +27,10 @@ namespace engine {
class ExecutionEngineImpl : public ExecutionEngine {
public:
ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist);
const milvus::json& index_params);
ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type, MetricType metric_type,
int32_t nlist);
const milvus::json& index_params);
Status
AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override;
@ -77,16 +78,16 @@ class ExecutionEngineImpl : public ExecutionEngine {
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) override;
Status
Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
bool hybrid = false) override;
Status
Search(int64_t n, const uint8_t* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid = false) override;
Search(int64_t n, const uint8_t* data, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid = false) override;
Status
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, int64_t nprobe, float* distances, int64_t* labels,
bool hybrid) override;
Search(int64_t n, const std::vector<int64_t>& ids, int64_t k, const milvus::json& extra_params, float* distances,
int64_t* labels, bool hybrid) override;
ExecutionEnginePtr
BuildIndex(const std::string& location, EngineType engine_type) override;
@ -136,7 +137,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
int64_t dim_;
std::string location_;
int64_t nlist_ = 0;
milvus::json index_params_;
int64_t gpu_num_ = 0;
};

View File

@ -54,7 +54,7 @@ struct TableSchema {
int64_t flag_ = 0;
int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE;
int32_t engine_type_ = DEFAULT_ENGINE_TYPE;
int32_t nlist_ = DEFAULT_NLIST;
std::string index_params_ = "{ \"nlist\": 16384 }";
int32_t metric_type_ = DEFAULT_METRIC_TYPE;
std::string owner_table_;
std::string partition_tag_;
@ -89,7 +89,7 @@ struct TableFileSchema {
int64_t created_on_ = 0;
int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE; // not persist to meta
int32_t engine_type_ = DEFAULT_ENGINE_TYPE;
int32_t nlist_ = DEFAULT_NLIST; // not persist to meta
std::string index_params_; // not persist to meta
int32_t metric_type_ = DEFAULT_METRIC_TYPE; // not persist to meta
uint64_t flush_lsn_ = 0;
}; // TableFileSchema

View File

@ -144,7 +144,7 @@ static const MetaSchema TABLES_SCHEMA(META_TABLES, {
MetaField("flag", "BIGINT", "DEFAULT 0 NOT NULL"),
MetaField("index_file_size", "BIGINT", "DEFAULT 1024 NOT NULL"),
MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"),
MetaField("nlist", "INT", "DEFAULT 16384 NOT NULL"),
MetaField("index_params", "VARCHAR(512)", "NOT NULL"),
MetaField("metric_type", "INT", "DEFAULT 1 NOT NULL"),
MetaField("owner_table", "VARCHAR(255)", "NOT NULL"),
MetaField("partition_tag", "VARCHAR(255)", "NOT NULL"),
@ -398,7 +398,7 @@ MySQLMetaImpl::CreateTable(TableSchema& table_schema) {
std::string flag = std::to_string(table_schema.flag_);
std::string index_file_size = std::to_string(table_schema.index_file_size_);
std::string engine_type = std::to_string(table_schema.engine_type_);
std::string nlist = std::to_string(table_schema.nlist_);
std::string& index_params = table_schema.index_params_;
std::string metric_type = std::to_string(table_schema.metric_type_);
std::string& owner_table = table_schema.owner_table_;
std::string& partition_tag = table_schema.partition_tag_;
@ -407,9 +407,9 @@ MySQLMetaImpl::CreateTable(TableSchema& table_schema) {
createTableQuery << "INSERT INTO " << META_TABLES << " VALUES(" << id << ", " << mysqlpp::quote << table_id
<< ", " << state << ", " << dimension << ", " << created_on << ", " << flag << ", "
<< index_file_size << ", " << engine_type << ", " << nlist << ", " << metric_type << ", "
<< mysqlpp::quote << owner_table << ", " << mysqlpp::quote << partition_tag << ", "
<< mysqlpp::quote << version << ", " << flush_lsn << ");";
<< index_file_size << ", " << engine_type << ", " << mysqlpp::quote << index_params << ", "
<< metric_type << ", " << mysqlpp::quote << owner_table << ", " << mysqlpp::quote
<< partition_tag << ", " << mysqlpp::quote << version << ", " << flush_lsn << ");";
ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTable: " << createTableQuery.str();
@ -446,8 +446,8 @@ MySQLMetaImpl::DescribeTable(TableSchema& table_schema) {
mysqlpp::Query describeTableQuery = connectionPtr->query();
describeTableQuery
<< "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, nlist, metric_type"
<< " ,owner_table, partition_tag, version, flush_lsn"
<< "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, index_params"
<< " , metric_type ,owner_table, partition_tag, version, flush_lsn"
<< " FROM " << META_TABLES << " WHERE table_id = " << mysqlpp::quote << table_schema.table_id_
<< " AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";";
@ -465,7 +465,7 @@ MySQLMetaImpl::DescribeTable(TableSchema& table_schema) {
table_schema.flag_ = resRow["flag"];
table_schema.index_file_size_ = resRow["index_file_size"];
table_schema.engine_type_ = resRow["engine_type"];
table_schema.nlist_ = resRow["nlist"];
resRow["index_params"].to_string(table_schema.index_params_);
table_schema.metric_type_ = resRow["metric_type"];
resRow["owner_table"].to_string(table_schema.owner_table_);
resRow["partition_tag"].to_string(table_schema.partition_tag_);
@ -534,7 +534,7 @@ MySQLMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
}
mysqlpp::Query allTablesQuery = connectionPtr->query();
allTablesQuery << "SELECT id, table_id, dimension, engine_type, nlist, index_file_size, metric_type"
allTablesQuery << "SELECT id, table_id, dimension, engine_type, index_params, index_file_size, metric_type"
<< " ,owner_table, partition_tag, version, flush_lsn"
<< " FROM " << META_TABLES << " WHERE state <> " << std::to_string(TableSchema::TO_DELETE)
<< " AND owner_table = \"\";";
@ -551,7 +551,7 @@ MySQLMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
table_schema.dimension_ = resRow["dimension"];
table_schema.index_file_size_ = resRow["index_file_size"];
table_schema.engine_type_ = resRow["engine_type"];
table_schema.nlist_ = resRow["nlist"];
resRow["index_params"].to_string(table_schema.index_params_);
table_schema.metric_type_ = resRow["metric_type"];
resRow["owner_table"].to_string(table_schema.owner_table_);
resRow["partition_tag"].to_string(table_schema.partition_tag_);
@ -673,17 +673,8 @@ MySQLMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
file_schema.created_on_ = utils::GetMicroSecTimeStamp();
file_schema.updated_time_ = file_schema.created_on_;
file_schema.index_file_size_ = table_schema.index_file_size_;
if (file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW ||
file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW_MERGE) {
file_schema.engine_type_ = server::ValidationUtil::IsBinaryMetricType(table_schema.metric_type_)
? (int32_t)EngineType::FAISS_BIN_IDMAP
: (int32_t)EngineType::FAISS_IDMAP;
} else {
file_schema.engine_type_ = table_schema.engine_type_;
}
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.engine_type_ = table_schema.engine_type_;
file_schema.metric_type_ = table_schema.metric_type_;
std::string id = "NULL"; // auto-increment
@ -785,7 +776,7 @@ MySQLMetaImpl::GetTableFiles(const std::string& table_id, const std::vector<size
resRow["segment_id"].to_string(file_schema.segment_id_);
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.engine_type_ = resRow["engine_type"];
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
resRow["file_id"].to_string(file_schema.file_id_);
file_schema.file_type_ = resRow["file_type"];
@ -844,7 +835,7 @@ MySQLMetaImpl::GetTableFilesBySegmentId(const std::string& segment_id,
resRow["segment_id"].to_string(file_schema.segment_id_);
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.engine_type_ = resRow["engine_type"];
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
resRow["file_id"].to_string(file_schema.file_id_);
file_schema.file_type_ = resRow["file_type"];
@ -900,7 +891,8 @@ MySQLMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex& i
updateTableIndexParamQuery << "UPDATE " << META_TABLES << " SET id = " << id << " ,state = " << state
<< " ,dimension = " << dimension << " ,created_on = " << created_on
<< " ,engine_type = " << index.engine_type_ << " ,nlist = " << index.nlist_
<< " ,engine_type = " << index.engine_type_
<< " ,index_params = " << mysqlpp::quote << index.extra_params_.dump()
<< " ,metric_type = " << index.metric_type_
<< " WHERE table_id = " << mysqlpp::quote << table_id << ";";
@ -1044,7 +1036,7 @@ MySQLMetaImpl::GetTableFilesByFlushLSN(uint64_t flush_lsn, TableFilesSchema& tab
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
@ -1263,7 +1255,7 @@ MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index
}
mysqlpp::Query describeTableIndexQuery = connectionPtr->query();
describeTableIndexQuery << "SELECT engine_type, nlist, index_file_size, metric_type"
describeTableIndexQuery << "SELECT engine_type, index_params, index_file_size, metric_type"
<< " FROM " << META_TABLES << " WHERE table_id = " << mysqlpp::quote << table_id
<< " AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";";
@ -1275,7 +1267,9 @@ MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index
const mysqlpp::Row& resRow = res[0];
index.engine_type_ = resRow["engine_type"];
index.nlist_ = resRow["nlist"];
std::string str_index_params;
resRow["index_params"].to_string(str_index_params);
index.extra_params_ = milvus::json::parse(str_index_params);
index.metric_type_ = resRow["metric_type"];
} else {
return Status(DB_NOT_FOUND, "Table " + table_id + " not found");
@ -1334,7 +1328,7 @@ MySQLMetaImpl::DropTableIndex(const std::string& table_id) {
// set table index type to raw
dropTableIndexQuery << "UPDATE " << META_TABLES
<< " SET engine_type = " << std::to_string(DEFAULT_ENGINE_TYPE)
<< " ,nlist = " << std::to_string(DEFAULT_NLIST)
<< " , index_params = '{}'"
<< " WHERE table_id = " << mysqlpp::quote << table_id << ";";
ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropTableIndex: " << dropTableIndexQuery.str();
@ -1426,7 +1420,7 @@ MySQLMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Tab
mysqlpp::Query allPartitionsQuery = connectionPtr->query();
allPartitionsQuery << "SELECT table_id, id, state, dimension, created_on, flag, index_file_size,"
<< " engine_type, nlist, metric_type, partition_tag, version FROM " << META_TABLES
<< " engine_type, index_params, metric_type, partition_tag, version FROM " << META_TABLES
<< " WHERE owner_table = " << mysqlpp::quote << table_id << " AND state <> "
<< std::to_string(TableSchema::TO_DELETE) << ";";
@ -1445,7 +1439,7 @@ MySQLMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Tab
partition_schema.flag_ = resRow["flag"];
partition_schema.index_file_size_ = resRow["index_file_size"];
partition_schema.engine_type_ = resRow["engine_type"];
partition_schema.nlist_ = resRow["nlist"];
resRow["index_params"].to_string(partition_schema.index_params_);
partition_schema.metric_type_ = resRow["metric_type"];
partition_schema.owner_table_ = table_id;
resRow["partition_tag"].to_string(partition_schema.partition_tag_);
@ -1562,7 +1556,7 @@ MySQLMetaImpl::FilesToSearch(const std::string& table_id, const std::vector<size
resRow["segment_id"].to_string(table_file.segment_id_);
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.engine_type_ = resRow["engine_type"];
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
resRow["file_id"].to_string(table_file.file_id_);
table_file.file_type_ = resRow["file_type"];
@ -1644,7 +1638,7 @@ MySQLMetaImpl::FilesToMerge(const std::string& table_id, TableFilesSchema& files
table_file.date_ = resRow["date"];
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.engine_type_ = resRow["engine_type"];
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
table_file.created_on_ = resRow["created_on"];
table_file.dimension_ = table_schema.dimension_;
@ -1722,7 +1716,7 @@ MySQLMetaImpl::FilesToIndex(TableFilesSchema& files) {
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
@ -1809,7 +1803,7 @@ MySQLMetaImpl::FilesByType(const std::string& table_id, const std::vector<int>&
file_schema.created_on_ = resRow["created_on"];
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
file_schema.dimension_ = table_schema.dimension_;
@ -2083,8 +2077,7 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint64_t seconds /*, CleanUpFilter* filter*/)
// If we are deleting a raw table file, it means it's okay to delete the entire segment directory.
// Else, we can only delete the single file
// TODO(zhiru): We determine whether a table file is raw by its engine type. This is a bit hacky
if (table_file.engine_type_ == (int32_t)EngineType::FAISS_IDMAP ||
table_file.engine_type_ == (int32_t)EngineType::FAISS_BIN_IDMAP) {
if (utils::IsRawIndexType(table_file.engine_type_)) {
utils::DeleteSegment(options_, table_file);
std::string segment_dir;
utils::GetParentPath(table_file.location_, segment_dir);

View File

@ -68,7 +68,8 @@ StoragePrototype(const std::string& path) {
make_column("created_on", &TableSchema::created_on_),
make_column("flag", &TableSchema::flag_, default_value(0)),
make_column("index_file_size", &TableSchema::index_file_size_),
make_column("engine_type", &TableSchema::engine_type_), make_column("nlist", &TableSchema::nlist_),
make_column("engine_type", &TableSchema::engine_type_),
make_column("index_params", &TableSchema::index_params_),
make_column("metric_type", &TableSchema::metric_type_),
make_column("owner_table", &TableSchema::owner_table_, default_value("")),
make_column("partition_tag", &TableSchema::partition_tag_, default_value("")),
@ -213,7 +214,7 @@ SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
auto groups = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
&TableSchema::nlist_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::flush_lsn_),
where(c(&TableSchema::table_id_) == table_schema.table_id_ and
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
@ -226,7 +227,7 @@ SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
table_schema.flag_ = std::get<4>(groups[0]);
table_schema.index_file_size_ = std::get<5>(groups[0]);
table_schema.engine_type_ = std::get<6>(groups[0]);
table_schema.nlist_ = std::get<7>(groups[0]);
table_schema.index_params_ = std::get<7>(groups[0]);
table_schema.metric_type_ = std::get<8>(groups[0]);
table_schema.owner_table_ = std::get<9>(groups[0]);
table_schema.partition_tag_ = std::get<10>(groups[0]);
@ -272,7 +273,7 @@ SqliteMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
auto selected = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::table_id_, &TableSchema::dimension_, &TableSchema::created_on_,
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
&TableSchema::nlist_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::owner_table_,
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::flush_lsn_),
where(c(&TableSchema::state_) != (int)TableSchema::TO_DELETE and c(&TableSchema::owner_table_) == ""));
for (auto& table : selected) {
@ -284,7 +285,7 @@ SqliteMetaImpl::AllTables(std::vector<TableSchema>& table_schema_array) {
schema.flag_ = std::get<4>(table);
schema.index_file_size_ = std::get<5>(table);
schema.engine_type_ = std::get<6>(table);
schema.nlist_ = std::get<7>(table);
schema.index_params_ = std::get<7>(table);
schema.metric_type_ = std::get<8>(table);
schema.owner_table_ = std::get<9>(table);
schema.partition_tag_ = std::get<10>(table);
@ -373,17 +374,8 @@ SqliteMetaImpl::CreateTableFile(TableFileSchema& file_schema) {
file_schema.created_on_ = utils::GetMicroSecTimeStamp();
file_schema.updated_time_ = file_schema.created_on_;
file_schema.index_file_size_ = table_schema.index_file_size_;
if (file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW ||
file_schema.file_type_ == TableFileSchema::FILE_TYPE::NEW_MERGE) {
file_schema.engine_type_ = server::ValidationUtil::IsBinaryMetricType(table_schema.metric_type_)
? (int32_t)EngineType::FAISS_BIN_IDMAP
: (int32_t)EngineType::FAISS_IDMAP;
} else {
file_schema.engine_type_ = table_schema.engine_type_;
}
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.engine_type_ = table_schema.engine_type_;
file_schema.metric_type_ = table_schema.metric_type_;
// multi-threads call sqlite update may get exception('bad logic', etc), so we add a lock here
@ -436,7 +428,7 @@ SqliteMetaImpl::GetTableFiles(const std::string& table_id, const std::vector<siz
file_schema.created_on_ = std::get<8>(file);
file_schema.dimension_ = table_schema.dimension_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
utils::GetTableFilePath(options_, file_schema);
@ -486,7 +478,7 @@ SqliteMetaImpl::GetTableFilesBySegmentId(const std::string& segment_id,
file_schema.created_on_ = std::get<9>(file);
file_schema.dimension_ = table_schema.dimension_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
utils::GetTableFilePath(options_, file_schema);
@ -601,7 +593,7 @@ SqliteMetaImpl::GetTableFilesByFlushLSN(uint64_t flush_lsn, TableFilesSchema& ta
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
table_files.push_back(table_file);
}
@ -721,7 +713,7 @@ SqliteMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex&
table_schema.partition_tag_ = std::get<7>(tables[0]);
table_schema.version_ = std::get<8>(tables[0]);
table_schema.engine_type_ = index.engine_type_;
table_schema.nlist_ = index.nlist_;
table_schema.index_params_ = index.extra_params_.dump();
table_schema.metric_type_ = index.metric_type_;
ConnectorPtr->update(table_schema);
@ -773,12 +765,12 @@ SqliteMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& inde
fiu_do_on("SqliteMetaImpl.DescribeTableIndex.throw_exception", throw std::exception());
auto groups = ConnectorPtr->select(
columns(&TableSchema::engine_type_, &TableSchema::nlist_, &TableSchema::metric_type_),
columns(&TableSchema::engine_type_, &TableSchema::index_params_, &TableSchema::metric_type_),
where(c(&TableSchema::table_id_) == table_id and c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
if (groups.size() == 1) {
index.engine_type_ = std::get<0>(groups[0]);
index.nlist_ = std::get<1>(groups[0]);
index.extra_params_ = milvus::json::parse(std::get<1>(groups[0]));
index.metric_type_ = std::get<2>(groups[0]);
} else {
return Status(DB_NOT_FOUND, "Table " + table_id + " not found");
@ -813,7 +805,7 @@ SqliteMetaImpl::DropTableIndex(const std::string& table_id) {
// set table index type to raw
ConnectorPtr->update_all(
set(c(&TableSchema::engine_type_) = DEFAULT_ENGINE_TYPE, c(&TableSchema::nlist_) = DEFAULT_NLIST),
set(c(&TableSchema::engine_type_) = DEFAULT_ENGINE_TYPE, c(&TableSchema::index_params_) = "{}"),
where(c(&TableSchema::table_id_) == table_id));
ENGINE_LOG_DEBUG << "Successfully drop table index, table id = " << table_id;
@ -886,13 +878,14 @@ SqliteMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Ta
server::MetricCollector metric;
fiu_do_on("SqliteMetaImpl.ShowPartitions.throw_exception", throw std::exception());
auto partitions =
ConnectorPtr->select(columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_,
&TableSchema::created_on_, &TableSchema::flag_, &TableSchema::index_file_size_,
&TableSchema::engine_type_, &TableSchema::nlist_, &TableSchema::metric_type_,
&TableSchema::partition_tag_, &TableSchema::version_, &TableSchema::table_id_),
where(c(&TableSchema::owner_table_) == table_id and
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
auto partitions = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
&TableSchema::flag_, &TableSchema::index_file_size_, &TableSchema::engine_type_,
&TableSchema::index_params_, &TableSchema::metric_type_, &TableSchema::partition_tag_,
&TableSchema::version_, &TableSchema::table_id_),
where(c(&TableSchema::owner_table_) == table_id and
c(&TableSchema::state_) != (int)TableSchema::TO_DELETE));
for (size_t i = 0; i < partitions.size(); i++) {
meta::TableSchema partition_schema;
partition_schema.id_ = std::get<0>(partitions[i]);
@ -902,7 +895,7 @@ SqliteMetaImpl::ShowPartitions(const std::string& table_id, std::vector<meta::Ta
partition_schema.flag_ = std::get<4>(partitions[i]);
partition_schema.index_file_size_ = std::get<5>(partitions[i]);
partition_schema.engine_type_ = std::get<6>(partitions[i]);
partition_schema.nlist_ = std::get<7>(partitions[i]);
partition_schema.index_params_ = std::get<7>(partitions[i]);
partition_schema.metric_type_ = std::get<8>(partitions[i]);
partition_schema.owner_table_ = table_id;
partition_schema.partition_tag_ = std::get<9>(partitions[i]);
@ -995,7 +988,7 @@ SqliteMetaImpl::FilesToSearch(const std::string& table_id, const std::vector<siz
table_file.engine_type_ = std::get<8>(file);
table_file.dimension_ = table_schema.dimension_;
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
@ -1063,7 +1056,7 @@ SqliteMetaImpl::FilesToMerge(const std::string& table_id, TableFilesSchema& file
table_file.created_on_ = std::get<8>(file);
table_file.dimension_ = table_schema.dimension_;
table_file.index_file_size_ = table_schema.index_file_size_;
table_file.nlist_ = table_schema.nlist_;
table_file.index_params_ = table_schema.index_params_;
table_file.metric_type_ = table_schema.metric_type_;
auto status = utils::GetTableFilePath(options_, table_file);
@ -1134,7 +1127,7 @@ SqliteMetaImpl::FilesToIndex(TableFilesSchema& files) {
}
table_file.dimension_ = groups[table_file.table_id_].dimension_;
table_file.index_file_size_ = groups[table_file.table_id_].index_file_size_;
table_file.nlist_ = groups[table_file.table_id_].nlist_;
table_file.index_params_ = groups[table_file.table_id_].index_params_;
table_file.metric_type_ = groups[table_file.table_id_].metric_type_;
files.push_back(table_file);
}
@ -1192,7 +1185,7 @@ SqliteMetaImpl::FilesByType(const std::string& table_id, const std::vector<int>&
file_schema.dimension_ = table_schema.dimension_;
file_schema.index_file_size_ = table_schema.index_file_size_;
file_schema.nlist_ = table_schema.nlist_;
file_schema.index_params_ = table_schema.index_params_;
file_schema.metric_type_ = table_schema.metric_type_;
switch (file_schema.file_type_) {
@ -1423,8 +1416,7 @@ SqliteMetaImpl::CleanUpFilesWithTTL(uint64_t seconds /*, CleanUpFilter* filter*/
// If we are deleting a raw table file, it means it's okay to delete the entire segment directory.
// Else, we can only delete the single file
// TODO(zhiru): We determine whether a table file is raw by its engine type. This is a bit hacky
if (table_file.engine_type_ == (int32_t)EngineType::FAISS_IDMAP ||
table_file.engine_type_ == (int32_t)EngineType::FAISS_BIN_IDMAP) {
if (utils::IsRawIndexType(table_file.engine_type_)) {
utils::DeleteSegment(options_, table_file);
std::string segment_dir;
utils::GetParentPath(table_file.location_, segment_dir);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,14 @@ import "status.proto";
package milvus.grpc;
/**
* @brief general usage
*/
message KeyValuePair {
string key = 1;
string value = 2;
}
/**
* @brief Table name
*/
@ -21,6 +29,7 @@ message TableNameList {
/**
* @brief Table schema
* metric_type: 1-L2, 2-IP
*/
message TableSchema {
Status status = 1;
@ -28,6 +37,7 @@ message TableSchema {
int64 dimension = 3;
int64 index_file_size = 4;
int32 metric_type = 5;
repeated KeyValuePair extra_params = 6;
}
/**
@ -62,6 +72,7 @@ message InsertParam {
repeated RowRecord row_record_array = 2;
repeated int64 row_id_array = 3; //optional
string partition_tag = 4;
repeated KeyValuePair extra_params = 5;
}
/**
@ -77,10 +88,10 @@ message VectorIds {
*/
message SearchParam {
string table_name = 1;
repeated RowRecord query_record_array = 2;
int64 topk = 3;
int64 nprobe = 4;
repeated string partition_tag_array = 5;
repeated string partition_tag_array = 2;
repeated RowRecord query_record_array = 3;
int64 topk = 4;
repeated KeyValuePair extra_params = 5;
}
/**
@ -96,10 +107,10 @@ message SearchInFilesParam {
*/
message SearchByIDParam {
string table_name = 1;
int64 id = 2;
int64 topk = 3;
int64 nprobe = 4;
repeated string partition_tag_array = 5;
repeated string partition_tag_array = 2;
int64 id = 3;
int64 topk = 4;
repeated KeyValuePair extra_params = 5;
}
/**
@ -143,23 +154,15 @@ message Command {
string cmd = 1;
}
/**
* @brief Index
* @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
* @metric_type: 1-L2, 2-IP
*/
message Index {
int32 index_type = 1;
int32 nlist = 2;
}
/**
* @brief Index params
* @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
*/
message IndexParam {
Status status = 1;
string table_name = 2;
Index index = 3;
int32 index_type = 3;
repeated KeyValuePair extra_params = 4;
}
/**

View File

@ -22,7 +22,6 @@ endif ()
set(external_srcs
knowhere/adapter/SptagAdapter.cpp
knowhere/adapter/VectorAdapter.cpp
knowhere/common/Exception.cpp
knowhere/common/Timer.cpp
)
@ -117,4 +116,4 @@ set(INDEX_INCLUDE_DIRS
${LAPACK_INCLUDE_DIR}
)
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)

View File

@ -43,7 +43,8 @@ std::vector<SPTAG::QueryResult>
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) {
GETTENSOR(dataset);
std::vector<SPTAG::QueryResult> query_results(rows, SPTAG::QueryResult(nullptr, config->k, true));
std::vector<SPTAG::QueryResult> query_results(rows,
SPTAG::QueryResult(nullptr, config[meta::TOPK].get<int64_t>(), true));
for (auto i = 0; i < rows; ++i) {
query_results[i].SetTarget(&p_data[i * dim]);
}

View File

@ -1,24 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/adapter/VectorAdapter.h"
namespace knowhere {
namespace meta {
const char* DIM = "dim";
const char* TENSOR = "tensor";
const char* ROWS = "rows";
const char* IDS = "ids";
const char* DISTANCE = "distance";
}; // namespace meta
} // namespace knowhere

View File

@ -13,17 +13,10 @@
#include <string>
#include "knowhere/common/Dataset.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace knowhere {
namespace meta {
extern const char* DIM;
extern const char* TENSOR;
extern const char* ROWS;
extern const char* IDS;
extern const char* DISTANCE;
}; // namespace meta
#define GETTENSOR(dataset) \
auto dim = dataset->Get<int64_t>(meta::DIM); \
auto rows = dataset->Get<int64_t>(meta::ROWS); \

View File

@ -11,64 +11,10 @@
#pragma once
#include <memory>
#include <sstream>
#include "Log.h"
#include "knowhere/common/Exception.h"
#include "src/utils/Json.h"
namespace knowhere {
enum class METRICTYPE {
INVALID = 0,
L2 = 1,
IP = 2,
HAMMING = 20,
JACCARD = 21,
TANIMOTO = 22,
};
// General Config
constexpr int64_t INVALID_VALUE = -1;
constexpr int64_t DEFAULT_K = INVALID_VALUE;
constexpr int64_t DEFAULT_DIM = INVALID_VALUE;
constexpr int64_t DEFAULT_GPUID = INVALID_VALUE;
constexpr METRICTYPE DEFAULT_TYPE = METRICTYPE::INVALID;
struct Cfg {
METRICTYPE metric_type = DEFAULT_TYPE;
int64_t k = DEFAULT_K;
int64_t gpu_id = DEFAULT_GPUID;
int64_t d = DEFAULT_DIM;
Cfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, METRICTYPE type)
: metric_type(type), k(k), gpu_id(gpu_id), d(dim) {
}
Cfg() = default;
virtual bool
CheckValid() {
if (metric_type == METRICTYPE::IP || metric_type == METRICTYPE::L2) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
void
Dump() {
KNOWHERE_LOG_DEBUG << DumpImpl().str();
}
virtual std::stringstream
DumpImpl() {
std::stringstream ss;
ss << "dim: " << d << ", metric: " << int(metric_type) << ", gpuid: " << gpu_id << ", k: " << k;
return ss;
}
};
using Config = std::shared_ptr<Cfg>;
using Config = milvus::json;
} // namespace knowhere

View File

@ -16,6 +16,7 @@
#include "IndexModel.h"
#include "IndexType.h"
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/Dataset.h"
#include "knowhere/index/preprocessor/Preprocessor.h"

View File

@ -14,6 +14,7 @@
#include <utility>
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"

View File

@ -15,6 +15,8 @@
#include <faiss/MetaIndexes.h>
#include <faiss/index_factory.h>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
@ -43,13 +45,13 @@ BinaryIDMAP::Search(const DatasetPtr& dataset, const Config& config) {
}
GETBINARYTENSOR(dataset)
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, Config());
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
@ -90,14 +92,9 @@ BinaryIDMAP::Add(const DatasetPtr& dataset, const Config& config) {
void
BinaryIDMAP::Train(const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
if (build_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
config->CheckValid();
const char* type = "BFlat";
auto index = faiss::index_binary_factory(config->d, type, GetMetricType(config->metric_type));
auto index = faiss::index_binary_factory(config[meta::DIM].get<int64_t>(), type,
GetMetricType(config[Metric::TYPE].get<std::string>()));
index_.reset(index);
}
@ -181,26 +178,18 @@ BinaryIDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize");
}
// auto search_cfg = std::dynamic_pointer_cast<BinIDMAPCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
// GETBINARYTENSOR(dataset)
auto dim = dataset->Get<int64_t>(meta::DIM);
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
auto* pdistances = (int32_t*)p_dist;
// index_->searchById(rows, (uint8_t*)p_data, config->k, pdistances, p_id, bitset_);
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
index_->search_by_id(rows, p_data, config->k, pdistances, p_id, bitset_);
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {

View File

@ -15,9 +15,11 @@
#include <faiss/IndexBinaryIVF.h>
#include <chrono>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
namespace knowhere {
@ -45,22 +47,17 @@ BinaryIVF::Search(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
GETBINARYTENSOR(dataset)
try {
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (uint8_t*)p_data, config->k, p_dist, p_id, config);
search_impl(rows, (uint8_t*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
@ -108,29 +105,20 @@ BinaryIVF::search_impl(int64_t n, const uint8_t* data, int64_t k, float* distanc
std::shared_ptr<faiss::IVFSearchParameters>
BinaryIVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->max_codes = config.get_with_default("max_codes", size_t(0));
params->nprobe = config[IndexParams::nprobe];
// params->max_codes = config["max_code"];
return params;
}
IndexModelPtr
BinaryIVF::Train(const DatasetPtr& dataset, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
auto build_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETBINARYTENSOR(dataset)
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);
faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, GetMetricType(build_cfg->metric_type));
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type));
faiss::IndexBinary* coarse_quantizer =
new faiss::IndexBinaryFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
auto index = std::make_shared<faiss::IndexBinaryIVF>(coarse_quantizer, dim, config[IndexParams::nlist],
GetMetricType(config[Metric::TYPE].get<std::string>()));
index->train(rows, (uint8_t*)p_data);
index->add_with_ids(rows, (uint8_t*)p_data, p_ids);
index_ = index;
@ -190,17 +178,11 @@ BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
// auto search_cfg = std::dynamic_pointer_cast<IVFBinCfg>(config);
// if (search_cfg == nullptr) {
// KNOWHERE_THROW_MSG("not support this kind of config");
// }
// GETBINARYTENSOR(dataset)
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
try {
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
@ -208,9 +190,7 @@ BinaryIVF::SearchById(const DatasetPtr& dataset, const Config& config) {
auto p_dist = (float*)malloc(p_dist_size);
int32_t* pdistances = (int32_t*)p_dist;
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
// index_->searchById(rows, (uint8_t*)p_data, config->k, pdistances, p_id, blacklist);
index_->search_by_id(rows, p_data, config->k, pdistances, p_id, bitset_);
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {

View File

@ -16,6 +16,7 @@
#include <faiss/MetaIndexes.h>
#include <faiss/index_io.h>
#include <fiu-local.h>
#include <string>
#ifdef MILVUS_GPU_VERSION
@ -127,7 +128,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t& k, Graph& graph, const Conf
int64_t K = k + 1;
auto ntotal = Count();
size_t dim = config->d;
size_t dim = config[meta::DIM];
auto batch_size = 1000;
auto tail_batch_size = ntotal % batch_size;
auto batch_search_count = ntotal / batch_size;

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <memory>
#include <string>
#include <faiss/gpu/GpuCloner.h>
#include <faiss/gpu/GpuIndexIVF.h>
@ -28,21 +29,16 @@ namespace knowhere {
IndexModelPtr
GPUIVF::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_id_;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type), idx_config);
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, config[IndexParams::nlist],
GetMetricType(config[Metric::TYPE].get<std::string>()), idx_config);
device_index.train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
@ -121,15 +117,13 @@ GPUIVF::LoadImpl(const BinarySet& index_binary) {
}
void
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
if (device_index) {
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(cfg);
device_index->nprobe = search_cfg->nprobe;
// assert(device_index->getNumProbes() == search_cfg->nprobe);
device_index->nprobe = config[IndexParams::nprobe];
ResScope rs(res_, gpu_id_);
device_index->search(n, (float*)data, k, distances, labels);
} else {

View File

@ -15,6 +15,7 @@
#include <faiss/index_factory.h>
#include <memory>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
@ -25,20 +26,16 @@ namespace knowhere {
IndexModelPtr
GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
auto device_index = new faiss::gpu::GpuIndexIVFPQ(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
build_cfg->m, build_cfg->nbits,
GetMetricType(build_cfg->metric_type)); // IP not support
auto device_index = new faiss::gpu::GpuIndexIVFPQ(
temp_resource->faiss_res.get(), dim, config[IndexParams::nlist].get<int64_t>(), config[IndexParams::m],
config[IndexParams::nbits],
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
@ -51,11 +48,10 @@ GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
std::shared_ptr<faiss::IVFSearchParameters>
GPUIVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
params->nprobe = config[IndexParams::nprobe];
// params->scan_table_threshold = config["scan_table_threhold"]
// params->polysemous_ht = config["polysemous_ht"]
// params->max_codes = config["max_codes"]
return params;
}

View File

@ -13,6 +13,7 @@
#include <faiss/index_factory.h>
#include <memory>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
@ -23,18 +24,14 @@ namespace knowhere {
IndexModelPtr
GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ" << config[IndexParams::nbits];
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {

View File

@ -79,20 +79,15 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("search conf is null");
}
index_->setEf(search_cfg->ef);
GETTENSOR(dataset)
size_t id_size = sizeof(int64_t) * config->k;
size_t dist_size = sizeof(float) * config->k;
size_t id_size = sizeof(int64_t) * config[meta::TOPK].get<int64_t>();
size_t dist_size = sizeof(float) * config[meta::TOPK].get<int64_t>();
auto p_id = (int64_t*)malloc(id_size * rows);
auto p_dist = (float*)malloc(dist_size * rows);
index_->setEf(config[IndexParams::ef]);
using P = std::pair<float, int64_t>;
auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; };
#pragma omp parallel for
@ -103,13 +98,13 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
// if (normalize) {
// std::vector<float> norm_vector(Dimension());
// normalize_vector((float*)(single_query), norm_vector.data(), Dimension());
// ret = index_->searchKnn((float*)(norm_vector.data()), config->k, compare);
// ret = index_->searchKnn((float*)(norm_vector.data()), config[meta::TOPK].get<int64_t>(), compare);
// } else {
// ret = index_->searchKnn((float*)single_query, config->k, compare);
// ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
// }
ret = index_->searchKnn((float*)single_query, config->k, compare);
ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get<int64_t>(), compare);
while (ret.size() < config->k) {
while (ret.size() < config[meta::TOPK]) {
ret.push_back(std::make_pair(-1, -1));
}
std::vector<float> dist;
@ -125,8 +120,8 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
std::transform(ret.begin(), ret.end(), std::back_inserter(ids),
[](const std::pair<float, int64_t>& e) { return e.second; });
memcpy(p_dist + i * config->k, dist.data(), dist_size);
memcpy(p_id + i * config->k, ids.data(), id_size);
memcpy(p_dist + i * config[meta::TOPK].get<int64_t>(), dist.data(), dist_size);
memcpy(p_id + i * config[meta::TOPK].get<int64_t>(), ids.data(), id_size);
}
auto ret_ds = std::make_shared<Dataset>();
@ -137,21 +132,17 @@ IndexHNSW::Search(const DatasetPtr& dataset, const Config& config) {
IndexModelPtr
IndexHNSW::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<HNSWCfg>(config);
if (build_cfg == nullptr) {
KNOWHERE_THROW_MSG("build conf is null");
}
GETTENSOR(dataset)
hnswlib::SpaceInterface<float>* space;
if (config->metric_type == METRICTYPE::L2) {
if (config[Metric::TYPE] == Metric::L2) {
space = new hnswlib::L2Space(dim);
} else if (config->metric_type == METRICTYPE::IP) {
} else if (config[Metric::TYPE] == Metric::IP) {
space = new hnswlib::InnerProductSpace(dim);
normalize = true;
}
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, build_cfg->M, build_cfg->ef);
index_ = std::make_shared<hnswlib::HierarchicalNSW<float>>(space, rows, config[IndexParams::M].get<int64_t>(),
config[IndexParams::efConstruction].get<int64_t>());
return nullptr;
}

View File

@ -22,6 +22,7 @@
#endif
#include <string>
#include <vector>
#include "knowhere/adapter/VectorAdapter.h"
@ -61,13 +62,13 @@ IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
}
GETTENSOR(dataset)
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (float*)p_data, config->k, p_dist, p_id, Config());
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, Config());
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -144,10 +145,9 @@ IDMAP::GetRawIds() {
void
IDMAP::Train(const Config& config) {
config->CheckValid();
const char* type = "IDMap,Flat";
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
auto index = faiss::index_factory(config[meta::DIM].get<int64_t>(), type,
GetMetricType(config[Metric::TYPE].get<std::string>()));
index_.reset(index);
}
@ -214,7 +214,7 @@ IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
auto elems = rows * config->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
@ -222,8 +222,8 @@ IDMAP::SearchById(const DatasetPtr& dataset, const Config& config) {
// todo: enable search by id (zhiru)
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
// index_->searchById(rows, (float*)p_data, config->k, p_dist, p_id, blacklist);
index_->search_by_id(rows, p_data, config->k, p_dist, p_id, bitset_);
// index_->searchById(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, blacklist);
index_->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);

View File

@ -26,6 +26,7 @@
#include <fiu-local.h>
#include <chrono>
#include <memory>
#include <string>
#include <utility>
#include <vector>
@ -43,16 +44,11 @@ using stdclock = std::chrono::high_resolution_clock;
IndexModelPtr
IVF::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type));
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
GetMetricType(config[Metric::TYPE].get<std::string>()));
index->train(rows, (float*)p_data);
// TODO(linxj): override here. train return model or not.
@ -106,24 +102,19 @@ IVF::Search(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
GETTENSOR(dataset)
try {
fiu_do_on("IVF.Search.throw_std_exception", throw std::exception());
fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException(""));
auto elems = rows * search_cfg->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
search_impl(rows, (float*)p_data, search_cfg->k, p_dist, p_id, config);
search_impl(rows, (float*)p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, config);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {
@ -163,9 +154,8 @@ std::shared_ptr<faiss::IVFSearchParameters>
IVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->max_codes = config.get_with_default("max_codes", size_t(0));
params->nprobe = config[IndexParams::nprobe];
// params->max_codes = config["max_codes"];
return params;
}
@ -185,7 +175,7 @@ IVF::GenGraph(const float* data, const int64_t& k, Graph& graph, const Config& c
int64_t K = k + 1;
auto ntotal = Count();
size_t dim = config->d;
size_t dim = config[meta::DIM];
auto batch_size = 1000;
auto tail_batch_size = ntotal % batch_size;
auto batch_search_count = ntotal / batch_size;
@ -279,12 +269,6 @@ IVF::GetVectorById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
// auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
auto elems = dataset->Get<int64_t>(meta::DIM);
@ -311,16 +295,11 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg == nullptr) {
KNOWHERE_THROW_MSG("not support this kind of config");
}
auto rows = dataset->Get<int64_t>(meta::ROWS);
auto p_data = dataset->Get<const int64_t*>(meta::IDS);
try {
auto elems = rows * search_cfg->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
@ -330,7 +309,7 @@ IVF::SearchById(const DatasetPtr& dataset, const Config& config) {
// todo: enable search by id (zhiru)
// auto blacklist = dataset->Get<faiss::ConcurrentBitsetPtr>("bitset");
auto index_ivf = std::static_pointer_cast<faiss::IndexIVF>(index_);
index_ivf->search_by_id(rows, p_data, search_cfg->k, p_dist, p_id, bitset_);
index_ivf->search_by_id(rows, p_data, config[meta::TOPK].get<int64_t>(), p_dist, p_id, bitset_);
// std::stringstream ss_res_id, ss_res_dist;
// for (int i = 0; i < 10; ++i) {

View File

@ -16,6 +16,7 @@
#endif
#include <memory>
#include <string>
#include <utility>
#include "knowhere/adapter/VectorAdapter.h"
@ -30,16 +31,12 @@ namespace knowhere {
IndexModelPtr
IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type));
auto index =
std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits);
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(config[Metric::TYPE].get<std::string>()));
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, config[IndexParams::nlist].get<int64_t>(),
config[IndexParams::m].get<int64_t>(),
config[IndexParams::nbits].get<int64_t>());
index->train(rows, (float*)p_data);
return std::make_shared<IVFIndexModel>(index);
@ -48,11 +45,10 @@ IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
std::shared_ptr<faiss::IVFSearchParameters>
IVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
params->nprobe = config[IndexParams::nprobe];
// params->scan_table_threshold = config["scan_table_threhold"]
// params->polysemous_ht = config["polysemous_ht"]
// params->max_codes = config["max_codes"]
return params;
}

View File

@ -16,6 +16,7 @@
#include <faiss/index_factory.h>
#include <memory>
#include <string>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
@ -30,17 +31,13 @@ namespace knowhere {
IndexModelPtr
IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ" << config[IndexParams::nbits];
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
build_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> ret_index;

View File

@ -19,6 +19,7 @@
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/index_factory.h>
#include <fiu-local.h>
#include <string>
#include <utility>
namespace knowhere {
@ -30,19 +31,14 @@ namespace knowhere {
IndexModelPtr
IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
// std::lock_guard<std::mutex> lk(g_mutex);
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
gpu_id_ = config[knowhere::meta::DEVICEID];
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
index_type << "IVF" << config[IndexParams::nlist] << ","
<< "SQ8Hybrid";
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
auto build_index =
faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get<std::string>()));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
@ -133,17 +129,10 @@ IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distanc
}
QuantizerPtr
IVFSQHybrid::LoadQuantizer(const Config& conf) {
IVFSQHybrid::LoadQuantizer(const Config& config) {
// std::lock_guard<std::mutex> lk(g_mutex);
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
if (quantizer_conf->mode != 1) {
KNOWHERE_THROW_MSG("mode only support 1 in this func");
}
}
auto gpu_id = quantizer_conf->gpu_id;
auto gpu_id = config[knowhere::meta::DEVICEID].get<int64_t>();
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
ResScope rs(res, gpu_id, false);
faiss::gpu::GpuClonerOptions option;
@ -152,7 +141,7 @@ IVFSQHybrid::LoadQuantizer(const Config& conf) {
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = nullptr;
index_composition->mode = quantizer_conf->mode; // only 1
index_composition->mode = 1; // only 1
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
delete gpu_index;
@ -205,19 +194,10 @@ IVFSQHybrid::UnsetQuantizer() {
}
VectorIndexPtr
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& config) {
// std::lock_guard<std::mutex> lk(g_mutex);
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
if (quantizer_conf->mode != 2) {
KNOWHERE_THROW_MSG("mode only support 2 in this func");
}
} else {
KNOWHERE_THROW_MSG("conf error");
}
auto gpu_id = quantizer_conf->gpu_id;
int64_t gpu_id = config[knowhere::meta::DEVICEID];
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
ResScope rs(res, gpu_id, false);
@ -231,7 +211,7 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = ivf_quantizer->quantizer;
index_composition->mode = quantizer_conf->mode; // only 2
index_composition->mode = 2; // only 2
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option);
std::shared_ptr<faiss::Index> new_idx;

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Timer.h"
@ -23,6 +24,7 @@
#endif
#include <fiu-local.h>
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/nsg/NSG.h"
@ -72,23 +74,21 @@ NSG::Load(const BinarySet& index_binary) {
DatasetPtr
NSG::Search(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
GETTENSOR(dataset)
auto elems = rows * build_cfg->k;
auto elems = rows * config[meta::TOPK].get<int64_t>();
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
algo::SearchParams s_params;
s_params.search_length = build_cfg->search_length;
index_->Search((float*)p_data, rows, dim, build_cfg->k, p_dist, p_id, s_params);
s_params.search_length = config[IndexParams::search_length];
index_->Search((float*)p_data, rows, dim, config[meta::TOPK].get<int64_t>(), p_dist, p_id, s_params);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -98,41 +98,35 @@ NSG::Search(const DatasetPtr& dataset, const Config& config) {
IndexModelPtr
NSG::Train(const DatasetPtr& dataset, const Config& config) {
config->Dump();
auto build_cfg = std::dynamic_pointer_cast<NSGCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
auto idmap = std::make_shared<IDMAP>();
idmap->Train(config);
idmap->AddWithoutId(dataset, config);
Graph knng;
const float* raw_data = idmap->GetRawVectors();
#ifdef MILVUS_GPU_VERSION
if (build_cfg->gpu_id == knowhere::INVALID_VALUE) {
if (config[knowhere::meta::DEVICEID].get<int64_t>() == -1) {
auto preprocess_index = std::make_shared<IVF>();
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->Add(dataset, config);
preprocess_index->GenGraph(raw_data, build_cfg->knng, knng, config);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
} else {
auto gpu_idx = cloner::CopyCpuToGpu(idmap, build_cfg->gpu_id, config);
auto gpu_idx = cloner::CopyCpuToGpu(idmap, config[knowhere::meta::DEVICEID].get<int64_t>(), config);
auto gpu_idmap = std::dynamic_pointer_cast<GPUIDMAP>(gpu_idx);
gpu_idmap->GenGraph(raw_data, build_cfg->knng, knng, config);
gpu_idmap->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
}
#else
auto preprocess_index = std::make_shared<IVF>();
auto model = preprocess_index->Train(dataset, config);
preprocess_index->set_index_model(model);
preprocess_index->AddWithoutIds(dataset, config);
preprocess_index->GenGraph(raw_data, build_cfg->knng, knng, config);
preprocess_index->GenGraph(raw_data, config[IndexParams::knng].get<int64_t>(), knng, config);
#endif
algo::BuildParams b_params;
b_params.candidate_pool_size = build_cfg->candidate_pool_size;
b_params.out_degree = build_cfg->out_degree;
b_params.search_length = build_cfg->search_length;
b_params.candidate_pool_size = config[IndexParams::candidate];
b_params.out_degree = config[IndexParams::out_degree];
b_params.search_length = config[IndexParams::search_length];
auto p_ids = dataset->Get<const int64_t*>(meta::IDS);

View File

@ -123,9 +123,6 @@ CPUSPTAGRNG::Load(const BinarySet& binary_set) {
IndexModelPtr
CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) {
SetParameters(train_config);
if (train_config != nullptr) {
train_config->CheckValid(); // throw exception
}
DatasetPtr dataset = origin; // TODO(linxj): copy or reference?
@ -159,62 +156,56 @@ CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) {
void
CPUSPTAGRNG::SetParameters(const Config& config) {
#define Assign(param_name, str_name) \
conf->param_name == INVALID_VALUE ? index_ptr_->SetParameter(str_name, std::to_string(build_cfg->param_name)) \
: index_ptr_->SetParameter(str_name, std::to_string(conf->param_name))
#define Assign(param_name, str_name) \
index_ptr_->SetParameter(str_name, std::to_string(build_cfg[param_name].get<int64_t>()))
if (index_type_ == SPTAG::IndexAlgoType::KDT) {
auto conf = std::dynamic_pointer_cast<KDTCfg>(config);
auto build_cfg = SPTAGParameterMgr::GetInstance().GetKDTParameters();
Assign(kdtnumber, "KDTNumber");
Assign(numtopdimensionkdtsplit, "NumTopDimensionKDTSplit");
Assign(samples, "Samples");
Assign(tptnumber, "TPTNumber");
Assign(tptleafsize, "TPTLeafSize");
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
Assign(neighborhoodsize, "NeighborhoodSize");
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
Assign(graphcefscale, "GraphCEFScale");
Assign(refineiterations, "RefineIterations");
Assign(cef, "CEF");
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
Assign(numofthreads, "NumberOfThreads");
Assign(maxcheck, "MaxCheck");
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
Assign("kdtnumber", "KDTNumber");
Assign("numtopdimensionkdtsplit", "NumTopDimensionKDTSplit");
Assign("samples", "Samples");
Assign("tptnumber", "TPTNumber");
Assign("tptleafsize", "TPTLeafSize");
Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit");
Assign("neighborhoodsize", "NeighborhoodSize");
Assign("graphneighborhoodscale", "GraphNeighborhoodScale");
Assign("graphcefscale", "GraphCEFScale");
Assign("refineiterations", "RefineIterations");
Assign("cef", "CEF");
Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph");
Assign("numofthreads", "NumberOfThreads");
Assign("maxcheck", "MaxCheck");
Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots");
Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots");
} else {
auto conf = std::dynamic_pointer_cast<BKTCfg>(config);
auto build_cfg = SPTAGParameterMgr::GetInstance().GetBKTParameters();
Assign(bktnumber, "BKTNumber");
Assign(bktkmeansk, "BKTKMeansK");
Assign(bktleafsize, "BKTLeafSize");
Assign(samples, "Samples");
Assign(tptnumber, "TPTNumber");
Assign(tptleafsize, "TPTLeafSize");
Assign(numtopdimensiontptsplit, "NumTopDimensionTPTSplit");
Assign(neighborhoodsize, "NeighborhoodSize");
Assign(graphneighborhoodscale, "GraphNeighborhoodScale");
Assign(graphcefscale, "GraphCEFScale");
Assign(refineiterations, "RefineIterations");
Assign(cef, "CEF");
Assign(maxcheckforrefinegraph, "MaxCheckForRefineGraph");
Assign(numofthreads, "NumberOfThreads");
Assign(maxcheck, "MaxCheck");
Assign(thresholdofnumberofcontinuousnobetterpropagation, "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign(numberofinitialdynamicpivots, "NumberOfInitialDynamicPivots");
Assign(numberofotherdynamicpivots, "NumberOfOtherDynamicPivots");
Assign("bktnumber", "BKTNumber");
Assign("bktkmeansk", "BKTKMeansK");
Assign("bktleafsize", "BKTLeafSize");
Assign("samples", "Samples");
Assign("tptnumber", "TPTNumber");
Assign("tptleafsize", "TPTLeafSize");
Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit");
Assign("neighborhoodsize", "NeighborhoodSize");
Assign("graphneighborhoodscale", "GraphNeighborhoodScale");
Assign("graphcefscale", "GraphCEFScale");
Assign("refineiterations", "RefineIterations");
Assign("cef", "CEF");
Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph");
Assign("numofthreads", "NumberOfThreads");
Assign("maxcheck", "MaxCheck");
Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation");
Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots");
Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots");
}
}
DatasetPtr
CPUSPTAGRNG::Search(const DatasetPtr& dataset, const Config& config) {
SetParameters(config);
// if (config != nullptr) {
// config->CheckValid(); // throw exception
// }
auto p_data = dataset->Get<const float*>(meta::TENSOR);
for (auto i = 0; i < 10; ++i) {

View File

@ -23,9 +23,9 @@ struct Quantizer {
};
using QuantizerPtr = std::shared_ptr<Quantizer>;
struct QuantizerCfg : Cfg {
int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
};
using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
// struct QuantizerCfg : Cfg {
// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
// };
// using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
} // namespace knowhere

View File

@ -24,7 +24,10 @@ namespace cloner {
VectorIndexPtr
CopyGpuToCpu(const VectorIndexPtr& index, const Config& config) {
if (auto device_index = std::dynamic_pointer_cast<GPUIndex>(index)) {
return device_index->CopyGpuToCpu(config);
VectorIndexPtr result = device_index->CopyGpuToCpu(config);
auto uids = index->GetUids();
result->SetUids(uids);
return result;
} else {
KNOWHERE_THROW_MSG("index type is not gpuindex");
}

View File

@ -17,47 +17,23 @@
namespace knowhere {
faiss::MetricType
GetMetricType(METRICTYPE& type) {
if (type == METRICTYPE::L2) {
GetMetricType(const std::string& type) {
if (type == Metric::L2) {
return faiss::METRIC_L2;
}
if (type == METRICTYPE::IP) {
if (type == Metric::IP) {
return faiss::METRIC_INNER_PRODUCT;
}
// binary only
if (type == METRICTYPE::JACCARD) {
if (type == Metric::JACCARD) {
return faiss::METRIC_Jaccard;
}
if (type == METRICTYPE::TANIMOTO) {
if (type == Metric::TANIMOTO) {
return faiss::METRIC_Tanimoto;
}
if (type == METRICTYPE::HAMMING) {
if (type == Metric::HAMMING) {
return faiss::METRIC_Hamming;
}
KNOWHERE_THROW_MSG("Metric type is invalid");
}
std::stringstream
IVFCfg::DumpImpl() {
auto ss = Cfg::DumpImpl();
ss << ", nlist: " << nlist << ", nprobe: " << nprobe;
return ss;
}
std::stringstream
IVFSQCfg::DumpImpl() {
auto ss = IVFCfg::DumpImpl();
ss << ", nbits: " << nbits;
return ss;
}
std::stringstream
NSGCfg::DumpImpl() {
auto ss = IVFCfg::DumpImpl();
ss << ", knng: " << knng << ", search_length: " << search_length << ", out_degree: " << out_degree
<< ", candidate: " << candidate_pool_size;
return ss;
}
} // namespace knowhere

View File

@ -12,240 +12,49 @@
#pragma once
#include <faiss/Index.h>
#include <memory>
#include "knowhere/common/Config.h"
#include <string>
namespace knowhere {
namespace meta {
constexpr const char* DIM = "dim";
constexpr const char* TENSOR = "tensor";
constexpr const char* ROWS = "rows";
constexpr const char* IDS = "ids";
constexpr const char* DISTANCE = "distance";
constexpr const char* TOPK = "k";
constexpr const char* DEVICEID = "gpu_id";
}; // namespace meta
namespace IndexParams {
// IVF Params
constexpr const char* nprobe = "nprobe";
constexpr const char* nlist = "nlist";
constexpr const char* m = "m"; // PQ
constexpr const char* nbits = "nbits"; // PQ/SQ
// NSG Params
constexpr const char* knng = "knng";
constexpr const char* search_length = "search_length";
constexpr const char* out_degree = "out_degree";
constexpr const char* candidate = "candidate_pool_size";
// HNSW Params
constexpr const char* efConstruction = "efConstruction";
constexpr const char* M = "M";
constexpr const char* ef = "ef";
} // namespace IndexParams
namespace Metric {
constexpr const char* TYPE = "metric_type";
constexpr const char* IP = "IP";
constexpr const char* L2 = "L2";
constexpr const char* HAMMING = "HAMMING";
constexpr const char* JACCARD = "JACCARD";
constexpr const char* TANIMOTO = "TANIMOTO";
} // namespace Metric
extern faiss::MetricType
GetMetricType(METRICTYPE& type);
// IVF Config
constexpr int64_t DEFAULT_NLIST = INVALID_VALUE;
constexpr int64_t DEFAULT_NPROBE = INVALID_VALUE;
constexpr int64_t DEFAULT_NSUBVECTORS = INVALID_VALUE;
constexpr int64_t DEFAULT_NBITS = INVALID_VALUE;
constexpr int64_t DEFAULT_SCAN_TABLE_THREHOLD = INVALID_VALUE;
constexpr int64_t DEFAULT_POLYSEMOUS_HT = INVALID_VALUE;
constexpr int64_t DEFAULT_MAX_CODES = INVALID_VALUE;
// NSG Config
constexpr int64_t DEFAULT_SEARCH_LENGTH = INVALID_VALUE;
constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
constexpr int64_t DEFAULT_CANDIDATE_SISE = INVALID_VALUE;
constexpr int64_t DEFAULT_NNG_K = INVALID_VALUE;
// SPTAG Config
constexpr int64_t DEFAULT_SAMPLES = INVALID_VALUE;
constexpr int64_t DEFAULT_TPTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_TPTLEAFSIZE = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMTOPDIMENSIONTPTSPLIT = INVALID_VALUE;
constexpr int64_t DEFAULT_NEIGHBORHOODSIZE = INVALID_VALUE;
constexpr int64_t DEFAULT_GRAPHNEIGHBORHOODSCALE = INVALID_VALUE;
constexpr int64_t DEFAULT_GRAPHCEFSCALE = INVALID_VALUE;
constexpr int64_t DEFAULT_REFINEITERATIONS = INVALID_VALUE;
constexpr int64_t DEFAULT_CEF = INVALID_VALUE;
constexpr int64_t DEFAULT_MAXCHECKFORREFINEGRAPH = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMOFTHREADS = INVALID_VALUE;
constexpr int64_t DEFAULT_MAXCHECK = INVALID_VALUE;
constexpr int64_t DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS = INVALID_VALUE;
// KDT Config
constexpr int64_t DEFAULT_KDTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_NUMTOPDIMENSIONKDTSPLIT = INVALID_VALUE;
// BKT Config
constexpr int64_t DEFAULT_BKTNUMBER = INVALID_VALUE;
constexpr int64_t DEFAULT_BKTKMEANSK = INVALID_VALUE;
constexpr int64_t DEFAULT_BKTLEAFSIZE = INVALID_VALUE;
// HNSW Config
constexpr int64_t DEFAULT_M = INVALID_VALUE;
constexpr int64_t DEFAULT_EF = INVALID_VALUE;
struct IVFCfg : public Cfg {
int64_t nlist = DEFAULT_NLIST;
int64_t nprobe = DEFAULT_NPROBE;
IVFCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
METRICTYPE type)
: Cfg(dim, k, gpu_id, type), nlist(nlist), nprobe(nprobe) {
}
IVFCfg() = default;
std::stringstream
DumpImpl() override;
// bool
// CheckValid() override {
// return true;
// };
};
using IVFConfig = std::shared_ptr<IVFCfg>;
struct IVFBinCfg : public IVFCfg {
bool
CheckValid() override {
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
metric_type == METRICTYPE::JACCARD) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
};
struct IVFSQCfg : public IVFCfg {
// TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16
int64_t nbits = DEFAULT_NBITS;
IVFSQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
const int64_t& nbits, METRICTYPE type)
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), nbits(nbits) {
}
std::stringstream
DumpImpl() override;
IVFSQCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using IVFSQConfig = std::shared_ptr<IVFSQCfg>;
struct IVFPQCfg : public IVFCfg {
int64_t m = DEFAULT_NSUBVECTORS; // number of subquantizers(subvector)
int64_t nbits = DEFAULT_NBITS; // number of bit per subvector index
// TODO(linxj): not use yet
int64_t scan_table_threhold = DEFAULT_SCAN_TABLE_THREHOLD;
int64_t polysemous_ht = DEFAULT_POLYSEMOUS_HT;
int64_t max_codes = DEFAULT_MAX_CODES;
IVFPQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
const int64_t& nbits, const int64_t& m, METRICTYPE type)
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type), m(m), nbits(nbits) {
}
IVFPQCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using IVFPQConfig = std::shared_ptr<IVFPQCfg>;
struct NSGCfg : public IVFCfg {
int64_t knng = DEFAULT_NNG_K;
int64_t search_length = DEFAULT_SEARCH_LENGTH;
int64_t out_degree = DEFAULT_OUT_DEGREE;
int64_t candidate_pool_size = DEFAULT_CANDIDATE_SISE;
NSGCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe,
const int64_t& knng, const int64_t& search_length, const int64_t& out_degree, const int64_t& candidate_size,
METRICTYPE type)
: IVFCfg(dim, k, gpu_id, nlist, nprobe, type),
knng(knng),
search_length(search_length),
out_degree(out_degree),
candidate_pool_size(candidate_size) {
}
NSGCfg() = default;
std::stringstream
DumpImpl() override;
// bool
// CheckValid() override {
// return true;
// };
};
using NSGConfig = std::shared_ptr<NSGCfg>;
struct SPTAGCfg : public Cfg {
int64_t samples = DEFAULT_SAMPLES;
int64_t tptnumber = DEFAULT_TPTNUMBER;
int64_t tptleafsize = DEFAULT_TPTLEAFSIZE;
int64_t numtopdimensiontptsplit = DEFAULT_NUMTOPDIMENSIONTPTSPLIT;
int64_t neighborhoodsize = DEFAULT_NEIGHBORHOODSIZE;
int64_t graphneighborhoodscale = DEFAULT_GRAPHNEIGHBORHOODSCALE;
int64_t graphcefscale = DEFAULT_GRAPHCEFSCALE;
int64_t refineiterations = DEFAULT_REFINEITERATIONS;
int64_t cef = DEFAULT_CEF;
int64_t maxcheckforrefinegraph = DEFAULT_MAXCHECKFORREFINEGRAPH;
int64_t numofthreads = DEFAULT_NUMOFTHREADS;
int64_t maxcheck = DEFAULT_MAXCHECK;
int64_t thresholdofnumberofcontinuousnobetterpropagation = DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION;
int64_t numberofinitialdynamicpivots = DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS;
int64_t numberofotherdynamicpivots = DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS;
SPTAGCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using SPTAGConfig = std::shared_ptr<SPTAGCfg>;
struct KDTCfg : public SPTAGCfg {
int64_t kdtnumber = DEFAULT_KDTNUMBER;
int64_t numtopdimensionkdtsplit = DEFAULT_NUMTOPDIMENSIONKDTSPLIT;
KDTCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using KDTConfig = std::shared_ptr<KDTCfg>;
struct BKTCfg : public SPTAGCfg {
int64_t bktnumber = DEFAULT_BKTNUMBER;
int64_t bktkmeansk = DEFAULT_BKTKMEANSK;
int64_t bktleafsize = DEFAULT_BKTLEAFSIZE;
BKTCfg() = default;
// bool
// CheckValid() override {
// return true;
// };
};
using BKTConfig = std::shared_ptr<BKTCfg>;
struct BinIDMAPCfg : public Cfg {
bool
CheckValid() override {
if (metric_type == METRICTYPE::HAMMING || metric_type == METRICTYPE::TANIMOTO ||
metric_type == METRICTYPE::JACCARD) {
return true;
}
std::stringstream ss;
ss << "MetricType: " << int(metric_type) << " not support!";
KNOWHERE_THROW_MSG(ss.str());
return false;
}
};
struct HNSWCfg : public Cfg {
int64_t M = DEFAULT_M;
int64_t ef = DEFAULT_EF;
HNSWCfg() = default;
};
using HNSWConfig = std::shared_ptr<HNSWCfg>;
GetMetricType(const std::string& type);
} // namespace knowhere

View File

@ -15,55 +15,53 @@
namespace knowhere {
const KDTConfig&
const Config&
SPTAGParameterMgr::GetKDTParameters() {
return kdt_config_;
}
const BKTConfig&
const Config&
SPTAGParameterMgr::GetBKTParameters() {
return bkt_config_;
}
SPTAGParameterMgr::SPTAGParameterMgr() {
kdt_config_ = std::make_shared<KDTCfg>();
kdt_config_->kdtnumber = 1;
kdt_config_->numtopdimensionkdtsplit = 5;
kdt_config_->samples = 100;
kdt_config_->tptnumber = 1;
kdt_config_->tptleafsize = 2000;
kdt_config_->numtopdimensiontptsplit = 5;
kdt_config_->neighborhoodsize = 32;
kdt_config_->graphneighborhoodscale = 2;
kdt_config_->graphcefscale = 2;
kdt_config_->refineiterations = 0;
kdt_config_->cef = 1000;
kdt_config_->maxcheckforrefinegraph = 10000;
kdt_config_->numofthreads = 1;
kdt_config_->maxcheck = 8192;
kdt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
kdt_config_->numberofinitialdynamicpivots = 50;
kdt_config_->numberofotherdynamicpivots = 4;
kdt_config_["kdtnumber"] = 1;
kdt_config_["numtopdimensionkdtsplit"] = 5;
kdt_config_["samples"] = 100;
kdt_config_["tptnumber"] = 1;
kdt_config_["tptleafsize"] = 2000;
kdt_config_["numtopdimensiontptsplit"] = 5;
kdt_config_["neighborhoodsize"] = 32;
kdt_config_["graphneighborhoodscale"] = 2;
kdt_config_["graphcefscale"] = 2;
kdt_config_["refineiterations"] = 0;
kdt_config_["cef"] = 1000;
kdt_config_["maxcheckforrefinegraph"] = 10000;
kdt_config_["numofthreads"] = 1;
kdt_config_["maxcheck"] = 8192;
kdt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3;
kdt_config_["numberofinitialdynamicpivots"] = 50;
kdt_config_["numberofotherdynamicpivots"] = 4;
bkt_config_ = std::make_shared<BKTCfg>();
bkt_config_->bktnumber = 1;
bkt_config_->bktkmeansk = 32;
bkt_config_->bktleafsize = 8;
bkt_config_->samples = 100;
bkt_config_->tptnumber = 1;
bkt_config_->tptleafsize = 2000;
bkt_config_->numtopdimensiontptsplit = 5;
bkt_config_->neighborhoodsize = 32;
bkt_config_->graphneighborhoodscale = 2;
bkt_config_->graphcefscale = 2;
bkt_config_->refineiterations = 0;
bkt_config_->cef = 1000;
bkt_config_->maxcheckforrefinegraph = 10000;
bkt_config_->numofthreads = 1;
bkt_config_->maxcheck = 8192;
bkt_config_->thresholdofnumberofcontinuousnobetterpropagation = 3;
bkt_config_->numberofinitialdynamicpivots = 50;
bkt_config_->numberofotherdynamicpivots = 4;
bkt_config_["bktnumber"] = 1;
bkt_config_["bktkmeansk"] = 32;
bkt_config_["bktleafsize"] = 8;
bkt_config_["samples"] = 100;
bkt_config_["tptnumber"] = 1;
bkt_config_["tptleafsize"] = 2000;
bkt_config_["numtopdimensiontptsplit"] = 5;
bkt_config_["neighborhoodsize"] = 32;
bkt_config_["graphneighborhoodscale"] = 2;
bkt_config_["graphcefscale"] = 2;
bkt_config_["refineiterations"] = 0;
bkt_config_["cef"] = 1000;
bkt_config_["maxcheckforrefinegraph"] = 10000;
bkt_config_["numofthreads"] = 1;
bkt_config_["maxcheck"] = 8192;
bkt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3;
bkt_config_["numberofinitialdynamicpivots"] = 50;
bkt_config_["numberofotherdynamicpivots"] = 4;
}
} // namespace knowhere

View File

@ -18,18 +18,15 @@
#include <SPTAG/AnnService/inc/Core/Common.h>
#include "IndexParameter.h"
#include "knowhere/common/Config.h"
namespace knowhere {
using KDTConfig = std::shared_ptr<KDTCfg>;
using BKTConfig = std::shared_ptr<BKTCfg>;
class SPTAGParameterMgr {
public:
const KDTConfig&
const Config&
GetKDTParameters();
const BKTConfig&
const Config&
GetBKTParameters();
public:
@ -48,8 +45,8 @@ class SPTAGParameterMgr {
SPTAGParameterMgr();
private:
KDTConfig kdt_config_;
BKTConfig bkt_config_;
Config kdt_config_;
Config bkt_config_;
};
} // namespace knowhere

View File

@ -29,16 +29,9 @@ namespace algo {
unsigned int seed = 100;
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric)
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
switch (metric) {
case METRICTYPE::L2:
distance_ = new DistanceL2;
break;
case METRICTYPE::IP:
distance_ = new DistanceIP;
break;
}
distance_ = new DistanceL2; // hardcode here
}
NsgIndex::~NsgIndex() {

View File

@ -13,6 +13,7 @@
#include <cstddef>
#include <mutex>
#include <string>
#include <vector>
#include <boost/dynamic_bitset.hpp>
@ -41,8 +42,8 @@ using Graph = std::vector<std::vector<node_t>>;
class NsgIndex {
public:
size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
METRICTYPE metric_type; // L2 | IP
size_t ntotal; // totabl nb of indexed vectors
std::string metric_type; // todo(linxj) IP
Distance* distance_;
float* ori_data_;
@ -62,7 +63,7 @@ class NsgIndex {
size_t out_degree;
public:
explicit NsgIndex(const size_t& dimension, const size_t& n, METRICTYPE metric = METRICTYPE::L2);
explicit NsgIndex(const size_t& dimension, const size_t& n, const std::string& metric = "L2");
NsgIndex() = default;

View File

@ -14,10 +14,6 @@
#include <omp.h>
#ifdef __SSE__
#include <immintrin.h>
#endif
#include <faiss/utils/utils.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/AuxIndexStructures.h>

View File

@ -30,7 +30,6 @@ set(util_srcs
${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/adapter/VectorAdapter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp
${INDEX_SOURCE_DIR}/unittest/utils.cpp

View File

@ -72,35 +72,32 @@ class ParamGenerator {
knowhere::Config
Gen(const ParameterType& type) {
if (type == ParameterType::ivf) {
auto tempconf = std::make_shared<knowhere::IVFCfg>();
tempconf->d = DIM;
tempconf->gpu_id = DEVICEID;
tempconf->nlist = 100;
tempconf->nprobe = 4;
tempconf->k = K;
tempconf->metric_type = knowhere::METRICTYPE::L2;
return tempconf;
return knowhere::Config{
{knowhere::meta::DIM, DIM},
{knowhere::meta::TOPK, K},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
{knowhere::meta::DEVICEID, DEVICEID},
};
} else if (type == ParameterType::ivfpq) {
auto tempconf = std::make_shared<knowhere::IVFPQCfg>();
tempconf->d = DIM;
tempconf->gpu_id = DEVICEID;
tempconf->nlist = 100;
tempconf->nprobe = 4;
tempconf->k = K;
tempconf->m = 4;
tempconf->nbits = 8;
tempconf->metric_type = knowhere::METRICTYPE::L2;
return tempconf;
return knowhere::Config{
{knowhere::meta::DIM, DIM},
{knowhere::meta::TOPK, K},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 4},
{knowhere::IndexParams::m, 4},
{knowhere::IndexParams::nbits, 8},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
{knowhere::meta::DEVICEID, DEVICEID},
};
} else if (type == ParameterType::ivfsq) {
auto tempconf = std::make_shared<knowhere::IVFSQCfg>();
tempconf->d = DIM;
tempconf->gpu_id = DEVICEID;
tempconf->nlist = 100;
tempconf->nprobe = 4;
tempconf->k = K;
tempconf->nbits = 8;
tempconf->metric_type = knowhere::METRICTYPE::L2;
return tempconf;
return knowhere::Config{
{knowhere::meta::DIM, DIM}, {knowhere::meta::TOPK, K},
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4},
{knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, knowhere::Metric::L2},
{knowhere::meta::DEVICEID, DEVICEID},
};
}
}
};

View File

@ -21,7 +21,7 @@ using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
@ -37,17 +37,17 @@ class BinaryIDMAPTest : public BinaryDataGen, public TestWithParam<knowhere::MET
};
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest,
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
knowhere::METRICTYPE::HAMMING));
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
ASSERT_TRUE(!xb.empty());
knowhere::METRICTYPE MetricType = GetParam();
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = MetricType;
std::string MetricType = GetParam();
knowhere::Config conf{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, k},
{knowhere::Metric::TYPE, MetricType},
};
index_->Train(conf);
index_->Add(base_dataset, conf);
@ -88,11 +88,12 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
reader(ret, bin->size);
};
knowhere::METRICTYPE MetricType = GetParam();
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = MetricType;
std::string MetricType = GetParam();
knowhere::Config conf{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, k},
{knowhere::Metric::TYPE, MetricType},
};
{
// serialize index

View File

@ -27,25 +27,24 @@ using ::testing::Combine;
using ::testing::TestWithParam;
using ::testing::Values;
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRICTYPE> {
class BinaryIVFTest : public BinaryDataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
knowhere::METRICTYPE MetricType = GetParam();
std::string MetricType = GetParam();
Init_with_binary_default();
// nb = 1000000;
// nq = 1000;
// k = 1000;
// Generate(DIM, NB, NQ);
index_ = std::make_shared<knowhere::BinaryIVF>();
auto x_conf = std::make_shared<knowhere::IVFBinCfg>();
x_conf->d = dim;
x_conf->k = k;
x_conf->metric_type = MetricType;
x_conf->nlist = 100;
x_conf->nprobe = 10;
conf = x_conf;
conf->Dump();
knowhere::Config temp_conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k},
{knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 10},
{knowhere::Metric::TYPE, MetricType},
};
conf = temp_conf;
}
void
@ -59,8 +58,7 @@ class BinaryIVFTest : public BinaryDataGen, public TestWithParam<knowhere::METRI
};
INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
Values(knowhere::METRICTYPE::JACCARD, knowhere::METRICTYPE::TANIMOTO,
knowhere::METRICTYPE::HAMMING));
Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING")));
TEST_P(BinaryIVFTest, binaryivf_basic) {
assert(!xb.empty());
@ -75,7 +73,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb);
@ -123,7 +121,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
// index_->set_index_model(model);
// index_->Add(base_dataset, conf);
// auto result = index_->Search(query_dataset, conf);
// AssertAnns(result, nq, conf->k);
// AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// }
{
@ -147,7 +145,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
}

View File

@ -69,7 +69,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
for (int i = 0; i < 3; ++i) {
auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf);
auto result = gpu_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
}
@ -86,30 +86,18 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
auto quantization = pair.second;
auto result = gpu_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
quantizer_conf->mode = 2; // only copy data
quantizer_conf->gpu_id = DEVICEID;
milvus::json quantizer_conf{{knowhere::meta::DEVICEID, DEVICEID}, {"mode", 2}};
for (int i = 0; i < 2; ++i) {
auto hybrid_idx = std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
hybrid_idx->Load(binaryset);
auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf);
auto result = new_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
}
{
// invalid quantizer config
quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
auto hybrid_idx = std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, nullptr));
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, quantizer_conf));
quantizer_conf->mode = 2; // only copy data
ASSERT_ANY_THROW(hybrid_idx->LoadData(quantization, quantizer_conf));
}
}
{
@ -126,7 +114,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
hybrid_idx->SetQuantizer(quantization);
auto result = hybrid_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
hybrid_idx->UnsetQuantizer();
}

View File

@ -45,10 +45,8 @@ class IDMAPTest : public DataGen, public TestGpuIndexBase {
TEST_F(IDMAPTest, idmap_basic) {
ASSERT_TRUE(!xb.empty());
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = knowhere::METRICTYPE::L2;
knowhere::Config conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
// null faiss index
{
@ -107,10 +105,8 @@ TEST_F(IDMAPTest, idmap_serialize) {
reader(ret, bin->size);
};
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = knowhere::METRICTYPE::L2;
knowhere::Config conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
{
// serialize index
@ -146,10 +142,8 @@ TEST_F(IDMAPTest, idmap_serialize) {
TEST_F(IDMAPTest, copy_test) {
ASSERT_TRUE(!xb.empty());
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = dim;
conf->k = k;
conf->metric_type = knowhere::METRICTYPE::L2;
knowhere::Config conf{
{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, k}, {knowhere::Metric::TYPE, knowhere::Metric::L2}};
index_->Train(conf);
index_->Add(base_dataset, conf);

View File

@ -62,7 +62,7 @@ class IVFTest : public DataGen, public TestWithParam<::std::tuple<std::string, P
Generate(DIM, NB, NQ);
index_ = IndexFactory(index_type);
conf = ParamGenerator::GetInstance().Gen(parameter_type_);
conf->Dump();
// KNOWHERE_LOG_DEBUG << "conf: " << conf->dump();
}
void
@ -109,7 +109,7 @@ TEST_P(IVFTest, ivf_basic) {
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
if (index_type.find("GPU") == std::string::npos && index_type.find("Hybrid") == std::string::npos &&
@ -190,7 +190,7 @@ TEST_P(IVFTest, ivf_serialize) {
index_->set_index_model(model);
index_->Add(base_dataset, conf);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
}
{
@ -214,7 +214,7 @@ TEST_P(IVFTest, ivf_serialize) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
}
}
@ -232,7 +232,7 @@ TEST_P(IVFTest, clone_test) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// PrintResult(result, nq, k);
auto AssertEqual = [&](knowhere::DatasetPtr p1, knowhere::DatasetPtr p2) {
@ -254,7 +254,7 @@ TEST_P(IVFTest, clone_test) {
// EXPECT_NO_THROW({
// auto clone_index = index_->Clone();
// auto clone_result = clone_index->Search(query_dataset, conf);
// //AssertAnns(result, nq, conf->k);
// //AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
// AssertEqual(result, clone_result);
// std::cout << "inplace clone [" << index_type << "] success" << std::endl;
// });
@ -339,7 +339,7 @@ TEST_P(IVFTest, gpu_seal_test) {
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
auto result = index_->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
AssertAnns(result, nq, conf[knowhere::meta::TOPK]);
fiu_init(0);
fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0);
@ -374,7 +374,7 @@ TEST_P(IVFTest, invalid_gpu_source) {
}
auto invalid_conf = ParamGenerator::GetInstance().Gen(parameter_type_);
invalid_conf->gpu_id = -1;
invalid_conf[knowhere::meta::DEVICEID] = -1;
if (index_type == "GPUIVF") {
// null faiss index
@ -430,15 +430,6 @@ TEST_P(IVFTest, IVFSQHybrid_test) {
ASSERT_TRUE(index != nullptr);
ASSERT_ANY_THROW(index->UnsetQuantizer());
knowhere::QuantizerConfig config = std::make_shared<knowhere::QuantizerCfg>();
config->gpu_id = knowhere::INVALID_VALUE;
// mode = -1
ASSERT_ANY_THROW(index->LoadQuantizer(config));
config->mode = 1;
ASSERT_ANY_THROW(index->LoadQuantizer(config));
config->gpu_id = DEVICEID;
// index->LoadQuantizer(config);
ASSERT_ANY_THROW(index->SetQuantizer(nullptr));
}

View File

@ -46,24 +46,19 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test {
Generate(256, 1000000 / 100, 1);
index_ = std::make_shared<knowhere::NSG>();
auto tmp_conf = std::make_shared<knowhere::NSGCfg>();
tmp_conf->gpu_id = DEVICEID;
tmp_conf->d = 256;
tmp_conf->knng = 20;
tmp_conf->nprobe = 8;
tmp_conf->nlist = 163;
tmp_conf->search_length = 40;
tmp_conf->out_degree = 30;
tmp_conf->candidate_pool_size = 100;
tmp_conf->metric_type = knowhere::METRICTYPE::L2;
train_conf = tmp_conf;
train_conf->Dump();
train_conf = knowhere::Config{{knowhere::meta::DIM, 256},
{knowhere::IndexParams::nlist, 163},
{knowhere::IndexParams::nprobe, 8},
{knowhere::IndexParams::knng, 20},
{knowhere::IndexParams::search_length, 40},
{knowhere::IndexParams::out_degree, 30},
{knowhere::IndexParams::candidate, 100},
{knowhere::Metric::TYPE, knowhere::Metric::L2}};
auto tmp2_conf = std::make_shared<knowhere::NSGCfg>();
tmp2_conf->k = k;
tmp2_conf->search_length = 30;
search_conf = tmp2_conf;
search_conf->Dump();
search_conf = knowhere::Config{
{knowhere::meta::TOPK, k},
{knowhere::IndexParams::search_length, 30},
};
}
void
@ -87,9 +82,9 @@ TEST_F(NSGInterfaceTest, basic_test) {
ASSERT_ANY_THROW(index_->Search(query_dataset, search_conf));
ASSERT_ANY_THROW(index_->Serialize());
}
train_conf->gpu_id = knowhere::INVALID_VALUE;
auto model_invalid_gpu = index_->Train(base_dataset, train_conf);
train_conf->gpu_id = DEVICEID;
// train_conf->gpu_id = knowhere::INVALID_VALUE;
// auto model_invalid_gpu = index_->Train(base_dataset, train_conf);
train_conf[knowhere::meta::DEVICEID] = DEVICEID;
auto model = index_->Train(base_dataset, train_conf);
auto result = index_->Search(query_dataset, search_conf);
AssertAnns(result, nq, k);

View File

@ -33,17 +33,17 @@ class SPTAGTest : public DataGen, public TestWithParam<std::string> {
Generate(128, 100, 5);
index_ = std::make_shared<knowhere::CPUSPTAGRNG>(IndexType);
if (IndexType == "KDT") {
auto tempconf = std::make_shared<knowhere::KDTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
tempconf->metric_type = knowhere::METRICTYPE::L2;
conf = tempconf;
conf = knowhere::Config{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, 10},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
};
} else {
auto tempconf = std::make_shared<knowhere::BKTCfg>();
tempconf->tptnumber = 1;
tempconf->k = 10;
tempconf->metric_type = knowhere::METRICTYPE::L2;
conf = tempconf;
conf = knowhere::Config{
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, 10},
{knowhere::Metric::TYPE, knowhere::Metric::L2},
};
}
Init_with_default();

View File

@ -16,9 +16,9 @@
namespace milvus {
namespace scheduler {
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
SearchJob::SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, const milvus::json& extra_params,
const engine::VectorsData& vectors)
: Job(JobType::SEARCH), context_(context), topk_(topk), nprobe_(nprobe), vectors_(vectors) {
: Job(JobType::SEARCH), context_(context), topk_(topk), extra_params_(extra_params), vectors_(vectors) {
}
bool
@ -72,7 +72,7 @@ SearchJob::Dump() const {
json ret{
{"topk", topk_},
{"nq", vectors_.vector_count_},
{"nprobe", nprobe_},
{"extra_params", extra_params_.dump()},
};
auto base = Job::Dump();
ret.insert(base.begin(), base.end());

View File

@ -40,7 +40,7 @@ using ResultDistances = engine::ResultDistances;
class SearchJob : public Job {
public:
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, uint64_t nprobe,
SearchJob(const std::shared_ptr<server::Context>& context, uint64_t topk, const milvus::json& extra_params,
const engine::VectorsData& vectors);
public:
@ -79,9 +79,9 @@ class SearchJob : public Job {
return vectors_.vector_count_;
}
uint64_t
nprobe() const {
return nprobe_;
const milvus::json&
extra_params() const {
return extra_params_;
}
const engine::VectorsData&
@ -103,7 +103,7 @@ class SearchJob : public Job {
const std::shared_ptr<server::Context> context_;
uint64_t topk_ = 0;
uint64_t nprobe_ = 0;
milvus::json extra_params_;
// TODO: smart pointer
const engine::VectorsData& vectors_;

View File

@ -41,8 +41,9 @@ XBuildIndexTask::XBuildIndexTask(TableFileSchemaPtr file, TaskLabelPtr label)
engine_type = (EngineType)file->engine_type_;
}
auto json = milvus::json::parse(file_->index_params_);
to_index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
(MetricType)file_->metric_type_, file_->nlist_);
(MetricType)file_->metric_type_, json);
}
}

View File

@ -115,8 +115,12 @@ XSearchTask::XSearchTask(const std::shared_ptr<server::Context>& context, TableF
engine_type = (EngineType)file->engine_type_;
}
milvus::json json_params;
if (!file_->index_params_.empty()) {
json_params = milvus::json::parse(file_->index_params_);
}
index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, engine_type,
(MetricType)file_->metric_type_, file_->nlist_);
(MetricType)file_->metric_type_, json_params);
}
}
@ -217,7 +221,8 @@ XSearchTask::Execute() {
// step 1: allocate memory
uint64_t nq = search_job->nq();
uint64_t topk = search_job->topk();
uint64_t nprobe = search_job->nprobe();
const milvus::json& extra_params = search_job->extra_params();
ENGINE_LOG_DEBUG << "Search job extra params: " << extra_params.dump();
const engine::VectorsData& vectors = search_job->vectors();
output_ids.resize(topk * nq);
@ -235,13 +240,13 @@ XSearchTask::Execute() {
}
Status s;
if (!vectors.float_data_.empty()) {
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, nprobe, output_distance.data(),
s = index_engine_->Search(nq, vectors.float_data_.data(), topk, extra_params, output_distance.data(),
output_ids.data(), hybrid);
} else if (!vectors.binary_data_.empty()) {
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, nprobe, output_distance.data(),
s = index_engine_->Search(nq, vectors.binary_data_.data(), topk, extra_params, output_distance.data(),
output_ids.data(), hybrid);
} else if (!vectors.id_array_.empty()) {
s = index_engine_->Search(nq, vectors.id_array_, topk, nprobe, output_distance.data(),
s = index_engine_->Search(nq, vectors.id_array_, topk, extra_params, output_distance.data(),
output_ids.data(), hybrid);
}

View File

@ -549,9 +549,14 @@ Config::UpdateFileConfigFromMem(const std::string& parent_key, const std::string
// convert value string to standard string stored in yaml file
std::string value_str;
if (child_key == CONFIG_CACHE_CACHE_INSERT_DATA || child_key == CONFIG_STORAGE_S3_ENABLE ||
child_key == CONFIG_METRIC_ENABLE_MONITOR || child_key == CONFIG_GPU_RESOURCE_ENABLE) {
value_str =
(value == "True" || value == "true" || value == "On" || value == "on" || value == "1") ? "true" : "false";
child_key == CONFIG_METRIC_ENABLE_MONITOR || child_key == CONFIG_GPU_RESOURCE_ENABLE ||
child_key == CONFIG_WAL_ENABLE || child_key == CONFIG_WAL_RECOVERY_ERROR_IGNORE) {
bool ok = false;
status = StringHelpFunctions::ConvertToBoolean(value, ok);
if (!status.ok()) {
return status;
}
value_str = ok ? "true" : "false";
} else if (child_key == CONFIG_GPU_RESOURCE_SEARCH_RESOURCES ||
child_key == CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES) {
std::vector<std::string> vec;
@ -593,7 +598,6 @@ Config::UpdateFileConfigFromMem(const std::string& parent_key, const std::string
}
// values of gpu resources are sequences, need to remove old here
std::regex reg("\\S*");
if (child_key == CONFIG_GPU_RESOURCE_SEARCH_RESOURCES || child_key == CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES) {
while (getline(conf_fin, line)) {
if (line.find("- gpu") != std::string::npos)
@ -1009,8 +1013,7 @@ Config::CheckCacheConfigCpuCacheCapacity(const std::string& value) {
std::cerr << "WARNING: cpu cache capacity value is too big" << std::endl;
}
std::string str =
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_INSERT_BUFFER_SIZE, CONFIG_CACHE_INSERT_BUFFER_SIZE_DEFAULT);
std::string str = GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_INSERT_BUFFER_SIZE, "0");
int64_t buffer_value = std::stoll(str);
int64_t insert_buffer_size = buffer_value * GB;
@ -1059,9 +1062,8 @@ Config::CheckCacheConfigInsertBufferSize(const std::string& value) {
return Status(SERVER_INVALID_ARGUMENT, msg);
}
std::string str =
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT);
int64_t cache_size = std::stoll(str);
std::string str = GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, "0");
int64_t cache_size = std::stoll(str) * GB;
uint64_t total_mem = 0, free_mem = 0;
CommonUtil::GetSystemMemInfo(total_mem, free_mem);
@ -1855,7 +1857,8 @@ Config::SetDBConfigBackendUrl(const std::string& value) {
Status
Config::SetDBConfigPreloadTable(const std::string& value) {
CONFIG_CHECK(CheckDBConfigPreloadTable(value));
return SetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRELOAD_TABLE, value);
std::string cor_value = value == "*" ? "\'*\'" : value;
return SetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRELOAD_TABLE, cor_value);
}
Status

View File

@ -70,8 +70,8 @@ RequestHandler::DropTable(const std::shared_ptr<Context>& context, const std::st
Status
RequestHandler::CreateIndex(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist) {
BaseRequestPtr request_ptr = CreateIndexRequest::Create(context, table_name, index_type, nlist);
const milvus::json& json_params) {
BaseRequestPtr request_ptr = CreateIndexRequest::Create(context, table_name, index_type, json_params);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
@ -123,11 +123,11 @@ RequestHandler::ShowTableInfo(const std::shared_ptr<Context>& context, const std
Status
RequestHandler::Search(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result) {
BaseRequestPtr request_ptr =
SearchRequest::Create(context, table_name, vectors, topk, nprobe, partition_list, file_id_list, result);
SearchRequest::Create(context, table_name, vectors, topk, extra_params, partition_list, file_id_list, result);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
@ -135,10 +135,10 @@ RequestHandler::Search(const std::shared_ptr<Context>& context, const std::strin
Status
RequestHandler::SearchByID(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
TopKQueryResult& result) {
int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, TopKQueryResult& result) {
BaseRequestPtr request_ptr =
SearchByIDRequest::Create(context, table_name, vector_id, topk, nprobe, partition_list, result);
SearchByIDRequest::Create(context, table_name, vector_id, topk, extra_params, partition_list, result);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();

View File

@ -38,7 +38,7 @@ class RequestHandler {
Status
CreateIndex(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist);
const milvus::json& json_params);
Status
Insert(const std::shared_ptr<Context>& context, const std::string& table_name, engine::VectorsData& vectors,
@ -60,12 +60,13 @@ class RequestHandler {
Status
Search(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
Status
SearchByID(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list, TopKQueryResult& result);
const milvus::json& extra_params, const std::vector<std::string>& partition_list,
TopKQueryResult& result);
Status
DescribeTable(const std::shared_ptr<Context>& context, const std::string& table_name, TableSchema& table_schema);

View File

@ -17,6 +17,7 @@
#include "grpc/gen-status/status.grpc.pb.h"
#include "grpc/gen-status/status.pb.h"
#include "server/context/Context.h"
#include "utils/Json.h"
#include "utils/Status.h"
#include <condition_variable>
@ -73,17 +74,15 @@ struct TopKQueryResult {
struct IndexParam {
std::string table_name_;
int64_t index_type_;
int64_t nlist_;
std::string extra_params_;
IndexParam() {
index_type_ = 0;
nlist_ = 0;
}
IndexParam(const std::string& table_name, int64_t index_type, int64_t nlist) {
IndexParam(const std::string& table_name, int64_t index_type) {
table_name_ = table_name;
index_type_ = index_type;
nlist_ = nlist;
}
};

View File

@ -24,14 +24,17 @@ namespace milvus {
namespace server {
CreateIndexRequest::CreateIndexRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
int64_t index_type, int64_t nlist)
: BaseRequest(context, DDL_DML_REQUEST_GROUP), table_name_(table_name), index_type_(index_type), nlist_(nlist) {
int64_t index_type, const milvus::json& json_params)
: BaseRequest(context, DDL_DML_REQUEST_GROUP),
table_name_(table_name),
index_type_(index_type),
json_params_(json_params) {
}
BaseRequestPtr
CreateIndexRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist) {
return std::shared_ptr<BaseRequest>(new CreateIndexRequest(context, table_name, index_type, nlist));
const milvus::json& json_params) {
return std::shared_ptr<BaseRequest>(new CreateIndexRequest(context, table_name, index_type, json_params));
}
Status
@ -69,7 +72,7 @@ CreateIndexRequest::OnExecute() {
return status;
}
status = ValidationUtil::ValidateTableIndexNlist(nlist_);
status = ValidationUtil::ValidateIndexParams(json_params_, table_schema, index_type_);
if (!status.ok()) {
return status;
}
@ -109,7 +112,7 @@ CreateIndexRequest::OnExecute() {
// step 3: create index
engine::TableIndex index;
index.engine_type_ = adapter_index_type;
index.nlist_ = nlist_;
index.extra_params_ = json_params_;
status = DBWrapper::DB()->CreateIndex(table_name_, index);
fiu_do_on("CreateIndexRequest.OnExecute.create_index_fail",
status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));

View File

@ -21,11 +21,12 @@ namespace server {
class CreateIndexRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type, int64_t nlist);
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
const milvus::json& json_params);
protected:
CreateIndexRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t index_type,
int64_t nlist);
const milvus::json& json_params);
Status
OnExecute() override;
@ -33,7 +34,7 @@ class CreateIndexRequest : public BaseRequest {
private:
const std::string table_name_;
const int64_t index_type_;
const int64_t nlist_;
milvus::json json_params_;
};
} // namespace server

View File

@ -78,7 +78,7 @@ DescribeIndexRequest::OnExecute() {
index_param_.table_name_ = table_name_;
index_param_.index_type_ = index.engine_type_;
index_param_.nlist_ = index.nlist_;
index_param_.extra_params_ = index.extra_params_.dump();
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}

View File

@ -34,23 +34,23 @@ namespace milvus {
namespace server {
SearchByIDRequest::SearchByIDRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
int64_t vector_id, int64_t topk, int64_t nprobe,
int64_t vector_id, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, TopKQueryResult& result)
: BaseRequest(context, DQL_REQUEST_GROUP),
table_name_(table_name),
vector_id_(vector_id),
topk_(topk),
nprobe_(nprobe),
extra_params_(extra_params),
partition_list_(partition_list),
result_(result) {
}
BaseRequestPtr
SearchByIDRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
TopKQueryResult& result) {
int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, TopKQueryResult& result) {
return std::shared_ptr<BaseRequest>(
new SearchByIDRequest(context, table_name, vector_id, topk, nprobe, partition_list, result));
new SearchByIDRequest(context, table_name, vector_id, topk, extra_params, partition_list, result));
}
Status
@ -59,7 +59,7 @@ SearchByIDRequest::OnExecute() {
auto pre_query_ctx = context_->Child("Pre query");
std::string hdr = "SearchByIDRequest(table=" + table_name_ + ", id=" + std::to_string(vector_id_) +
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
", k=" + std::to_string(topk_) + ", extra_params=" + extra_params_.dump() + ")";
TimeRecorder rc(hdr);
@ -88,6 +88,11 @@ SearchByIDRequest::OnExecute() {
}
}
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
if (!status.ok()) {
return status;
}
// Check whether GPU search resource is enabled
#ifdef MILVUS_GPU_VERSION
Config& config = Config::GetInstance();
@ -122,11 +127,6 @@ SearchByIDRequest::OnExecute() {
return status;
}
status = ValidationUtil::ValidateSearchNprobe(nprobe_, table_schema);
if (!status.ok()) {
return status;
}
rc.RecordSection("check validation");
// step 5: search vectors
@ -140,8 +140,8 @@ SearchByIDRequest::OnExecute() {
pre_query_ctx->GetTraceContext()->GetSpan()->Finish();
status = DBWrapper::DB()->QueryByID(context_, table_name_, partition_list_, (size_t)topk_, nprobe_, vector_id_,
result_ids, result_distances);
status = DBWrapper::DB()->QueryByID(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
vector_id_, result_ids, result_distances);
#ifdef MILVUS_ENABLE_PROFILING
ProfilerStop();

View File

@ -30,11 +30,11 @@ class SearchByIDRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id, int64_t topk,
int64_t nprobe, const std::vector<std::string>& partition_list, TopKQueryResult& result);
const milvus::json& extra_params, const std::vector<std::string>& partition_list, TopKQueryResult& result);
protected:
SearchByIDRequest(const std::shared_ptr<Context>& context, const std::string& table_name, int64_t vector_id,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
TopKQueryResult& result);
Status
@ -44,7 +44,7 @@ class SearchByIDRequest : public BaseRequest {
const std::string table_name_;
const int64_t vector_id_;
int64_t topk_;
int64_t nprobe_;
milvus::json extra_params_;
const std::vector<std::string> partition_list_;
TopKQueryResult& result_;

View File

@ -26,14 +26,14 @@ namespace milvus {
namespace server {
SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result)
: BaseRequest(context, DQL_REQUEST_GROUP),
table_name_(table_name),
vectors_data_(vectors),
topk_(topk),
nprobe_(nprobe),
extra_params_(extra_params),
partition_list_(partition_list),
file_id_list_(file_id_list),
result_(result) {
@ -41,11 +41,11 @@ SearchRequest::SearchRequest(const std::shared_ptr<Context>& context, const std:
BaseRequestPtr
SearchRequest::Create(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result) {
return std::shared_ptr<BaseRequest>(
new SearchRequest(context, table_name, vectors, topk, nprobe, partition_list, file_id_list, result));
new SearchRequest(context, table_name, vectors, topk, extra_params, partition_list, file_id_list, result));
}
Status
@ -56,7 +56,7 @@ SearchRequest::OnExecute() {
auto pre_query_ctx = context_->Child("Pre query");
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
", k=" + std::to_string(topk_) + ", nprob=" + std::to_string(nprobe_) + ")";
", k=" + std::to_string(topk_) + ", extra_params=" + extra_params_.dump() + ")";
TimeRecorder rc(hdr);
@ -84,13 +84,13 @@ SearchRequest::OnExecute() {
}
}
// step 3: check search parameter
status = ValidationUtil::ValidateSearchTopk(topk_, table_schema);
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
if (!status.ok()) {
return status;
}
status = ValidationUtil::ValidateSearchNprobe(nprobe_, table_schema);
// step 3: check search parameter
status = ValidationUtil::ValidateSearchTopk(topk_, table_schema);
if (!status.ok()) {
return status;
}
@ -150,10 +150,10 @@ SearchRequest::OnExecute() {
return status;
}
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, nprobe_,
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
vectors_data_, result_ids, result_distances);
} else {
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, nprobe_,
status = DBWrapper::DB()->QueryByFileID(context_, table_name_, file_id_list_, (size_t)topk_, extra_params_,
vectors_data_, result_ids, result_distances);
}

View File

@ -24,12 +24,12 @@ class SearchRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<Context>& context, const std::string& table_name, const engine::VectorsData& vectors,
int64_t topk, int64_t nprobe, const std::vector<std::string>& partition_list,
int64_t topk, const milvus::json& extra_params, const std::vector<std::string>& partition_list,
const std::vector<std::string>& file_id_list, TopKQueryResult& result);
protected:
SearchRequest(const std::shared_ptr<Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, int64_t nprobe,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result);
@ -40,7 +40,7 @@ class SearchRequest : public BaseRequest {
const std::string table_name_;
const engine::VectorsData& vectors_data_;
int64_t topk_;
int64_t nprobe_;
milvus::json extra_params_;
const std::vector<std::string> partition_list_;
const std::vector<std::string> file_id_list_;

View File

@ -280,8 +280,16 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.CreateIndex(context_map_[context], request->table_name(),
request->index().index_type(), request->index().nlist());
milvus::json json_params;
for (int i = 0; i < request->extra_params_size(); i++) {
const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
if (extra.key() == EXTRA_PARAM_KEY) {
json_params = json::parse(extra.value());
}
}
Status status =
request_handler_.CreateIndex(context_map_[context], request->table_name(), request->index_type(), json_params);
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
@ -366,14 +374,23 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
partitions.emplace_back(partition);
}
// step 3: search vectors
// step 3: parse extra parameters
milvus::json json_params;
for (int i = 0; i < request->extra_params_size(); i++) {
const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
if (extra.key() == EXTRA_PARAM_KEY) {
json_params = json::parse(extra.value());
}
}
// step 4: search vectors
std::vector<std::string> file_ids;
TopKQueryResult result;
fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id"));
Status status = request_handler_.Search(context_map_[context], request->table_name(), vectors, request->topk(),
request->nprobe(), partitions, file_ids, result);
json_params, partitions, file_ids, result);
// step 4: construct and return result
// step 5: construct and return result
ConstructResults(result, response);
SET_RESPONSE(response->mutable_status(), status, context);
@ -392,12 +409,21 @@ GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::g
partitions.emplace_back(partition);
}
// step 2: search vectors
// step 2: parse extra parameters
milvus::json json_params;
for (int i = 0; i < request->extra_params_size(); i++) {
const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
if (extra.key() == EXTRA_PARAM_KEY) {
json_params = json::parse(extra.value());
}
}
// step 3: search vectors
TopKQueryResult result;
Status status = request_handler_.SearchByID(context_map_[context], request->table_name(), request->id(),
request->topk(), request->nprobe(), partitions, result);
request->topk(), json_params, partitions, result);
// step 3: construct and return result
// step 4: construct and return result
ConstructResults(result, response);
SET_RESPONSE(response->mutable_status(), status, context);
@ -429,13 +455,21 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus
partitions.emplace_back(partition);
}
// step 4: search vectors
TopKQueryResult result;
Status status =
request_handler_.Search(context_map_[context], search_request->table_name(), vectors, search_request->topk(),
search_request->nprobe(), partitions, file_ids, result);
// step 4: parse extra parameters
milvus::json json_params;
for (int i = 0; i < search_request->extra_params_size(); i++) {
const ::milvus::grpc::KeyValuePair& extra = search_request->extra_params(i);
if (extra.key() == EXTRA_PARAM_KEY) {
json_params = json::parse(extra.value());
}
}
// step 5: construct and return result
// step 5: search vectors
TopKQueryResult result;
Status status = request_handler_.Search(context_map_[context], search_request->table_name(), vectors,
search_request->topk(), json_params, partitions, file_ids, result);
// step 6: construct and return result
ConstructResults(result, response);
SET_RESPONSE(response->mutable_status(), status, context);
@ -549,8 +583,10 @@ GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus
IndexParam param;
Status status = request_handler_.DescribeIndex(context_map_[context], request->table_name(), param);
response->set_table_name(param.table_name_);
response->mutable_index()->set_index_type(param.index_type_);
response->mutable_index()->set_nlist(param.nlist_);
response->set_index_type(param.index_type_);
::milvus::grpc::KeyValuePair* kv = response->add_extra_params();
kv->set_key(EXTRA_PARAM_KEY);
kv->set_value(param.extra_params_);
SET_RESPONSE(response->mutable_status(), status, context);
return ::grpc::Status::OK;

View File

@ -56,6 +56,8 @@ namespace grpc {
::milvus::grpc::ErrorCode
ErrorMap(ErrorCode code);
static const char* EXTRA_PARAM_KEY = "params";
class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service, public GrpcInterceptorHookHandler {
public:
explicit GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer);

View File

@ -208,14 +208,14 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(TablesOptions)
ENDPOINT("OPTIONS", "/tables", TablesOptions) {
ENDPOINT("OPTIONS", "/collections", TablesOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_CORS(CreateTable)
ENDPOINT("POST", "/tables", CreateTable, BODY_DTO(TableRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables\'");
ENDPOINT("POST", "/collections", CreateTable, BODY_DTO(TableRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
@ -238,8 +238,8 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(ShowTables)
ENDPOINT("GET", "/tables", ShowTables, QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables\'");
ENDPOINT("GET", "/collections", ShowTables, QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
@ -265,21 +265,21 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(TableOptions)
ENDPOINT("OPTIONS", "/tables/{table_name}", TableOptions) {
ENDPOINT("OPTIONS", "/collections/{collection_name}", TableOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_CORS(GetTable)
ENDPOINT("GET", "/tables/{table_name}", GetTable,
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables/" + table_name->std_str() + "\'");
ENDPOINT("GET", "/collections/{collection_name}", GetTable,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
String response_str;
auto status_dto = handler.GetTable(table_name, query_params, response_str);
auto status_dto = handler.GetTable(collection_name, query_params, response_str);
std::shared_ptr<OutgoingResponse> response;
switch (status_dto->code->getValue()) {
@ -302,14 +302,14 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(DropTable)
ENDPOINT("DELETE", "/tables/{table_name}", DropTable, PATH(String, table_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/tables/" + table_name->std_str() + "\'");
ENDPOINT("DELETE", "/collections/{collection_name}", DropTable, PATH(String, collection_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropTable(table_name);
auto status_dto = handler.DropTable(collection_name);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
@ -330,21 +330,21 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(IndexOptions)
ENDPOINT("OPTIONS", "/tables/{table_name}/indexes", IndexOptions) {
ENDPOINT("OPTIONS", "/collections/{collection_name}/indexes", IndexOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_CORS(CreateIndex)
ENDPOINT("POST", "/tables/{table_name}/indexes", CreateIndex,
PATH(String, table_name), BODY_DTO(IndexRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + table_name->std_str() + "/indexes\'");
ENDPOINT("POST", "/tables/{collection_name}/indexes", CreateIndex,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreateIndex(table_name, body);
auto status_dto = handler.CreateIndex(collection_name, body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
@ -365,19 +365,19 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(GetIndex)
ENDPOINT("GET", "/tables/{table_name}/indexes", GetIndex, PATH(String, table_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables/" + table_name->std_str() + "/indexes\'");
ENDPOINT("GET", "/collections/{collection_name}/indexes", GetIndex, PATH(String, collection_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/indexes\'");
tr.RecordSection("Received request.");
auto index_dto = IndexDto::createShared();
auto handler = WebRequestHandler();
auto status_dto = handler.GetIndex(table_name, index_dto);
OString result;
auto status_dto = handler.GetIndex(collection_name, result);
std::shared_ptr<OutgoingResponse> response;
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, index_dto);
response = createResponse(Status::CODE_200, result);
break;
case StatusCode::TABLE_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
@ -395,14 +395,14 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(DropIndex)
ENDPOINT("DELETE", "/tables/{table_name}/indexes", DropIndex, PATH(String, table_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/tables/" + table_name->std_str() + "/indexes\'");
ENDPOINT("DELETE", "/collections/{collection_name}/indexes", DropIndex, PATH(String, collection_name)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "/indexes\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropIndex(table_name);
auto status_dto = handler.DropIndex(collection_name);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
@ -423,21 +423,21 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(PartitionsOptions)
ENDPOINT("OPTIONS", "/tables/{table_name}/partitions", PartitionsOptions) {
ENDPOINT("OPTIONS", "/collections/{collection_name}/partitions", PartitionsOptions) {
return createResponse(Status::CODE_204, "No Content");
}
ADD_CORS(CreatePartition)
ENDPOINT("POST", "/tables/{table_name}/partitions",
CreatePartition, PATH(String, table_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + table_name->std_str() + "/partitions\'");
ENDPOINT("POST", "/collections/{collection_name}/partitions",
CreatePartition, PATH(String, collection_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/partitions\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.CreatePartition(table_name, body);
auto status_dto = handler.CreatePartition(collection_name, body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, status_dto);
@ -457,9 +457,9 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(ShowPartitions)
ENDPOINT("GET", "/tables/{table_name}/partitions", ShowPartitions,
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/tables/" + table_name->std_str() + "/partitions\'");
ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/partitions\'");
tr.RecordSection("Received request.");
auto offset = query_params.get("offset");
@ -469,7 +469,7 @@ class WebController : public oatpp::web::server::api::ApiController {
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.ShowPartitions(table_name, query_params, partition_list_dto);
auto status_dto = handler.ShowPartitions(collection_name, query_params, partition_list_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_200, partition_list_dto);
@ -488,16 +488,16 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(DropPartition)
ENDPOINT("DELETE", "/tables/{table_name}/partitions", DropPartition,
PATH(String, table_name), BODY_STRING(String, body)) {
ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) +
"DELETE \'/tables/" + table_name->std_str() + "/partitions\'");
"DELETE \'/collections/" + collection_name->std_str() + "/partitions\'");
tr.RecordSection("Received request.");
auto handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.DropPartition(table_name, body);
auto status_dto = handler.DropPartition(collection_name, body);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_204, status_dto);
@ -517,14 +517,14 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(ShowSegments)
ENDPOINT("GET", "/tables/{table_name}/segments", ShowSegments,
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
auto offset = query_params.get("offset");
auto page_size = query_params.get("page_size");
auto handler = WebRequestHandler();
String response;
auto status_dto = handler.ShowSegments(table_name, query_params, response);
auto status_dto = handler.ShowSegments(collection_name, query_params, response);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
@ -541,14 +541,14 @@ class WebController : public oatpp::web::server::api::ApiController {
*
* GetSegmentVector
*/
ENDPOINT("GET", "/tables/{table_name}/segments/{segment_name}/{info}", GetSegmentInfo,
PATH(String, table_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}/segments/{segment_name}/{info}", GetSegmentInfo,
PATH(String, collection_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) {
auto offset = query_params.get("offset");
auto page_size = query_params.get("page_size");
auto handler = WebRequestHandler();
String response;
auto status_dto = handler.GetSegmentInfo(table_name, segment_name, info, query_params, response);
auto status_dto = handler.GetSegmentInfo(collection_name, segment_name, info, query_params, response);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
@ -562,7 +562,7 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(VectorsOptions)
ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) {
ENDPOINT("OPTIONS", "/collections/{collection_name}/vectors", VectorsOptions) {
return createResponse(Status::CODE_204, "No Content");
}
@ -571,11 +571,11 @@ class WebController : public oatpp::web::server::api::ApiController {
*
* GetVectorByID ?id=
*/
ENDPOINT("GET", "/tables/{table_name}/vectors", GetVectors,
PATH(String, table_name), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
auto handler = WebRequestHandler();
String response;
auto status_dto = handler.GetVector(table_name, query_params, response);
auto status_dto = handler.GetVector(collection_name, query_params, response);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
@ -589,16 +589,16 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(Insert)
ENDPOINT("POST", "/tables/{table_name}/vectors", Insert,
PATH(String, table_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + table_name->std_str() + "/vectors\'");
ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/vectors\'");
tr.RecordSection("Received request.");
auto ids_dto = VectorIdsDto::createShared();
WebRequestHandler handler = WebRequestHandler();
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.Insert(table_name, body, ids_dto);
auto status_dto = handler.Insert(collection_name, body, ids_dto);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createDtoResponse(Status::CODE_201, ids_dto);
@ -621,16 +621,16 @@ class WebController : public oatpp::web::server::api::ApiController {
* Search
* Delete by ID
* */
ENDPOINT("PUT", "/tables/{table_name}/vectors", VectorsOp,
PATH(String, table_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/tables/" + table_name->std_str() + "/vectors\'");
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp,
PATH(String, collection_name), BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + "/vectors\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
OString result;
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.VectorsOp(table_name, body, result);
auto status_dto = handler.VectorsOp(collection_name, body, result);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createResponse(Status::CODE_200, result);
@ -674,11 +674,6 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(SystemOp)
/**
* Load
* Compact
* Flush
*/
ENDPOINT("PUT", "/system/{Op}", SystemOp, PATH(String, Op), BODY_STRING(String, body_str)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/system/" + Op->std_str() + "\'");
tr.RecordSection("Received request.");

View File

@ -26,7 +26,7 @@ class IndexRequestDto : public oatpp::data::mapping::type::Object {
DTO_FIELD(String, index_type) = VALUE_INDEX_INDEX_TYPE_DEFAULT;
DTO_FIELD(Int64, nlist) = VALUE_INDEX_NLIST_DEFAULT;
DTO_FIELD(String, params) = VALUE_INDEX_NLIST_DEFAULT;
};
using IndexDto = IndexRequestDto;

View File

@ -25,7 +25,7 @@ namespace web {
class TableRequestDto : public oatpp::data::mapping::type::Object {
DTO_INIT(TableRequestDto, Object)
DTO_FIELD(String, table_name, "table_name");
DTO_FIELD(String, collection_name, "collection_name");
DTO_FIELD(Int64, dimension, "dimension");
DTO_FIELD(Int64, index_file_size, "index_file_size") = VALUE_TABLE_INDEX_FILE_SIZE_DEFAULT;
DTO_FIELD(String, metric_type, "metric_type") = VALUE_TABLE_METRIC_TYPE_DEFAULT;
@ -34,25 +34,25 @@ class TableRequestDto : public oatpp::data::mapping::type::Object {
class TableFieldsDto : public oatpp::data::mapping::type::Object {
DTO_INIT(TableFieldsDto, Object)
DTO_FIELD(String, table_name);
DTO_FIELD(String, collection_name);
DTO_FIELD(Int64, dimension);
DTO_FIELD(Int64, index_file_size);
DTO_FIELD(String, metric_type);
DTO_FIELD(Int64, count);
DTO_FIELD(String, index);
DTO_FIELD(Int64, nlist);
DTO_FIELD(String, index_params);
};
class TableListDto : public OObject {
DTO_INIT(TableListDto, Object)
DTO_FIELD(List<String>::ObjectWrapper, table_names);
DTO_FIELD(List<String>::ObjectWrapper, collection_names);
};
class TableListFieldsDto : public OObject {
DTO_INIT(TableListFieldsDto, Object)
DTO_FIELD(List<TableFieldsDto::ObjectWrapper>::ObjectWrapper, tables);
DTO_FIELD(List<TableFieldsDto::ObjectWrapper>::ObjectWrapper, collections);
DTO_FIELD(Int64, count) = 0;
};

View File

@ -20,26 +20,6 @@ namespace web {
#include OATPP_CODEGEN_BEGIN(DTO)
class SearchRequestDto : public OObject {
DTO_INIT(SearchRequestDto, Object)
DTO_FIELD(Int64, topk);
DTO_FIELD(Int64, nprobe);
DTO_FIELD(List<String>::ObjectWrapper, tags);
DTO_FIELD(List<String>::ObjectWrapper, file_ids);
DTO_FIELD(List<List<Float32>::ObjectWrapper>::ObjectWrapper, records);
DTO_FIELD(List<List<Int64>::ObjectWrapper>::ObjectWrapper, records_bin);
};
class InsertRequestDto : public oatpp::data::mapping::type::Object {
DTO_INIT(InsertRequestDto, Object)
DTO_FIELD(String, tag) = VALUE_PARTITION_TAG_DEFAULT;
DTO_FIELD(List<List<Float32>::ObjectWrapper>::ObjectWrapper, records);
DTO_FIELD(List<List<Int64>::ObjectWrapper>::ObjectWrapper, records_bin);
DTO_FIELD(List<Int64>::ObjectWrapper, ids);
};
class VectorIdsDto : public oatpp::data::mapping::type::Object {
DTO_INIT(VectorIdsDto, Object)

View File

@ -109,6 +109,26 @@ WebRequestHandler::ParseQueryStr(const OQueryParams& query_params, const std::st
return Status::OK();
}
Status
WebRequestHandler::ParseQueryBool(const OQueryParams& query_params, const std::string& key, bool& value,
bool nullable) {
auto query = query_params.get(key.c_str());
if (nullptr != query.get() && query->getSize() > 0) {
std::string value_str = query->std_str();
if (!ValidationUtil::ValidateStringIsBool(value_str).ok()) {
return Status(ILLEGAL_QUERY_PARAM, "Query param \'all_required\' must be a bool");
}
value = value_str == "True" || value_str == "true";
return Status::OK();
}
if (!nullable) {
return Status(QUERY_PARAM_LOSS, "Query param \"" + key + "\" is required");
}
return Status::OK();
}
void
WebRequestHandler::AddStatusToJson(nlohmann::json& json, int64_t code, const std::string& msg) {
json["code"] = (int64_t)code;
@ -142,9 +162,9 @@ WebRequestHandler::ParsePartitionStat(const milvus::server::PartitionStat& par_s
}
Status
WebRequestHandler::IsBinaryTable(const std::string& table_name, bool& bin) {
WebRequestHandler::IsBinaryTable(const std::string& collection_name, bool& bin) {
TableSchema schema;
auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema);
auto status = request_handler_.DescribeTable(context_ptr_, collection_name, schema);
if (status.ok()) {
auto metric = engine::MetricType(schema.metric_type_);
bin = engine::MetricType::HAMMING == metric || engine::MetricType::JACCARD == metric ||
@ -187,30 +207,30 @@ WebRequestHandler::CopyRecordsFromJson(const nlohmann::json& json, engine::Vecto
///////////////////////// WebRequestHandler methods ///////////////////////////////////////
Status
WebRequestHandler::GetTableMetaInfo(const std::string& table_name, nlohmann::json& json_out) {
WebRequestHandler::GetTableMetaInfo(const std::string& collection_name, nlohmann::json& json_out) {
TableSchema schema;
auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema);
auto status = request_handler_.DescribeTable(context_ptr_, collection_name, schema);
if (!status.ok()) {
return status;
}
int64_t count;
status = request_handler_.CountTable(context_ptr_, table_name, count);
status = request_handler_.CountTable(context_ptr_, collection_name, count);
if (!status.ok()) {
return status;
}
IndexParam index_param;
status = request_handler_.DescribeIndex(context_ptr_, table_name, index_param);
status = request_handler_.DescribeIndex(context_ptr_, collection_name, index_param);
if (!status.ok()) {
return status;
}
json_out["table_name"] = schema.table_name_;
json_out["collection_name"] = schema.table_name_;
json_out["dimension"] = schema.dimension_;
json_out["index_file_size"] = schema.index_file_size_;
json_out["index"] = IndexMap.at(engine::EngineType(index_param.index_type_));
json_out["nlist"] = index_param.nlist_;
json_out["index_params"] = index_param.extra_params_;
json_out["metric_type"] = MetricMap.at(engine::MetricType(schema.metric_type_));
json_out["count"] = count;
@ -218,15 +238,15 @@ WebRequestHandler::GetTableMetaInfo(const std::string& table_name, nlohmann::jso
}
Status
WebRequestHandler::GetTableStat(const std::string& table_name, nlohmann::json& json_out) {
struct TableInfo table_info;
auto status = request_handler_.ShowTableInfo(context_ptr_, table_name, table_info);
WebRequestHandler::GetTableStat(const std::string& collection_name, nlohmann::json& json_out) {
struct TableInfo collection_info;
auto status = request_handler_.ShowTableInfo(context_ptr_, collection_name, collection_info);
if (status.ok()) {
json_out["count"] = table_info.total_row_num_;
json_out["count"] = collection_info.total_row_num_;
std::vector<nlohmann::json> par_stat_json;
for (auto& par : table_info.partitions_stat_) {
for (auto& par : collection_info.partitions_stat_) {
nlohmann::json par_json;
ParsePartitionStat(par, par_json);
par_stat_json.push_back(par_json);
@ -238,10 +258,10 @@ WebRequestHandler::GetTableStat(const std::string& table_name, nlohmann::json& j
}
Status
WebRequestHandler::GetSegmentVectors(const std::string& table_name, const std::string& segment_name, int64_t page_size,
int64_t offset, nlohmann::json& json_out) {
WebRequestHandler::GetSegmentVectors(const std::string& collection_name, const std::string& segment_name,
int64_t page_size, int64_t offset, nlohmann::json& json_out) {
std::vector<int64_t> vector_ids;
auto status = request_handler_.GetVectorIDs(context_ptr_, table_name, segment_name, vector_ids);
auto status = request_handler_.GetVectorIDs(context_ptr_, collection_name, segment_name, vector_ids);
if (!status.ok()) {
return status;
}
@ -251,7 +271,7 @@ WebRequestHandler::GetSegmentVectors(const std::string& table_name, const std::s
auto ids = std::vector<int64_t>(vector_ids.begin() + ids_begin, vector_ids.begin() + ids_end);
nlohmann::json vectors_json;
status = GetVectorsByIDs(table_name, ids, vectors_json);
status = GetVectorsByIDs(collection_name, ids, vectors_json);
nlohmann::json result_json;
if (vectors_json.empty()) {
@ -267,10 +287,10 @@ WebRequestHandler::GetSegmentVectors(const std::string& table_name, const std::s
}
Status
WebRequestHandler::GetSegmentIds(const std::string& table_name, const std::string& segment_name, int64_t page_size,
WebRequestHandler::GetSegmentIds(const std::string& collection_name, const std::string& segment_name, int64_t page_size,
int64_t offset, nlohmann::json& json_out) {
std::vector<int64_t> vector_ids;
auto status = request_handler_.GetVectorIDs(context_ptr_, table_name, segment_name, vector_ids);
auto status = request_handler_.GetVectorIDs(context_ptr_, collection_name, segment_name, vector_ids);
if (status.ok()) {
auto ids_begin = std::min(vector_ids.size(), (size_t)offset);
auto ids_end = std::min(vector_ids.size(), (size_t)(offset + page_size));
@ -310,12 +330,12 @@ WebRequestHandler::Cmd(const std::string& cmd, std::string& result_str) {
Status
WebRequestHandler::PreLoadTable(const nlohmann::json& json, std::string& result_str) {
if (!json.contains("table_name")) {
return Status(BODY_FIELD_LOSS, "Field \"load\" must contains table_name");
if (!json.contains("collection_name")) {
return Status(BODY_FIELD_LOSS, "Field \"load\" must contains collection_name");
}
auto table_name = json["table_name"];
auto status = request_handler_.PreloadTable(context_ptr_, table_name.get<std::string>());
auto collection_name = json["collection_name"];
auto status = request_handler_.PreloadTable(context_ptr_, collection_name.get<std::string>());
if (status.ok()) {
nlohmann::json result;
AddStatusToJson(result, status.code(), status.message());
@ -327,17 +347,17 @@ WebRequestHandler::PreLoadTable(const nlohmann::json& json, std::string& result_
Status
WebRequestHandler::Flush(const nlohmann::json& json, std::string& result_str) {
if (!json.contains("table_names")) {
return Status(BODY_FIELD_LOSS, "Field \"flush\" must contains table_names");
if (!json.contains("collection_names")) {
return Status(BODY_FIELD_LOSS, "Field \"flush\" must contains collection_names");
}
auto table_names = json["table_names"];
if (!table_names.is_array()) {
return Status(BODY_FIELD_LOSS, "Field \"table_names\" must be and array");
auto collection_names = json["collection_names"];
if (!collection_names.is_array()) {
return Status(BODY_FIELD_LOSS, "Field \"collection_names\" must be and array");
}
std::vector<std::string> names;
for (auto& name : table_names) {
for (auto& name : collection_names) {
names.emplace_back(name.get<std::string>());
}
@ -353,16 +373,16 @@ WebRequestHandler::Flush(const nlohmann::json& json, std::string& result_str) {
Status
WebRequestHandler::Compact(const nlohmann::json& json, std::string& result_str) {
if (!json.contains("table_name")) {
return Status(BODY_FIELD_LOSS, "Field \"compact\" must contains table_names");
if (!json.contains("collection_name")) {
return Status(BODY_FIELD_LOSS, "Field \"compact\" must contains collection_names");
}
auto table_name = json["table_name"];
if (!table_name.is_string()) {
return Status(BODY_FIELD_LOSS, "Field \"table_names\" must be a string");
auto collection_name = json["collection_name"];
if (!collection_name.is_string()) {
return Status(BODY_FIELD_LOSS, "Field \"collection_names\" must be a string");
}
auto name = table_name.get<std::string>();
auto name = collection_name.get<std::string>();
auto status = request_handler_.Compact(context_ptr_, name);
@ -463,17 +483,12 @@ WebRequestHandler::SetConfig(const nlohmann::json& json, std::string& result_str
}
Status
WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& json, std::string& result_str) {
WebRequestHandler::Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str) {
if (!json.contains("topk")) {
return Status(BODY_FIELD_LOSS, "Field \'topk\' is required");
}
int64_t topk = json["topk"];
if (!json.contains("nprobe")) {
return Status(BODY_FIELD_LOSS, "Field \'nprobe\' is required");
}
int64_t nprobe = json["nprobe"];
std::vector<std::string> partition_tags;
if (json.contains("partition_tags")) {
auto tags = json["partition_tags"];
@ -497,8 +512,12 @@ WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& j
}
}
if (!json.contains("params")) {
return Status(BODY_FIELD_LOSS, "Field \'params\' is required");
}
bool bin_flag = false;
auto status = IsBinaryTable(table_name, bin_flag);
auto status = IsBinaryTable(collection_name, bin_flag);
if (!status.ok()) {
return status;
}
@ -514,8 +533,9 @@ WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& j
}
TopKQueryResult result;
status = request_handler_.Search(context_ptr_, table_name, vectors_data, topk, nprobe, partition_tags, file_id_vec,
result);
status = request_handler_.Search(context_ptr_, collection_name, vectors_data, topk, json["params"], partition_tags,
file_id_vec, result);
if (!status.ok()) {
return status;
}
@ -547,7 +567,8 @@ WebRequestHandler::Search(const std::string& table_name, const nlohmann::json& j
}
Status
WebRequestHandler::DeleteByIDs(const std::string& table_name, const nlohmann::json& json, std::string& result_str) {
WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json,
std::string& result_str) {
std::vector<int64_t> vector_ids;
if (!json.contains("ids")) {
return Status(BODY_FIELD_LOSS, "Field \"delete\" must contains \"ids\"");
@ -565,7 +586,7 @@ WebRequestHandler::DeleteByIDs(const std::string& table_name, const nlohmann::js
vector_ids.emplace_back(std::stol(id_str));
}
auto status = request_handler_.DeleteByID(context_ptr_, table_name, vector_ids);
auto status = request_handler_.DeleteByID(context_ptr_, collection_name, vector_ids);
nlohmann::json result_json;
AddStatusToJson(result_json, status.code(), status.message());
@ -575,13 +596,13 @@ WebRequestHandler::DeleteByIDs(const std::string& table_name, const nlohmann::js
}
Status
WebRequestHandler::GetVectorsByIDs(const std::string& table_name, const std::vector<int64_t>& ids,
WebRequestHandler::GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids,
nlohmann::json& json_out) {
std::vector<engine::VectorsData> vector_batch;
for (size_t i = 0; i < ids.size(); i++) {
auto vec_ids = std::vector<int64_t>(ids.begin() + i, ids.begin() + i + 1);
engine::VectorsData vectors_data;
auto status = request_handler_.GetVectorByID(context_ptr_, table_name, vec_ids, vectors_data);
auto status = request_handler_.GetVectorByID(context_ptr_, collection_name, vec_ids, vectors_data);
if (!status.ok()) {
return status;
}
@ -589,7 +610,7 @@ WebRequestHandler::GetVectorsByIDs(const std::string& table_name, const std::vec
}
bool bin;
auto status = IsBinaryTable(table_name, bin);
auto status = IsBinaryTable(collection_name, bin);
if (!status.ok()) {
return status;
}
@ -879,30 +900,31 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt
* Table {
*/
StatusDto::ObjectWrapper
WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& table_schema) {
if (nullptr == table_schema->table_name.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'table_name\' is missing")
WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& collection_schema) {
if (nullptr == collection_schema->collection_name.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'collection_name\' is missing")
}
if (nullptr == table_schema->dimension.get()) {
if (nullptr == collection_schema->dimension.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'dimension\' is missing")
}
if (nullptr == table_schema->index_file_size.get()) {
if (nullptr == collection_schema->index_file_size.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_file_size\' is missing")
}
if (nullptr == table_schema->metric_type.get()) {
if (nullptr == collection_schema->metric_type.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'metric_type\' is missing")
}
if (MetricNameMap.find(table_schema->metric_type->std_str()) == MetricNameMap.end()) {
if (MetricNameMap.find(collection_schema->metric_type->std_str()) == MetricNameMap.end()) {
RETURN_STATUS_DTO(ILLEGAL_METRIC_TYPE, "metric_type is illegal")
}
auto status = request_handler_.CreateTable(
context_ptr_, table_schema->table_name->std_str(), table_schema->dimension, table_schema->index_file_size,
static_cast<int64_t>(MetricNameMap.at(table_schema->metric_type->std_str())));
auto status =
request_handler_.CreateTable(context_ptr_, collection_schema->collection_name->std_str(),
collection_schema->dimension, collection_schema->index_file_size,
static_cast<int64_t>(MetricNameMap.at(collection_schema->metric_type->std_str())));
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -926,45 +948,41 @@ WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result)
}
bool all_required = false;
auto required = query_params.get("all_required");
if (nullptr != required.get()) {
auto required_str = required->std_str();
if (!ValidationUtil::ValidateStringIsBool(required_str).ok()) {
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, "Query param \'all_required\' must be a bool")
}
all_required = required_str == "True" || required_str == "true";
ParseQueryBool(query_params, "all_required", all_required);
if (!status.ok()) {
RETURN_STATUS_DTO(status.code(), status.message().c_str());
}
std::vector<std::string> tables;
status = request_handler_.ShowTables(context_ptr_, tables);
std::vector<std::string> collections;
status = request_handler_.ShowTables(context_ptr_, collections);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
if (all_required) {
offset = 0;
page_size = tables.size();
page_size = collections.size();
} else {
offset = std::min((size_t)offset, tables.size());
page_size = std::min(tables.size() - offset, (size_t)page_size);
offset = std::min((size_t)offset, collections.size());
page_size = std::min(collections.size() - offset, (size_t)page_size);
}
nlohmann::json tables_json;
nlohmann::json collections_json;
for (int64_t i = offset; i < page_size + offset; i++) {
nlohmann::json table_json;
status = GetTableMetaInfo(tables.at(i), table_json);
nlohmann::json collection_json;
status = GetTableMetaInfo(collections.at(i), collection_json);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
tables_json.push_back(table_json);
collections_json.push_back(collection_json);
}
nlohmann::json result_json;
result_json["count"] = tables.size();
if (tables_json.empty()) {
result_json["tables"] = std::vector<int64_t>();
result_json["count"] = collections.size();
if (collections_json.empty()) {
result_json["collections"] = std::vector<int64_t>();
} else {
result_json["tables"] = tables_json;
result_json["collections"] = collections_json;
}
result = result_json.dump().c_str();
@ -973,9 +991,9 @@ WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result)
}
StatusDto::ObjectWrapper
WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query_params, OString& result) {
if (nullptr == table_name.get()) {
RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'table_name\' is required!");
WebRequestHandler::GetTable(const OString& collection_name, const OQueryParams& query_params, OString& result) {
if (nullptr == collection_name.get()) {
RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'collection_name\' is required!");
}
std::string stat;
@ -986,11 +1004,11 @@ WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query
if (!stat.empty() && stat == "stat") {
nlohmann::json json;
status = GetTableStat(table_name->std_str(), json);
status = GetTableStat(collection_name->std_str(), json);
result = status.ok() ? json.dump().c_str() : "NULL";
} else {
nlohmann::json json;
status = GetTableMetaInfo(table_name->std_str(), json);
status = GetTableMetaInfo(collection_name->std_str(), json);
result = status.ok() ? json.dump().c_str() : "NULL";
}
@ -998,8 +1016,8 @@ WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query
}
StatusDto::ObjectWrapper
WebRequestHandler::DropTable(const OString& table_name) {
auto status = request_handler_.DropTable(context_ptr_, table_name->std_str());
WebRequestHandler::DropTable(const OString& collection_name) {
auto status = request_handler_.DropTable(context_ptr_, collection_name->std_str());
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -1010,41 +1028,49 @@ WebRequestHandler::DropTable(const OString& table_name) {
*/
StatusDto::ObjectWrapper
WebRequestHandler::CreateIndex(const OString& table_name, const IndexRequestDto::ObjectWrapper& index_param) {
if (nullptr == index_param->index_type.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required")
}
std::string index_type = index_param->index_type->std_str();
if (IndexNameMap.find(index_type) == IndexNameMap.end()) {
RETURN_STATUS_DTO(ILLEGAL_INDEX_TYPE, "The index type is invalid.")
WebRequestHandler::CreateIndex(const OString& table_name, const OString& body) {
try {
auto request_json = nlohmann::json::parse(body->std_str());
if (!request_json.contains("index_type")) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required");
}
std::string index_type = request_json["index_type"];
if (IndexNameMap.find(index_type) == IndexNameMap.end()) {
RETURN_STATUS_DTO(ILLEGAL_INDEX_TYPE, "The index type is invalid.")
}
auto index = static_cast<int64_t>(IndexNameMap.at(index_type));
if (!request_json.contains("params")) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'params\' is required")
}
auto status = request_handler_.CreateIndex(context_ptr_, table_name->std_str(), index, request_json["params"]);
ASSIGN_RETURN_STATUS_DTO(status);
} catch (nlohmann::detail::parse_error& e) {
} catch (nlohmann::detail::type_error& e) {
}
if (nullptr == index_param->nlist.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'nlist\' is required")
}
auto status =
request_handler_.CreateIndex(context_ptr_, table_name->std_str(),
static_cast<int64_t>(IndexNameMap.at(index_type)), index_param->nlist->getValue());
ASSIGN_RETURN_STATUS_DTO(status)
ASSIGN_RETURN_STATUS_DTO(Status::OK())
}
StatusDto::ObjectWrapper
WebRequestHandler::GetIndex(const OString& table_name, IndexDto::ObjectWrapper& index_dto) {
WebRequestHandler::GetIndex(const OString& collection_name, OString& result) {
IndexParam param;
auto status = request_handler_.DescribeIndex(context_ptr_, table_name->std_str(), param);
auto status = request_handler_.DescribeIndex(context_ptr_, collection_name->std_str(), param);
if (status.ok()) {
index_dto->index_type = IndexMap.at(engine::EngineType(param.index_type_)).c_str();
index_dto->nlist = param.nlist_;
nlohmann::json json_out;
auto index_type = IndexMap.at(engine::EngineType(param.index_type_));
json_out["index_type"] = index_type;
json_out["params"] = nlohmann::json::parse(param.extra_params_);
result = json_out.dump().c_str();
}
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::DropIndex(const OString& table_name) {
auto status = request_handler_.DropIndex(context_ptr_, table_name->std_str());
WebRequestHandler::DropIndex(const OString& collection_name) {
auto status = request_handler_.DropIndex(context_ptr_, collection_name->std_str());
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -1054,19 +1080,19 @@ WebRequestHandler::DropIndex(const OString& table_name) {
* Partition {
*/
StatusDto::ObjectWrapper
WebRequestHandler::CreatePartition(const OString& table_name, const PartitionRequestDto::ObjectWrapper& param) {
WebRequestHandler::CreatePartition(const OString& collection_name, const PartitionRequestDto::ObjectWrapper& param) {
if (nullptr == param->partition_tag.get()) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'partition_tag\' is required")
}
auto status =
request_handler_.CreatePartition(context_ptr_, table_name->std_str(), param->partition_tag->std_str());
request_handler_.CreatePartition(context_ptr_, collection_name->std_str(), param->partition_tag->std_str());
ASSIGN_RETURN_STATUS_DTO(status)
}
StatusDto::ObjectWrapper
WebRequestHandler::ShowPartitions(const OString& table_name, const OQueryParams& query_params,
WebRequestHandler::ShowPartitions(const OString& collection_name, const OQueryParams& query_params,
PartitionListDto::ObjectWrapper& partition_list_dto) {
int64_t offset = 0;
auto status = ParseQueryInteger(query_params, "offset", offset);
@ -1096,7 +1122,7 @@ WebRequestHandler::ShowPartitions(const OString& table_name, const OQueryParams&
}
std::vector<PartitionParam> partitions;
status = request_handler_.ShowPartitions(context_ptr_, table_name->std_str(), partitions);
status = request_handler_.ShowPartitions(context_ptr_, collection_name->std_str(), partitions);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -1124,7 +1150,7 @@ WebRequestHandler::ShowPartitions(const OString& table_name, const OQueryParams&
}
StatusDto::ObjectWrapper
WebRequestHandler::DropPartition(const OString& table_name, const OString& body) {
WebRequestHandler::DropPartition(const OString& collection_name, const OString& body) {
std::string tag;
try {
auto json = nlohmann::json::parse(body->std_str());
@ -1134,7 +1160,7 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& body)
} catch (nlohmann::detail::type_error& e) {
RETURN_STATUS_DTO(BODY_PARSE_FAIL, e.what())
}
auto status = request_handler_.DropPartition(context_ptr_, table_name->std_str(), tag);
auto status = request_handler_.DropPartition(context_ptr_, collection_name->std_str(), tag);
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -1144,7 +1170,7 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& body)
* Segment {
*/
StatusDto::ObjectWrapper
WebRequestHandler::ShowSegments(const OString& table_name, const OQueryParams& query_params, OString& response) {
WebRequestHandler::ShowSegments(const OString& collection_name, const OQueryParams& query_params, OString& response) {
int64_t offset = 0;
auto status = ParseQueryInteger(query_params, "offset", offset);
if (!status.ok()) {
@ -1177,7 +1203,7 @@ WebRequestHandler::ShowSegments(const OString& table_name, const OQueryParams& q
}
TableInfo info;
status = request_handler_.ShowTableInfo(context_ptr_, table_name->std_str(), info);
status = request_handler_.ShowTableInfo(context_ptr_, collection_name->std_str(), info);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -1228,7 +1254,7 @@ WebRequestHandler::ShowSegments(const OString& table_name, const OQueryParams& q
}
StatusDto::ObjectWrapper
WebRequestHandler::GetSegmentInfo(const OString& table_name, const OString& segment_name, const OString& info,
WebRequestHandler::GetSegmentInfo(const OString& collection_name, const OString& segment_name, const OString& info,
const OQueryParams& query_params, OString& result) {
int64_t offset = 0;
auto status = ParseQueryInteger(query_params, "offset", offset);
@ -1252,10 +1278,10 @@ WebRequestHandler::GetSegmentInfo(const OString& table_name, const OString& segm
nlohmann::json json;
// Get vectors
if (re == "vectors") {
status = GetSegmentVectors(table_name->std_str(), segment_name->std_str(), page_size, offset, json);
status = GetSegmentVectors(collection_name->std_str(), segment_name->std_str(), page_size, offset, json);
// Get vector ids
} else if (re == "ids") {
status = GetSegmentIds(table_name->std_str(), segment_name->std_str(), page_size, offset, json);
status = GetSegmentIds(collection_name->std_str(), segment_name->std_str(), page_size, offset, json);
}
result = status.ok() ? json.dump().c_str() : "NULL";
@ -1268,14 +1294,14 @@ WebRequestHandler::GetSegmentInfo(const OString& table_name, const OString& segm
* Vector {
*/
StatusDto::ObjectWrapper
WebRequestHandler::Insert(const OString& table_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto) {
WebRequestHandler::Insert(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto) {
if (nullptr == body.get() || body->getSize() == 0) {
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Request payload is required.")
}
// step 1: copy vectors
bool bin_flag;
auto status = IsBinaryTable(table_name->std_str(), bin_flag);
auto status = IsBinaryTable(collection_name->std_str(), bin_flag);
if (!status.ok()) {
ASSIGN_RETURN_STATUS_DTO(status)
}
@ -1310,7 +1336,7 @@ WebRequestHandler::Insert(const OString& table_name, const OString& body, Vector
}
// step 4: construct result
status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, tag);
status = request_handler_.Insert(context_ptr_, collection_name->std_str(), vectors, tag);
if (status.ok()) {
ids_dto->ids = ids_dto->ids->createShared();
for (auto& id : vectors.id_array_) {
@ -1322,7 +1348,7 @@ WebRequestHandler::Insert(const OString& table_name, const OString& body, Vector
}
StatusDto::ObjectWrapper
WebRequestHandler::GetVector(const OString& table_name, const OQueryParams& query_params, OString& response) {
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
int64_t id = 0;
auto status = ParseQueryInteger(query_params, "id", id, false);
if (!status.ok()) {
@ -1332,7 +1358,7 @@ WebRequestHandler::GetVector(const OString& table_name, const OQueryParams& quer
std::vector<int64_t> ids = {id};
engine::VectorsData vectors;
nlohmann::json vectors_json;
status = GetVectorsByIDs(table_name->std_str(), ids, vectors_json);
status = GetVectorsByIDs(collection_name->std_str(), ids, vectors_json);
if (!status.ok()) {
response = "NULL";
ASSIGN_RETURN_STATUS_DTO(status)
@ -1352,7 +1378,7 @@ WebRequestHandler::GetVector(const OString& table_name, const OQueryParams& quer
}
StatusDto::ObjectWrapper
WebRequestHandler::VectorsOp(const OString& table_name, const OString& payload, OString& response) {
WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payload, OString& response) {
auto status = Status::OK();
std::string result_str;
@ -1360,9 +1386,9 @@ WebRequestHandler::VectorsOp(const OString& table_name, const OString& payload,
nlohmann::json payload_json = nlohmann::json::parse(payload->std_str());
if (payload_json.contains("delete")) {
status = DeleteByIDs(table_name->std_str(), payload_json["delete"], result_str);
status = DeleteByIDs(collection_name->std_str(), payload_json["delete"], result_str);
} else if (payload_json.contains("search")) {
status = Search(table_name->std_str(), payload_json["search"], result_str);
status = Search(collection_name->std_str(), payload_json["search"], result_str);
} else {
status = Status(ILLEGAL_BODY, "Unknown body");
}

View File

@ -84,6 +84,9 @@ class WebRequestHandler {
Status
ParseQueryStr(const OQueryParams& query_params, const std::string& key, std::string& value, bool nullable = true);
Status
ParseQueryBool(const OQueryParams& query_params, const std::string& key, bool& value, bool nullable = true);
private:
void
AddStatusToJson(nlohmann::json& json, int64_t code, const std::string& msg);
@ -189,10 +192,10 @@ class WebRequestHandler {
* Index
*/
StatusDto::ObjectWrapper
CreateIndex(const OString& table_name, const IndexRequestDto::ObjectWrapper& index_param);
CreateIndex(const OString& table_name, const OString& body);
StatusDto::ObjectWrapper
GetIndex(const OString& table_name, IndexDto::ObjectWrapper& index_dto);
GetIndex(const OString& table_name, OString& result);
StatusDto::ObjectWrapper
DropIndex(const OString& table_name);

View File

@ -12,9 +12,12 @@
#include "utils/StringHelpFunctions.h"
#include <fiu-local.h>
#include <algorithm>
#include <regex>
#include <string>
#include "utils/ValidationUtil.h"
namespace milvus {
namespace server {
@ -148,11 +151,21 @@ StringHelpFunctions::IsRegexMatch(const std::string& target_str, const std::stri
// regex match
std::regex pattern(pattern_str);
std::smatch results;
if (std::regex_match(target_str, results, pattern)) {
return true;
} else {
return false;
return std::regex_match(target_str, results, pattern);
}
Status
StringHelpFunctions::ConvertToBoolean(const std::string& str, bool& value) {
auto status = ValidationUtil::ValidateStringIsBool(str);
if (!status.ok()) {
return status;
}
std::string s = str;
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
value = s == "true" || s == "on" || s == "yes" || s == "1";
return Status::OK();
}
} // namespace server

View File

@ -64,6 +64,12 @@ class StringHelpFunctions {
// regex grammar reference: http://www.cplusplus.com/reference/regex/ECMAScript/
static bool
IsRegexMatch(const std::string& target_str, const std::string& pattern);
// conversion rules refer to ValidationUtil::ValidateStringIsBool()
// "true", "on", "yes", "1" ==> true
// "false", "off", "no", "0", "" ==> false
static Status
ConvertToBoolean(const std::string& str, bool& value);
};
} // namespace server

View File

@ -12,6 +12,7 @@
#include "utils/ValidationUtil.h"
#include "Log.h"
#include "db/engine/ExecutionEngine.h"
#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h"
#include "utils/StringHelpFunctions.h"
#include <arpa/inet.h>
@ -32,10 +33,62 @@
namespace milvus {
namespace server {
namespace {
constexpr size_t TABLE_NAME_SIZE_LIMIT = 255;
constexpr int64_t TABLE_DIMENSION_LIMIT = 32768;
constexpr int32_t INDEX_FILE_SIZE_LIMIT = 4096; // index trigger size max = 4096 MB
Status
CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max,
bool min_close = true, bool max_closed = true) {
if (json_params.find(param_name) == json_params.end()) {
std::string msg = "Parameter list must contain: ";
return Status(SERVER_INVALID_ARGUMENT, msg + param_name);
}
try {
int64_t value = json_params[param_name];
bool min_err = min_close ? value < min : value <= min;
bool max_err = max_closed ? value > max : value >= max;
if (min_err || max_err) {
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value) + ". Valid range is " +
(min_close ? "[" : "(") + std::to_string(min) + ", " + std::to_string(max) +
(max_closed ? "]" : ")");
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
} catch (std::exception& e) {
std::string msg = "Invalid " + param_name + ": ";
return Status(SERVER_INVALID_ARGUMENT, msg + e.what());
}
return Status::OK();
}
Status
CheckParameterExistence(const milvus::json& json_params, const std::string& param_name) {
if (json_params.find(param_name) == json_params.end()) {
std::string msg = "Parameter list must contain: ";
return Status(SERVER_INVALID_ARGUMENT, msg + param_name);
}
try {
int64_t value = json_params[param_name];
if (value < 0) {
std::string msg = "Invalid " + param_name + " value: " + std::to_string(value);
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
} catch (std::exception& e) {
std::string msg = "Invalid " + param_name + ": ";
return Status(SERVER_INVALID_ARGUMENT, msg + e.what());
}
return Status::OK();
}
} // namespace
Status
ValidationUtil::ValidateTableName(const std::string& table_name) {
// Table name shouldn't be empty.
@ -109,24 +162,114 @@ ValidationUtil::ValidateTableIndexType(int32_t index_type) {
return Status::OK();
}
Status
ValidationUtil::ValidateIndexParams(const milvus::json& index_params, const engine::meta::TableSchema& table_schema,
int32_t index_type) {
switch (index_type) {
case (int32_t)engine::EngineType::FAISS_IDMAP:
case (int32_t)engine::EngineType::FAISS_BIN_IDMAP: {
break;
}
case (int32_t)engine::EngineType::FAISS_IVFFLAT:
case (int32_t)engine::EngineType::FAISS_IVFSQ8:
case (int32_t)engine::EngineType::FAISS_IVFSQ8H:
case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT: {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 0, 999999, false);
if (!status.ok()) {
return status;
}
break;
}
case (int32_t)engine::EngineType::FAISS_PQ: {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 0, 999999, false);
if (!status.ok()) {
return status;
}
status = CheckParameterExistence(index_params, knowhere::IndexParams::m);
if (!status.ok()) {
return status;
}
break;
}
case (int32_t)engine::EngineType::NSG_MIX: {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::search_length, 10, 300);
if (!status.ok()) {
return status;
}
status = CheckParameterRange(index_params, knowhere::IndexParams::out_degree, 5, 300);
if (!status.ok()) {
return status;
}
status = CheckParameterRange(index_params, knowhere::IndexParams::candidate, 50, 1000);
if (!status.ok()) {
return status;
}
status = CheckParameterRange(index_params, knowhere::IndexParams::knng, 5, 300);
if (!status.ok()) {
return status;
}
break;
}
case (int32_t)engine::EngineType::HNSW: {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 5, 48);
if (!status.ok()) {
return status;
}
status = CheckParameterRange(index_params, knowhere::IndexParams::efConstruction, 100, 500);
if (!status.ok()) {
return status;
}
break;
}
}
return Status::OK();
}
Status
ValidationUtil::ValidateSearchParams(const milvus::json& search_params, const engine::meta::TableSchema& table_schema,
int64_t topk) {
switch (table_schema.engine_type_) {
case (int32_t)engine::EngineType::FAISS_IDMAP:
case (int32_t)engine::EngineType::FAISS_BIN_IDMAP: {
break;
}
case (int32_t)engine::EngineType::FAISS_IVFFLAT:
case (int32_t)engine::EngineType::FAISS_IVFSQ8:
case (int32_t)engine::EngineType::FAISS_IVFSQ8H:
case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT:
case (int32_t)engine::EngineType::FAISS_PQ: {
auto status = CheckParameterRange(search_params, knowhere::IndexParams::nprobe, 1, 999999);
if (!status.ok()) {
return status;
}
break;
}
case (int32_t)engine::EngineType::NSG_MIX: {
auto status = CheckParameterRange(search_params, knowhere::IndexParams::search_length, 10, 300);
if (!status.ok()) {
return status;
}
break;
}
case (int32_t)engine::EngineType::HNSW: {
auto status = CheckParameterRange(search_params, knowhere::IndexParams::ef, topk, 1000);
if (!status.ok()) {
return status;
}
break;
}
}
return Status::OK();
}
bool
ValidationUtil::IsBinaryIndexType(int32_t index_type) {
return (index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IDMAP)) ||
(index_type == static_cast<int32_t>(engine::EngineType::FAISS_BIN_IVFFLAT));
}
Status
ValidationUtil::ValidateTableIndexNlist(int32_t nlist) {
if (nlist <= 0) {
std::string msg =
"Invalid index nlist: " + std::to_string(nlist) + ". " + "The index nlist must be greater than 0.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_INDEX_NLIST, msg);
}
return Status::OK();
}
Status
ValidationUtil::ValidateTableIndexFileSize(int64_t index_file_size) {
if (index_file_size <= 0 || index_file_size > INDEX_FILE_SIZE_LIMIT) {
@ -170,18 +313,6 @@ ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchem
return Status::OK();
}
Status
ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema) {
if (nprobe <= 0 || nprobe > table_schema.nlist_) {
std::string msg = "Invalid nprobe: " + std::to_string(nprobe) + ". " +
"The nprobe must be within the range of 1 ~ index nlist.";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_NPROBE, msg);
}
return Status::OK();
}
Status
ValidationUtil::ValidatePartitionName(const std::string& partition_name) {
if (partition_name.empty()) {

View File

@ -12,6 +12,7 @@
#pragma once
#include "db/meta/MetaTypes.h"
#include "utils/Json.h"
#include "utils/Status.h"
#include <string>
@ -34,11 +35,16 @@ class ValidationUtil {
static Status
ValidateTableIndexType(int32_t index_type);
static bool
IsBinaryIndexType(int32_t index_type);
static Status
ValidateIndexParams(const milvus::json& index_params, const engine::meta::TableSchema& table_schema,
int32_t index_type);
static Status
ValidateTableIndexNlist(int32_t nlist);
ValidateSearchParams(const milvus::json& search_params, const engine::meta::TableSchema& table_schema,
int64_t topk);
static bool
IsBinaryIndexType(int32_t index_type);
static Status
ValidateTableIndexFileSize(int64_t index_file_size);
@ -52,9 +58,6 @@ class ValidationUtil {
static Status
ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema);
static Status
ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema);
static Status
ValidatePartitionName(const std::string& partition_name);

View File

@ -25,7 +25,7 @@ Status
BinVecImpl::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const uint8_t* xt) {
try {
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
@ -47,15 +47,13 @@ BinVecImpl::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, c
Status
BinVecImpl::Search(const int64_t& nq, const uint8_t* xq, float* dist, int64_t* ids, const Config& cfg) {
try {
auto k = cfg->k;
int64_t k = cfg[knowhere::meta::TOPK];
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nq);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::TENSOR, xq);
Config search_cfg = cfg;
auto res = index_->Search(ret_ds, search_cfg);
auto res = index_->Search(ret_ds, cfg);
//{
// auto& ids = ids_array;
// auto& dists = dis_array;
@ -150,9 +148,7 @@ BinVecImpl::GetVectorById(const int64_t n, const int64_t* ids, uint8_t* x, const
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::IDS, ids);
Config search_cfg = cfg;
auto res = index_->GetVectorById(ret_ds, search_cfg);
auto res = index_->GetVectorById(ret_ds, cfg);
// TODO(linxj): avoid copy here.
auto res_x = res->Get<uint8_t*>(knowhere::meta::TENSOR);
@ -176,15 +172,13 @@ BinVecImpl::SearchById(const int64_t& nq, const int64_t* xq, float* dist, int64_
throw WrapperException("not support");
}
try {
auto k = cfg->k;
int64_t k = cfg[knowhere::meta::TOPK];
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nq);
ret_ds->Set(knowhere::meta::DIM, dim);
ret_ds->Set(knowhere::meta::IDS, xq);
Config search_cfg = cfg;
auto res = index_->SearchById(ret_ds, search_cfg);
auto res = index_->SearchById(ret_ds, cfg);
//{
// auto& ids = ids_array;
// auto& dists = dis_array;
@ -235,7 +229,7 @@ BinVecImpl::GetBlacklist(faiss::ConcurrentBitsetPtr& list) {
ErrorCode
BinBFIndex::Build(const Config& cfg) {
try {
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
std::static_pointer_cast<knowhere::BinaryIDMAP>(index_)->Train(cfg);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
@ -251,7 +245,7 @@ Status
BinBFIndex::BuildAll(const int64_t& nb, const uint8_t* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const uint8_t* xt) {
try {
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
auto ret_ds = std::make_shared<knowhere::Dataset>();
ret_ds->Set(knowhere::meta::ROWS, nb);
ret_ds->Set(knowhere::meta::DIM, dim);

View File

@ -8,12 +8,13 @@
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "wrapper/ConfAdapter.h"
#include <fiu-local.h>
#include <cmath>
#include <memory>
#include <string>
#include <vector>
#include "WrapperException.h"
@ -21,8 +22,6 @@
#include "server/Config.h"
#include "utils/Log.h"
// TODO(lxj): add conf checker
namespace milvus {
namespace engine {
@ -32,45 +31,64 @@ namespace engine {
#define GPU_MAX_NRPOBE 1024
#endif
void
ConfAdapter::MatchBase(knowhere::Config conf, knowhere::METRICTYPE default_metric) {
if (conf->metric_type == knowhere::DEFAULT_TYPE)
conf->metric_type = default_metric;
#define DEFAULT_MAX_DIM 16384
#define DEFAULT_MIN_DIM 1
#define DEFAULT_MAX_K 16384
#define DEFAULT_MIN_K 1
#define DEFAULT_MIN_ROWS 1 // minimum size for build index
#define DEFAULT_MAX_ROWS 50000000
#define CheckIntByRange(key, min, max) \
if (!oricfg.contains(key) || !oricfg[key].is_number_integer() || oricfg[key].get<int64_t>() > max || \
oricfg[key].get<int64_t>() < min) { \
return false; \
}
// #define checkfloat(key, min, max) \
// if (!oricfg.contains(key) || !oricfg[key].is_number_float() || oricfg[key] >= max || oricfg[key] <= min) { \
// return false; \
// }
#define CheckIntByValues(key, container) \
if (!oricfg.contains(key) || !oricfg[key].is_number_integer()) { \
return false; \
} else { \
auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get<int64_t>()); \
if (finder == std::end(container)) { \
return false; \
} \
}
#define CheckStrByValues(key, container) \
if (!oricfg.contains(key) || !oricfg[key].is_string()) { \
return false; \
} else { \
auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get<std::string>()); \
if (finder == std::end(container)) { \
return false; \
} \
}
bool
ConfAdapter::CheckTrain(milvus::json& oricfg) {
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
return true;
}
knowhere::Config
ConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::Cfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
conf->k = metaconf.k;
MatchBase(conf);
return conf;
}
bool
ConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
CheckIntByRange(knowhere::meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K);
knowhere::Config
ConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::Cfg>();
conf->k = metaconf.k;
return conf;
return true;
}
knowhere::Config
IVFConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
MatchBase(conf);
return conf;
}
static constexpr float TYPICAL_COUNT = 1000000.0;
int64_t
IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist) {
static float TYPICAL_COUNT = 1000000.0;
if (size <= TYPICAL_COUNT / per_nlist + 1) {
// handle less row count, avoid nlist set to 0
return 1;
@ -81,62 +99,88 @@ IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist, const int6
return nlist;
}
knowhere::Config
IVFConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::IVFCfg>();
conf->k = metaconf.k;
bool
IVFConfAdapter::CheckTrain(milvus::json& oricfg) {
static int64_t MAX_NLIST = 999999;
static int64_t MIN_NLIST = 1;
if (metaconf.nprobe <= 0)
conf->nprobe = 16; // hardcode here
else
conf->nprobe = metaconf.nprobe;
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
switch (type) {
case IndexType::FAISS_IVFFLAT_GPU:
case IndexType::FAISS_IVFSQ8_GPU:
case IndexType::FAISS_IVFPQ_GPU:
if (conf->nprobe > GPU_MAX_NRPOBE) {
WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE
<< ", but you passed " << conf->nprobe << ". Search with " << GPU_MAX_NRPOBE
<< " instead";
conf->nprobe = GPU_MAX_NRPOBE;
}
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// auto tune params
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(),
oricfg[knowhere::IndexParams::nlist].get<int64_t>(), 16384);
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
return ConfAdapter::CheckTrain(oricfg);
}
bool
IVFConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
static int64_t MIN_NPROBE = 1;
static int64_t MAX_NPROBE = 999999; // todo(linxj): [1, nlist]
if (type == IndexType::FAISS_IVFPQ_GPU || type == IndexType::FAISS_IVFSQ8_GPU ||
type == IndexType::FAISS_IVFSQ8_HYBRID || type == IndexType::FAISS_IVFFLAT_GPU) {
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, GPU_MAX_NRPOBE);
} else {
CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE);
}
return conf;
return ConfAdapter::CheckSearch(oricfg, type);
}
knowhere::Config
IVFSQConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFSQCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
conf->nbits = 8;
MatchBase(conf);
return conf;
bool
IVFSQConfAdapter::CheckTrain(milvus::json& oricfg) {
static int64_t DEFAULT_NBITS = 8;
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
return IVFConfAdapter::CheckTrain(oricfg);
}
knowhere::Config
IVFPQConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFPQCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
conf->nbits = 8;
MatchBase(conf);
bool
IVFPQConfAdapter::CheckTrain(milvus::json& oricfg) {
static int64_t DEFAULT_NBITS = 8;
static int64_t MAX_NLIST = 999999;
static int64_t MIN_NLIST = 1;
static std::vector<std::string> CPU_METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
static std::vector<std::string> GPU_METRICS{knowhere::Metric::L2};
oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS;
#ifdef MILVUS_GPU_VERSION
Status s;
bool enable_gpu = false;
server::Config& config = server::Config::GetInstance();
s = config.GetGpuResourceConfigEnable(enable_gpu);
if (s.ok() && conf->metric_type == knowhere::METRICTYPE::IP) {
WRAPPER_LOG_ERROR << "PQ not support IP in GPU version!";
throw WrapperException("PQ not support IP in GPU version!");
if (s.ok()) {
CheckStrByValues(knowhere::Metric::TYPE, GPU_METRICS);
} else {
CheckStrByValues(knowhere::Metric::TYPE, CPU_METRICS);
}
#endif
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
// int64_t nlist = oricfg[knowhere::IndexParams::nlist];
// CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
// auto tune params
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(),
oricfg[knowhere::IndexParams::nlist].get<int64_t>(), 16384);
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
/*
* Faiss 1.6
@ -147,165 +191,110 @@ IVFPQConfAdapter::Match(const TempMetaConf& metaconf) {
static std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
std::vector<int64_t> resset;
for (const auto& dimperquantizer : support_dim_per_subquantizer) {
if (!(conf->d % dimperquantizer)) {
auto subquantzier_num = conf->d / dimperquantizer;
if (!(oricfg[knowhere::meta::DIM].get<int64_t>() % dimperquantizer)) {
auto subquantzier_num = oricfg[knowhere::meta::DIM].get<int64_t>() / dimperquantizer;
auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num);
if (finder != support_subquantizer.end()) {
resset.push_back(subquantzier_num);
}
}
}
fiu_do_on("IVFPQConfAdapter.Match.empty_resset", resset.clear());
if (resset.empty()) {
// todo(linxj): throw exception here.
WRAPPER_LOG_ERROR << "The dims of PQ is wrong : only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-"
"quantizer are currently supported with no precomputed codes.";
throw WrapperException(
"The dims of PQ is wrong : only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims "
"per sub-quantizer are currently supported with no precomputed codes.");
// return nullptr;
}
static int64_t compression_level = 1; // 1:low, 2:high
if (compression_level == 1) {
conf->m = resset[int(resset.size() / 2)];
WRAPPER_LOG_DEBUG << "PQ m = " << conf->m << ", compression radio = " << conf->d / conf->m * 4;
}
return conf;
CheckIntByValues(knowhere::IndexParams::m, resset);
return true;
}
knowhere::Config
IVFPQConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::IVFPQCfg>();
conf->k = metaconf.k;
bool
NSGConfAdapter::CheckTrain(milvus::json& oricfg) {
static int64_t MIN_KNNG = 5;
static int64_t MAX_KNNG = 300;
static int64_t MIN_SEARCH_LENGTH = 10;
static int64_t MAX_SEARCH_LENGTH = 300;
static int64_t MIN_OUT_DEGREE = 5;
static int64_t MAX_OUT_DEGREE = 300;
static int64_t MIN_CANDIDATE_POOL_SIZE = 50;
static int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
static std::vector<std::string> METRICS{knowhere::Metric::L2};
if (metaconf.nprobe <= 0) {
WRAPPER_LOG_ERROR << "The nprobe of PQ is wrong!";
throw WrapperException("The nprobe of PQ is wrong!");
} else {
conf->nprobe = metaconf.nprobe;
}
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::knng, MIN_KNNG, MAX_KNNG);
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
CheckIntByRange(knowhere::IndexParams::out_degree, MIN_OUT_DEGREE, MAX_OUT_DEGREE);
CheckIntByRange(knowhere::IndexParams::candidate, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE);
return conf;
// auto tune params
oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), 8192, 8192);
oricfg[knowhere::IndexParams::nprobe] = int(oricfg[knowhere::IndexParams::nlist].get<int64_t>() * 0.01);
return true;
}
int64_t
IVFPQConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
if (size <= TYPICAL_COUNT / 16384 + 1) {
// handle less row count, avoid nlist set to 0
return 1;
} else if (int(size / TYPICAL_COUNT) * nlist <= 0) {
// calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
return int(size / TYPICAL_COUNT * 16384);
}
return nlist;
bool
NSGConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
static int64_t MIN_SEARCH_LENGTH = 1;
static int64_t MAX_SEARCH_LENGTH = 300;
CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH);
return ConfAdapter::CheckSearch(oricfg, type);
}
knowhere::Config
NSGConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::NSGCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 16384);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
conf->k = metaconf.k;
bool
HNSWConfAdapter::CheckTrain(milvus::json& oricfg) {
static int64_t MIN_EFCONSTRUCTION = 100;
static int64_t MAX_EFCONSTRUCTION = 500;
static int64_t MIN_M = 5;
static int64_t MAX_M = 48;
auto scale_factor = round(metaconf.dim / 128.0);
scale_factor = scale_factor >= 4 ? 4 : scale_factor;
conf->nprobe = int64_t(conf->nlist * 0.01);
// conf->knng = 40 + 10 * scale_factor; // the size of knng
conf->knng = 50;
conf->search_length = 50 + 5 * scale_factor;
conf->out_degree = 50 + 5 * scale_factor;
conf->candidate_pool_size = 300;
MatchBase(conf);
return conf;
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
return ConfAdapter::CheckTrain(oricfg);
}
knowhere::Config
NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::NSGCfg>();
conf->k = metaconf.k;
conf->search_length = metaconf.search_length;
if (metaconf.search_length == TEMPMETA_DEFAULT_VALUE) {
conf->search_length = 30; // TODO(linxj): hardcode here.
}
return conf;
bool
HNSWConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
static int64_t MAX_EF = 4096;
CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF);
return ConfAdapter::CheckSearch(oricfg, type);
}
knowhere::Config
SPTAGKDTConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::KDTCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
return conf;
bool
BinIDMAPConfAdapter::CheckTrain(milvus::json& oricfg) {
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
knowhere::Metric::TANIMOTO};
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
return true;
}
knowhere::Config
SPTAGKDTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::KDTCfg>();
conf->k = metaconf.k;
return conf;
}
bool
BinIVFConfAdapter::CheckTrain(milvus::json& oricfg) {
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
knowhere::Metric::TANIMOTO};
static int64_t MAX_NLIST = 999999;
static int64_t MIN_NLIST = 1;
knowhere::Config
SPTAGBKTConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::BKTCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
return conf;
}
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
knowhere::Config
SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::BKTCfg>();
conf->k = metaconf.k;
return conf;
}
int64_t nlist = oricfg[knowhere::IndexParams::nlist];
CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS);
knowhere::Config
HNSWConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::HNSWCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
// Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
conf->ef = 500; // ef can be auto-configured by using sample data.
conf->M = 24; // A reasonable range of M is from 5 to 48.
return conf;
}
knowhere::Config
HNSWConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::HNSWCfg>();
conf->k = metaconf.k;
if (metaconf.nprobe < metaconf.k) {
conf->ef = metaconf.k + 32;
} else {
conf->ef = metaconf.nprobe;
}
return conf;
}
knowhere::Config
BinIDMAPConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::BinIDMAPCfg>();
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
conf->k = metaconf.k;
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
return conf;
}
knowhere::Config
BinIVFConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFBinCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist, 2048);
conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type;
conf->gpu_id = metaconf.gpu_id;
MatchBase(conf, knowhere::METRICTYPE::HAMMING);
return conf;
return true;
}
} // namespace engine
} // namespace milvus

View File

@ -14,117 +14,78 @@
#include <memory>
#include "VecIndex.h"
#include "knowhere/common/Config.h"
#include "utils/Json.h"
namespace milvus {
namespace engine {
// TODO(linxj): remove later, replace with real metaconf
constexpr int64_t TEMPMETA_DEFAULT_VALUE = -1;
struct TempMetaConf {
int64_t size = TEMPMETA_DEFAULT_VALUE;
int64_t nlist = TEMPMETA_DEFAULT_VALUE;
int64_t dim = TEMPMETA_DEFAULT_VALUE;
int64_t gpu_id = TEMPMETA_DEFAULT_VALUE;
int64_t k = TEMPMETA_DEFAULT_VALUE;
int64_t nprobe = TEMPMETA_DEFAULT_VALUE;
int64_t search_length = TEMPMETA_DEFAULT_VALUE;
knowhere::METRICTYPE metric_type = knowhere::DEFAULT_TYPE;
};
class ConfAdapter {
public:
virtual knowhere::Config
Match(const TempMetaConf& metaconf);
virtual bool
CheckTrain(milvus::json& oricfg);
virtual knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type);
virtual bool
CheckSearch(milvus::json& oricfg, const IndexType& type);
protected:
static void
MatchBase(knowhere::Config conf, knowhere::METRICTYPE defalut_metric = knowhere::METRICTYPE::L2);
// todo(linxj): refactor in next release.
//
// virtual bool
// CheckTrain(milvus::json&, IndexMode&) = 0;
//
// virtual bool
// CheckSearch(milvus::json&, const IndexType&, IndexMode&) = 0;
};
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
class IVFConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
bool
CheckTrain(milvus::json& oricfg) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
protected:
static int64_t
MatchNlist(const int64_t& size, const int64_t& nlist, const int64_t& per_nlist);
bool
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
};
class IVFSQConfAdapter : public IVFConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
bool
CheckTrain(milvus::json& oricfg) override;
};
class IVFPQConfAdapter : public IVFConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
protected:
static int64_t
MatchNlist(const int64_t& size, const int64_t& nlist);
bool
CheckTrain(milvus::json& oricfg) override;
};
class NSGConfAdapter : public IVFConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
bool
CheckTrain(milvus::json& oricfg) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) final;
};
class SPTAGKDTConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
};
class SPTAGBKTConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
bool
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
};
class BinIDMAPConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
bool
CheckTrain(milvus::json& oricfg) override;
};
class BinIVFConfAdapter : public IVFConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
bool
CheckTrain(milvus::json& oricfg) override;
};
class HNSWConfAdapter : public ConfAdapter {
public:
knowhere::Config
Match(const TempMetaConf& metaconf) override;
bool
CheckTrain(milvus::json& oricfg) override;
knowhere::Config
MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
bool
CheckSearch(milvus::json& oricfg, const IndexType& type) override;
};
} // namespace engine

View File

@ -54,8 +54,8 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexType::NSG_MIX, nsg_mix);
REGISTER_CONF_ADAPTER(SPTAGKDTConfAdapter, IndexType::SPTAG_KDT_RNT_CPU, sptag_kdt);
REGISTER_CONF_ADAPTER(SPTAGBKTConfAdapter, IndexType::SPTAG_BKT_RNT_CPU, sptag_bkt);
REGISTER_CONF_ADAPTER(ConfAdapter, IndexType::SPTAG_KDT_RNT_CPU, sptag_kdt);
REGISTER_CONF_ADAPTER(ConfAdapter, IndexType::SPTAG_BKT_RNT_CPU, sptag_bkt);
REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexType::HNSW, hnsw);
}

View File

@ -40,7 +40,7 @@ Status
VecIndexImpl::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float* xt) {
try {
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
fiu_do_on("VecIndexImpl.BuildAll.throw_knowhere_exception", throw knowhere::KnowhereException(""));
fiu_do_on("VecIndexImpl.BuildAll.throw_std_exception", throw std::exception());
@ -80,15 +80,13 @@ VecIndexImpl::Add(const int64_t& nb, const float* xb, const int64_t* ids, const
Status
VecIndexImpl::Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) {
try {
auto k = cfg->k;
int64_t k = cfg[knowhere::meta::TOPK];
auto dataset = GenDataset(nq, dim, xq);
Config search_cfg = cfg;
fiu_do_on("VecIndexImpl.Search.throw_knowhere_exception", throw knowhere::KnowhereException(""));
fiu_do_on("VecIndexImpl.Search.throw_std_exception", throw std::exception());
auto res = index_->Search(dataset, search_cfg);
auto res = index_->Search(dataset, cfg);
//{
// auto& ids = ids_array;
// auto& dists = dis_array;
@ -216,8 +214,7 @@ VecIndexImpl::GetVectorById(const int64_t n, const int64_t* xid, float* x, const
dataset->Set(knowhere::meta::DIM, dim);
dataset->Set(knowhere::meta::IDS, xid);
Config search_cfg = cfg;
auto res = index_->GetVectorById(dataset, search_cfg);
auto res = index_->GetVectorById(dataset, cfg);
// TODO(linxj): avoid copy here.
auto res_x = res->Get<float*>(knowhere::meta::TENSOR);
@ -242,14 +239,13 @@ VecIndexImpl::SearchById(const int64_t& nq, const int64_t* xq, float* dist, int6
}
try {
auto k = cfg->k;
int64_t k = cfg[knowhere::meta::TOPK];
auto dataset = std::make_shared<knowhere::Dataset>();
dataset->Set(knowhere::meta::ROWS, nq);
dataset->Set(knowhere::meta::DIM, dim);
dataset->Set(knowhere::meta::IDS, xq);
Config search_cfg = cfg;
auto res = index_->SearchById(dataset, search_cfg);
auto res = index_->SearchById(dataset, cfg);
//{
// auto& ids = ids_array;
// auto& dists = dis_array;
@ -337,7 +333,7 @@ BFIndex::Build(const Config& cfg) {
try {
fiu_do_on("BFIndex.Build.throw_knowhere_exception", throw knowhere::KnowhereException(""));
fiu_do_on("BFIndex.Build.throw_std_exception", throw std::exception());
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg);
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
@ -353,7 +349,7 @@ Status
BFIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float* xt) {
try {
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
fiu_do_on("BFIndex.BuildAll.throw_knowhere_exception", throw knowhere::KnowhereException(""));
fiu_do_on("BFIndex.BuildAll.throw_std_exception", throw std::exception());

View File

@ -12,6 +12,7 @@
#pragma once
#include <faiss/utils/ConcurrentBitset.h>
#include <thirdparty/nlohmann/json.hpp>
#include <memory>
#include <string>
@ -29,7 +30,8 @@
namespace milvus {
namespace engine {
using Config = knowhere::Config;
using json = nlohmann::json;
using Config = json;
// TODO(linxj): replace with string, Do refactor serialization
enum class IndexType {

View File

@ -38,7 +38,7 @@ IVFMixIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, co
fiu_do_on("IVFMixIndex.BuildAll.throw_knowhere_exception", throw knowhere::KnowhereException(""));
fiu_do_on("IVFMixIndex.BuildAll.throw_std_exception", throw std::exception());
dim = cfg->d;
dim = cfg[knowhere::meta::DIM];
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
index_->set_preprocessor(preprocessor);

View File

@ -186,16 +186,20 @@ TEST_F(DBTest, DB_TEST) {
std::stringstream ss;
uint64_t count = 0;
uint64_t prev_count = 0;
milvus::json json_params = {{"nprobe", 10}};
for (auto j = 0; j < 10; ++j) {
ss.str("");
db_->Size(count);
prev_count = count;
if (count == 0) {
continue;
}
START_TIMER;
std::vector<std::string> tags;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, qxb, result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str());
@ -306,37 +310,41 @@ TEST_F(DBTest, SEARCH_TEST) {
stat = db_->InsertVectors(TABLE_NAME, "", xb);
ASSERT_TRUE(stat.ok());
milvus::json json_params = {{"nprobe", 10}};
milvus::engine::TableIndex index;
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IDMAP;
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
{
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFFLAT;
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
{
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
index.extra_params_ = {{"nlist", 16384}};
// db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
//
// {
// std::vector<std::string> tags;
// milvus::engine::ResultIds result_ids;
// milvus::engine::ResultDistances result_distances;
// stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
// ASSERT_TRUE(stat.ok());
// }
//
// index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFFLAT;
// index.extra_params_ = {{"nlist", 16384}};
// db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
//
// {
// std::vector<std::string> tags;
// milvus::engine::ResultIds result_ids;
// milvus::engine::ResultDistances result_distances;
// stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
// ASSERT_TRUE(stat.ok());
// }
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
index.extra_params_ = {{"nlist", 16384}};
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
{
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -349,7 +357,7 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
#endif
@ -365,16 +373,22 @@ TEST_F(DBTest, SEARCH_TEST) {
}
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_distances);
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k,
json_params,
xq,
result_ids,
result_distances);
ASSERT_TRUE(stat.ok());
FIU_ENABLE_FIU("SqliteMetaImpl.FilesToSearch.throw_exception");
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_distances);
stat =
db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, json_params, xq, result_ids, result_distances);
ASSERT_FALSE(stat.ok());
fiu_disable("SqliteMetaImpl.FilesToSearch.throw_exception");
FIU_ENABLE_FIU("DBImpl.QueryByFileID.empty_files_array");
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_distances);
stat =
db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, json_params, xq, result_ids, result_distances);
ASSERT_FALSE(stat.ok());
fiu_disable("DBImpl.QueryByFileID.empty_files_array");
}
@ -385,13 +399,13 @@ TEST_F(DBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
FIU_ENABLE_FIU("SqliteMetaImpl.FilesToSearch.throw_exception");
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
ASSERT_FALSE(stat.ok());
fiu_disable("SqliteMetaImpl.FilesToSearch.throw_exception");
}
@ -409,7 +423,7 @@ TEST_F(DBTest, SEARCH_TEST) {
{
result_ids.clear();
result_dists.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, 10, xq, result_ids, result_dists);
stat = db_->Query(dummy_context_, TABLE_NAME, partition_tag, k, json_params, xq, result_ids, result_dists);
ASSERT_TRUE(stat.ok());
}
@ -423,7 +437,7 @@ TEST_F(DBTest, SEARCH_TEST) {
}
result_ids.clear();
result_dists.clear();
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, 10, xq, result_ids, result_dists);
stat = db_->QueryByFileID(dummy_context_, TABLE_NAME, file_ids, k, json_params, xq, result_ids, result_dists);
ASSERT_TRUE(stat.ok());
}
#endif
@ -564,13 +578,27 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, xb, result_ids, result_distances);
milvus::json json_params = {{"nprobe", 1}};
stat = db_->Query(dummy_context_, table_info.table_id_, tags, 1, json_params, xb, result_ids, result_distances);
ASSERT_FALSE(stat.ok());
std::vector<std::string> file_ids;
stat = db_->QueryByFileID(dummy_context_, table_info.table_id_, file_ids, 1, 1, xb, result_ids, result_distances);
stat = db_->QueryByFileID(dummy_context_,
table_info.table_id_,
file_ids,
1,
json_params,
xb,
result_ids,
result_distances);
ASSERT_FALSE(stat.ok());
stat = db_->Query(dummy_context_, table_info.table_id_, tags, 1, 1, milvus::engine::VectorsData(), result_ids,
stat = db_->Query(dummy_context_,
table_info.table_id_,
tags,
1,
json_params,
milvus::engine::VectorsData(),
result_ids,
result_distances);
ASSERT_FALSE(stat.ok());
@ -731,7 +759,7 @@ TEST_F(DBTest, INDEX_TEST) {
stat = db_->DescribeIndex(table_info.table_id_, index_out);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(index.engine_type_, index_out.engine_type_);
ASSERT_EQ(index.nlist_, index_out.nlist_);
ASSERT_EQ(index.extra_params_, index_out.extra_params_);
ASSERT_EQ(table_info.metric_type_, index_out.metric_type_);
stat = db_->DropIndex(table_info.table_id_);
@ -845,7 +873,9 @@ TEST_F(DBTest, PARTITION_TEST) {
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
milvus::json json_params = {{"nprobe", nprobe}};
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -853,7 +883,7 @@ TEST_F(DBTest, PARTITION_TEST) {
tags.clear();
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -861,7 +891,7 @@ TEST_F(DBTest, PARTITION_TEST) {
tags.push_back("\\d");
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, nprobe, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
}
@ -1074,11 +1104,12 @@ TEST_F(DBTestWAL, DB_STOP_TEST) {
const int64_t topk = 10;
const int64_t nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
milvus::engine::VectorsData qxb;
BuildVectors(qb, 0, qxb);
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, qb);
@ -1102,11 +1133,12 @@ TEST_F(DBTestWALRecovery, RECOVERY_WITH_NO_ERROR) {
const int64_t topk = 10;
const int64_t nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
milvus::engine::VectorsData qxb;
BuildVectors(qb, 0, qxb);
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_NE(result_ids.size() / topk, qb);
@ -1119,14 +1151,14 @@ TEST_F(DBTestWALRecovery, RECOVERY_WITH_NO_ERROR) {
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size(), 0);
db_->Flush();
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, nprobe, qxb, result_ids, result_distances);
stat = db_->Query(dummy_context_, table_info.table_id_, {}, topk, json_params, qxb, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, qb);
}
@ -1312,6 +1344,7 @@ TEST_F(DBTest2, SEARCH_WITH_DIFFERENT_INDEX) {
ASSERT_TRUE(stat.ok());
int topk = 10, nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
for (auto id : ids_to_search) {
// std::cout << "xxxxxxxxxxxxxxxxxxxx " << i << std::endl;
@ -1319,7 +1352,7 @@ TEST_F(DBTest2, SEARCH_WITH_DIFFERENT_INDEX) {
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, nprobe, id, result_ids,
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, json_params, id, result_ids,
result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids[0], id);
@ -1341,7 +1374,7 @@ result_distances);
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, nprobe, id, result_ids,
stat = db_->QueryByID(dummy_context_, table_info.table_id_, tags, topk, json_params, id, result_ids,
result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids[0], id);

View File

@ -78,16 +78,20 @@ TEST_F(MySqlDBTest, DB_TEST) {
std::stringstream ss;
uint64_t count = 0;
uint64_t prev_count = 0;
milvus::json json_params = {{"nprobe", 10}};
for (auto j = 0; j < 10; ++j) {
ss.str("");
db_->Size(count);
prev_count = count;
if (count == 0) {
continue;
}
START_TIMER;
std::vector<std::string> tags;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, qxb, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, qxb, result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str());
@ -186,7 +190,8 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, 10, xq, result_ids, result_distances);
milvus::json json_params = {{"nprobe", 10}};
stat = db_->Query(dummy_context_, TABLE_NAME, tags, k, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
}
@ -377,7 +382,8 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
std::vector<std::string> tags = {"0", std::to_string(PARTITION_COUNT - 1)};
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
milvus::json json_params = {{"nprobe", nprobe}};
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -385,7 +391,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
tags.clear();
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
@ -393,7 +399,7 @@ TEST_F(MySqlDBTest, PARTITION_TEST) {
tags.push_back("\\d");
result_ids.clear();
result_distances.clear();
stat = db_->Query(dummy_context_, TABLE_NAME, tags, 10, 10, xq, result_ids, result_distances);
stat = db_->Query(dummy_context_, TABLE_NAME, tags, topk, json_params, xq, result_ids, result_distances);
ASSERT_TRUE(stat.ok());
ASSERT_EQ(result_ids.size() / topk, nq);
}

View File

@ -301,6 +301,7 @@ TEST_F(DeleteTest, delete_with_index) {
milvus::engine::TableIndex index;
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
index.extra_params_ = {{"nlist", 100}};
stat = db_->CreateIndex(GetTableName(), index);
ASSERT_TRUE(stat.ok());
@ -368,12 +369,13 @@ TEST_F(DeleteTest, delete_single_vector) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(row_count, 0);
int topk = 1, nprobe = 1;
const int topk = 1, nprobe = 1;
milvus::json json_params = {{"nprobe", nprobe}};
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, xb, result_ids, result_distances);
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, xb, result_ids, result_distances);
ASSERT_TRUE(result_ids.empty());
ASSERT_TRUE(result_distances.empty());
// ASSERT_EQ(result_ids[0], -1);
@ -402,6 +404,7 @@ TEST_F(DeleteTest, delete_add_create_index) {
// ASSERT_TRUE(stat.ok());
milvus::engine::TableIndex index;
index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8;
index.extra_params_ = {{"nlist", 100}};
stat = db_->CreateIndex(GetTableName(), index);
ASSERT_TRUE(stat.ok());
@ -426,7 +429,8 @@ TEST_F(DeleteTest, delete_add_create_index) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(row_count, nb * 2 - 1);
int topk = 10, nprobe = 10;
const int topk = 10, nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
@ -435,14 +439,14 @@ TEST_F(DeleteTest, delete_add_create_index) {
qb.float_data_.resize(TABLE_DIM);
qb.vector_count_ = 1;
qb.id_array_.clear();
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, qb, result_ids, result_distances);
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, qb, result_ids, result_distances);
ASSERT_EQ(result_ids[0], xb2.id_array_.front());
ASSERT_LT(result_distances[0], 1e-4);
result_ids.clear();
result_distances.clear();
stat = db_->QueryByID(dummy_context_, GetTableName(), tags, topk, nprobe, ids_to_delete.front(), result_ids,
stat = db_->QueryByID(dummy_context_, GetTableName(), tags, topk, json_params, ids_to_delete.front(), result_ids,
result_distances);
ASSERT_EQ(result_ids[0], -1);
ASSERT_EQ(result_distances[0], std::numeric_limits<float>::max());
@ -496,7 +500,8 @@ TEST_F(DeleteTest, delete_add_auto_flush) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ(row_count, nb * 2 - 1);
int topk = 10, nprobe = 10;
const int topk = 10, nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
@ -505,7 +510,7 @@ TEST_F(DeleteTest, delete_add_auto_flush) {
qb.float_data_.resize(TABLE_DIM);
qb.vector_count_ = 1;
qb.id_array_.clear();
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, qb, result_ids, result_distances);
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, qb, result_ids, result_distances);
ASSERT_EQ(result_ids[0], xb2.id_array_.front());
ASSERT_LT(result_distances[0], 1e-4);
@ -555,7 +560,8 @@ TEST_F(CompactTest, compact_basic) {
stat = db_->Compact(GetTableName());
ASSERT_TRUE(stat.ok());
int topk = 1, nprobe = 1;
const int topk = 1, nprobe = 1;
milvus::json json_params = {{"nprobe", nprobe}};
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
@ -563,7 +569,8 @@ TEST_F(CompactTest, compact_basic) {
milvus::engine::VectorsData qb = xb;
for (auto& id : ids_to_delete) {
stat = db_->QueryByID(dummy_context_, GetTableName(), tags, topk, nprobe, id, result_ids, result_distances);
stat =
db_->QueryByID(dummy_context_, GetTableName(), tags, topk, json_params, id, result_ids, result_distances);
ASSERT_EQ(result_ids[0], -1);
ASSERT_EQ(result_distances[0], std::numeric_limits<float>::max());
}
@ -643,14 +650,17 @@ TEST_F(CompactTest, compact_with_index) {
ASSERT_TRUE(stat.ok());
ASSERT_FLOAT_EQ(table_index.engine_type_, index.engine_type_);
int topk = 10, nprobe = 10;
const int topk = 10, nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
for (auto& pair : search_vectors) {
auto& search = pair.second;
std::vector<std::string> tags;
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, search, result_ids, result_distances);
stat =
db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, search, result_ids, result_distances);
ASSERT_NE(result_ids[0], pair.first);
// ASSERT_LT(result_distances[0], 1e-4);
ASSERT_GT(result_distances[0], 1);

View File

@ -19,17 +19,58 @@
#include <fiu-local.h>
#include <fiu-control.h>
namespace {
static constexpr uint16_t DIMENSION = 64;
static constexpr int64_t ROW_COUNT = 1000;
static const char* INIT_PATH = "/tmp/milvus_index_1";
milvus::engine::ExecutionEnginePtr
CreateExecEngine(const milvus::json& json_params, milvus::engine::MetricType metric = milvus::engine::MetricType::IP) {
auto engine_ptr = milvus::engine::EngineFactory::Build(
DIMENSION,
INIT_PATH,
milvus::engine::EngineType::FAISS_IDMAP,
metric,
json_params);
std::vector<float> data;
std::vector<int64_t> ids;
data.reserve(ROW_COUNT * DIMENSION);
ids.reserve(ROW_COUNT);
for (int64_t i = 0; i < ROW_COUNT; i++) {
ids.push_back(i);
for (uint16_t k = 0; k < DIMENSION; k++) {
data.push_back(i * DIMENSION + k);
}
}
auto status = engine_ptr->AddWithIds((int64_t)ids.size(), data.data(), ids.data());
return engine_ptr;
}
} // namespace
TEST_F(EngineTest, FACTORY_TEST) {
const milvus::json index_params = {{"nlist", 1024}};
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::INVALID, milvus::engine::MetricType::IP, 1024);
512,
"/tmp/milvus_index_1",
milvus::engine::EngineType::INVALID,
milvus::engine::MetricType::IP,
index_params);
ASSERT_TRUE(engine_ptr == nullptr);
}
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_IDMAP, milvus::engine::MetricType::IP, 1024);
512,
"/tmp/milvus_index_1",
milvus::engine::EngineType::FAISS_IDMAP,
milvus::engine::MetricType::IP,
index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
@ -37,28 +78,40 @@ TEST_F(EngineTest, FACTORY_TEST) {
{
auto engine_ptr =
milvus::engine::EngineFactory::Build(512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_IVFFLAT,
milvus::engine::MetricType::IP, 1024);
milvus::engine::MetricType::IP, index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_IVFSQ8, milvus::engine::MetricType::IP, 1024);
512,
"/tmp/milvus_index_1",
milvus::engine::EngineType::FAISS_IVFSQ8,
milvus::engine::MetricType::IP,
index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::NSG_MIX, milvus::engine::MetricType::IP, 1024);
512,
"/tmp/milvus_index_1",
milvus::engine::EngineType::NSG_MIX,
milvus::engine::MetricType::IP,
index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::FAISS_PQ, milvus::engine::MetricType::IP, 1024);
512,
"/tmp/milvus_index_1",
milvus::engine::EngineType::FAISS_PQ,
milvus::engine::MetricType::IP,
index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
@ -66,7 +119,7 @@ TEST_F(EngineTest, FACTORY_TEST) {
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
milvus::engine::MetricType::L2, 1024);
milvus::engine::MetricType::L2, index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
@ -74,7 +127,7 @@ TEST_F(EngineTest, FACTORY_TEST) {
{
auto engine_ptr = milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
milvus::engine::MetricType::L2, 1024);
milvus::engine::MetricType::L2, index_params);
ASSERT_TRUE(engine_ptr != nullptr);
}
@ -85,7 +138,7 @@ TEST_F(EngineTest, FACTORY_TEST) {
FIU_ENABLE_FIU("ExecutionEngineImpl.CreatetVecIndex.invalid_type");
ASSERT_ANY_THROW(milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
milvus::engine::MetricType::L2, 1024));
milvus::engine::MetricType::L2, index_params));
fiu_disable("ExecutionEngineImpl.CreatetVecIndex.invalid_type");
}
@ -94,91 +147,88 @@ TEST_F(EngineTest, FACTORY_TEST) {
FIU_ENABLE_FIU("BFIndex.Build.throw_knowhere_exception");
ASSERT_ANY_THROW(milvus::engine::EngineFactory::Build(
512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT,
milvus::engine::MetricType::L2, 1024));
milvus::engine::MetricType::L2, index_params));
fiu_disable("BFIndex.Build.throw_knowhere_exception");
}
}
TEST_F(EngineTest, ENGINE_IMPL_TEST) {
fiu_init(0);
uint16_t dimension = 64;
std::string file_path = "/tmp/milvus_index_1";
auto engine_ptr = milvus::engine::EngineFactory::Build(
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, 1024);
std::vector<float> data;
std::vector<int64_t> ids;
const int row_count = 500;
data.reserve(row_count * dimension);
ids.reserve(row_count);
for (int64_t i = 0; i < row_count; i++) {
ids.push_back(i);
for (uint16_t k = 0; k < dimension; k++) {
data.push_back(i * dimension + k);
}
{
milvus::json index_params = {{"nlist", 10}};
auto engine_ptr = CreateExecEngine(index_params);
ASSERT_EQ(engine_ptr->Dimension(), DIMENSION);
ASSERT_EQ(engine_ptr->Count(), ROW_COUNT);
ASSERT_EQ(engine_ptr->GetLocation(), INIT_PATH);
ASSERT_EQ(engine_ptr->IndexMetricType(), milvus::engine::MetricType::IP);
ASSERT_ANY_THROW(engine_ptr->BuildIndex(INIT_PATH, milvus::engine::EngineType::INVALID));
FIU_ENABLE_FIU("VecIndexImpl.BuildAll.throw_knowhere_exception");
ASSERT_ANY_THROW(engine_ptr->BuildIndex(INIT_PATH, milvus::engine::EngineType::SPTAG_KDT));
fiu_disable("VecIndexImpl.BuildAll.throw_knowhere_exception");
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_2", milvus::engine::EngineType::FAISS_IVFSQ8);
ASSERT_NE(engine_build, nullptr);
}
auto status = engine_ptr->AddWithIds((int64_t)ids.size(), data.data(), ids.data());
ASSERT_TRUE(status.ok());
ASSERT_EQ(engine_ptr->Dimension(), dimension);
ASSERT_EQ(engine_ptr->Count(), ids.size());
ASSERT_EQ(engine_ptr->GetLocation(), file_path);
ASSERT_EQ(engine_ptr->IndexMetricType(), milvus::engine::MetricType::IP);
ASSERT_ANY_THROW(engine_ptr->BuildIndex(file_path, milvus::engine::EngineType::INVALID));
FIU_ENABLE_FIU("VecIndexImpl.BuildAll.throw_knowhere_exception");
ASSERT_ANY_THROW(engine_ptr->BuildIndex(file_path, milvus::engine::EngineType::SPTAG_KDT));
fiu_disable("VecIndexImpl.BuildAll.throw_knowhere_exception");
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_2", milvus::engine::EngineType::FAISS_IVFSQ8);
{
#ifndef MILVUS_GPU_VERSION
//PQ don't support IP In gpu version
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_3", milvus::engine::EngineType::FAISS_PQ);
milvus::json index_params = {{"nlist", 10}, {"m", 16}};
auto engine_ptr = CreateExecEngine(index_params);
//PQ don't support IP In gpu version
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_3", milvus::engine::EngineType::FAISS_PQ);
ASSERT_NE(engine_build, nullptr);
#endif
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_4", milvus::engine::EngineType::SPTAG_KDT);
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_5", milvus::engine::EngineType::SPTAG_BKT);
engine_ptr->BuildIndex("/tmp/milvus_index_SPTAG_BKT", milvus::engine::EngineType::SPTAG_BKT);
}
{
milvus::json index_params = {{"nlist", 10}};
auto engine_ptr = CreateExecEngine(index_params);
auto engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_4", milvus::engine::EngineType::SPTAG_KDT);
engine_build = engine_ptr->BuildIndex("/tmp/milvus_index_5", milvus::engine::EngineType::SPTAG_BKT);
engine_ptr->BuildIndex("/tmp/milvus_index_SPTAG_BKT", milvus::engine::EngineType::SPTAG_BKT);
//CPU version invoke CopyToCpu will fail
auto status = engine_ptr->CopyToCpu();
ASSERT_FALSE(status.ok());
}
#ifdef MILVUS_GPU_VERSION
FIU_ENABLE_FIU("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
engine_ptr->BuildIndex("/tmp/milvus_index_NSG_MIX", milvus::engine::EngineType::NSG_MIX);
engine_ptr->BuildIndex("/tmp/milvus_index_6", milvus::engine::EngineType::FAISS_IVFFLAT);
engine_ptr->BuildIndex("/tmp/milvus_index_7", milvus::engine::EngineType::FAISS_IVFSQ8);
ASSERT_ANY_THROW(engine_ptr->BuildIndex("/tmp/milvus_index_8", milvus::engine::EngineType::FAISS_IVFSQ8H));
ASSERT_ANY_THROW(engine_ptr->BuildIndex("/tmp/milvus_index_9", milvus::engine::EngineType::FAISS_PQ));
fiu_disable("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
#endif
{
FIU_ENABLE_FIU("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
milvus::json index_params = {{"search_length", 100}, {"out_degree", 40}, {"pool_size", 100}, {"knng", 200},
{"candidate_pool_size", 500}};
auto engine_ptr = CreateExecEngine(index_params, milvus::engine::MetricType::L2);
engine_ptr->BuildIndex("/tmp/milvus_index_NSG_MIX", milvus::engine::EngineType::NSG_MIX);
fiu_disable("ExecutionEngineImpl.CreatetVecIndex.gpu_res_disabled");
//CPU version invoke CopyToCpu will fail
status = engine_ptr->CopyToCpu();
ASSERT_FALSE(status.ok());
auto status = engine_ptr->CopyToGpu(0, false);
ASSERT_TRUE(status.ok());
status = engine_ptr->GpuCache(0);
ASSERT_TRUE(status.ok());
status = engine_ptr->CopyToGpu(0, false);
ASSERT_TRUE(status.ok());
#ifdef MILVUS_GPU_VERSION
status = engine_ptr->CopyToGpu(0, false);
ASSERT_TRUE(status.ok());
status = engine_ptr->GpuCache(0);
ASSERT_TRUE(status.ok());
status = engine_ptr->CopyToGpu(0, false);
ASSERT_TRUE(status.ok());
// auto new_engine = engine_ptr->Clone();
// ASSERT_EQ(new_engine->Dimension(), dimension);
// ASSERT_EQ(new_engine->Count(), ids.size());
// auto new_engine = engine_ptr->Clone();
// ASSERT_EQ(new_engine->Dimension(), dimension);
// ASSERT_EQ(new_engine->Count(), ids.size());
status = engine_ptr->CopyToCpu();
ASSERT_TRUE(status.ok());
engine_ptr->CopyToCpu();
ASSERT_TRUE(status.ok());
status = engine_ptr->CopyToCpu();
ASSERT_TRUE(status.ok());
engine_ptr->CopyToCpu();
ASSERT_TRUE(status.ok());
}
#endif
}
TEST_F(EngineTest, ENGINE_IMPL_NULL_INDEX_TEST) {
uint16_t dimension = 64;
std::string file_path = "/tmp/milvus_index_1";
milvus::json index_params = {{"nlist", 1024}};
auto engine_ptr = milvus::engine::EngineFactory::Build(
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, 1024);
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, index_params);
fiu_init(0); // init
fiu_enable("read_null_index", 1, NULL, 0);
@ -209,11 +259,13 @@ TEST_F(EngineTest, ENGINE_IMPL_NULL_INDEX_TEST) {
TEST_F(EngineTest, ENGINE_IMPL_THROW_EXCEPTION_TEST) {
uint16_t dimension = 64;
std::string file_path = "/tmp/invalid_file";
milvus::json index_params = {{"nlist", 1024}};
fiu_init(0); // init
fiu_enable("ValidateStringNotBool", 1, NULL, 0);
auto engine_ptr = milvus::engine::EngineFactory::Build(
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, 1024);
dimension, file_path, milvus::engine::EngineType::FAISS_IVFFLAT, milvus::engine::MetricType::IP, index_params);
fiu_disable("ValidateStringNotBool");

View File

@ -285,7 +285,8 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
search_vectors.insert(std::make_pair(xb.id_array_[index], search));
}
int topk = 10, nprobe = 10;
const int topk = 10, nprobe = 10;
milvus::json json_params = {{"nprobe", nprobe}};
for (auto& pair : search_vectors) {
auto& search = pair.second;
@ -293,7 +294,8 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
stat = db_->Query(dummy_context_, GetTableName(), tags, topk, nprobe, search, result_ids, result_distances);
stat =
db_->Query(dummy_context_, GetTableName(), tags, topk, json_params, search, result_ids, result_distances);
ASSERT_EQ(result_ids[0], pair.first);
ASSERT_LT(result_distances[0], 1e-4);
}
@ -388,6 +390,7 @@ TEST_F(MemManagerTest2, INSERT_BINARY_TEST) {
// std::stringstream ss;
// uint64_t count = 0;
// uint64_t prev_count = 0;
// milvus::json json_params = {{"nprobe", 10}};
//
// for (auto j = 0; j < 10; ++j) {
// ss.str("");
@ -397,7 +400,8 @@ TEST_F(MemManagerTest2, INSERT_BINARY_TEST) {
// START_TIMER;
//
// std::vector<std::string> tags;
// stat = db_->Query(dummy_context_, GetTableName(), tags, k, 10, qxb, result_ids, result_distances);
// stat =
// db_->Query(dummy_context_, GetTableName(), tags, k, json_params, qxb, result_ids, result_distances);
// ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
// STOP_TIMER(ss.str());
//

View File

@ -647,7 +647,7 @@ TEST_F(MetaTest, INDEX_TEST) {
milvus::engine::TableIndex index;
index.metric_type_ = 2;
index.nlist_ = 1234;
index.extra_params_ = {{"nlist", 1234}};
index.engine_type_ = 3;
status = impl_->UpdateTableIndex(table_id, index);
ASSERT_TRUE(status.ok());
@ -664,14 +664,13 @@ TEST_F(MetaTest, INDEX_TEST) {
milvus::engine::TableIndex index_out;
status = impl_->DescribeTableIndex(table_id, index_out);
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
ASSERT_EQ(index_out.nlist_, index.nlist_);
ASSERT_EQ(index_out.extra_params_, index.extra_params_);
ASSERT_EQ(index_out.engine_type_, index.engine_type_);
status = impl_->DropTableIndex(table_id);
ASSERT_TRUE(status.ok());
status = impl_->DescribeTableIndex(table_id, index_out);
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
ASSERT_NE(index_out.nlist_, index.nlist_);
ASSERT_NE(index_out.engine_type_, index.engine_type_);
status = impl_->UpdateTableFilesToIndex(table_id);

View File

@ -700,7 +700,7 @@ TEST_F(MySqlMetaTest, INDEX_TEST) {
milvus::engine::TableIndex index;
index.metric_type_ = 2;
index.nlist_ = 1234;
index.extra_params_ = {{"nlist", 1234}};
index.engine_type_ = 3;
status = impl_->UpdateTableIndex(table_id, index);
ASSERT_TRUE(status.ok());
@ -740,14 +740,13 @@ TEST_F(MySqlMetaTest, INDEX_TEST) {
milvus::engine::TableIndex index_out;
status = impl_->DescribeTableIndex(table_id, index_out);
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
ASSERT_EQ(index_out.nlist_, index.nlist_);
ASSERT_EQ(index_out.extra_params_, index.extra_params_);
ASSERT_EQ(index_out.engine_type_, index.engine_type_);
status = impl_->DropTableIndex(table_id);
ASSERT_TRUE(status.ok());
status = impl_->DescribeTableIndex(table_id, index_out);
ASSERT_EQ(index_out.metric_type_, index.metric_type_);
ASSERT_NE(index_out.nlist_, index.nlist_);
ASSERT_NE(index_out.engine_type_, index.engine_type_);
FIU_ENABLE_FIU("MySQLMetaImpl.DescribeTableIndex.null_connection");

Some files were not shown because too many files have changed in this diff Show More