diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index 7fd7271ca5..205ed8b5ad 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -20,3 +20,4 @@ Please mark all change in change log and use the ticket from JIRA. - MS-4 - Refactor the vecwise_engine code structure - MS-6 - Implement SDK interface part 1 - MS-20 - Clean Code Part 1 +- MS-6 - Implement SDK interface part 2 diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7a36bdfd10..b6b8ad1aa0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -101,7 +101,6 @@ link_directories(${CMAKE_CURRRENT_BINARY_DIR}) #execute_process(COMMAND bash build.sh # WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/third_party) - add_subdirectory(src) if (BUILD_UNIT_TEST) diff --git a/cpp/conf/server_config.yaml b/cpp/conf/server_config.yaml index 964a23fd48..523b1f9968 100644 --- a/cpp/conf/server_config.yaml +++ b/cpp/conf/server_config.yaml @@ -1,7 +1,7 @@ server_config: address: 0.0.0.0 port: 33001 - transfer_protocol: json #optional: binary, compact, json, debug + transfer_protocol: binary #optional: binary, compact, json server_mode: thread_pool #optional: simple, thread_pool gpu_index: 0 #which gpu to be used diff --git a/cpp/conf/server_config_template.yaml b/cpp/conf/server_config_template.yaml index fb6f6beae2..b701956498 100644 --- a/cpp/conf/server_config_template.yaml +++ b/cpp/conf/server_config_template.yaml @@ -1,7 +1,7 @@ server_config: address: 0.0.0.0 port: 33001 - transfer_protocol: json #optional: binary, compact, json, debug + transfer_protocol: binary #optional: binary, compact, json server_mode: thread_pool #optional: simple, thread_pool gpu_index: 0 #which gpu to be used diff --git a/cpp/src/sdk/examples/simple/src/ClientTest.cpp b/cpp/src/sdk/examples/simple/src/ClientTest.cpp index 01669680fc..b57ab75c37 100644 --- a/cpp/src/sdk/examples/simple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/simple/src/ClientTest.cpp @@ -13,13 +13,49 @@ using namespace megasearch; namespace { + +#define BLOCK_SPLITER std::cout << "===========================================" << std::endl; + void PrintTableSchema(const megasearch::TableSchema& tb_schema) { - std::cout << "===========================================" << std::endl; + BLOCK_SPLITER std::cout << "Table name: " << tb_schema.table_name << std::endl; std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl; std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl; std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl; - std::cout << "===========================================" << std::endl; + BLOCK_SPLITER + } + + void PrintRecordIdArray(const std::vector& record_ids) { + BLOCK_SPLITER + std::cout << "Returned id array count: " << record_ids.size() << std::endl; +#if 0 + for(auto id : record_ids) { + std::cout << std::to_string(id) << std::endl; + } +#endif + BLOCK_SPLITER + } + + void PrintSearchResult(const std::vector& topk_query_result_array) { + BLOCK_SPLITER + std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl; + + int32_t index = 0; + for(auto& result : topk_query_result_array) { + index++; + std::cout << "No." << std::to_string(index) << " vector top " + << std::to_string(result.query_result_arrays.size()) + << " search result:" << std::endl; + for(auto& item : result.query_result_arrays) { + std::cout << "\t" << std::to_string(item.id) << "\tscore:" << std::to_string(item.score); + for(auto& attribute : item.column_map) { + std::cout << "\t" << attribute.first << ":" << attribute.second; + } + std::cout << std::endl; + } + } + + BLOCK_SPLITER } std::string CurrentTime() { @@ -42,8 +78,29 @@ namespace { static const std::string TABLE_NAME = GetTableName(); static const std::string VECTOR_COLUMN_NAME = "face_vector"; + static const std::string AGE_COLUMN_NAME = "age"; + static const std::string CITY_COLUMN_NAME = "city"; static const int64_t TABLE_DIMENSION = 512; + TableSchema BuildTableSchema() { + TableSchema tb_schema; + VectorColumn col1; + col1.name = VECTOR_COLUMN_NAME; + col1.dimension = TABLE_DIMENSION; + col1.store_raw_vector = true; + tb_schema.vector_column_array.emplace_back(col1); + + Column col2 = {ColumnType::int8, AGE_COLUMN_NAME}; + tb_schema.attribute_column_array.emplace_back(col2); + + Column col3 = {ColumnType::int16, CITY_COLUMN_NAME}; + tb_schema.attribute_column_array.emplace_back(col3); + + tb_schema.table_name = TABLE_NAME; + + return tb_schema; + } + void BuildVectors(int64_t from, int64_t to, std::vector* vector_record_array, std::vector* query_record_array) { @@ -58,6 +115,19 @@ namespace { query_record_array->clear(); } + static const std::map CITY_MAP = { + {0, "Beijing"}, + {1, "Shanhai"}, + {2, "Hangzhou"}, + {3, "Guangzhou"}, + {4, "Shenzheng"}, + {5, "Wuhan"}, + {6, "Chengdu"}, + {7, "Chongqin"}, + {8, "Tianjing"}, + {9, "Hongkong"}, + }; + for (int64_t k = from; k < to; k++) { std::vector f_p; @@ -69,12 +139,16 @@ namespace { if(vector_record_array) { RowRecord record; record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p)); + record.attribute_map[AGE_COLUMN_NAME] = std::to_string(k%100); + record.attribute_map[CITY_COLUMN_NAME] = CITY_MAP.at(k%CITY_MAP.size()); vector_record_array->emplace_back(record); } if(query_record_array) { QueryRecord record; record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p)); + record.selected_column_array.push_back(AGE_COLUMN_NAME); + record.selected_column_array.push_back(CITY_COLUMN_NAME); query_record_array->emplace_back(record); } } @@ -87,29 +161,35 @@ ClientTest::Test(const std::string& address, const std::string& port) { ConnectParam param = { address, port }; conn->Connect(param); + {//get server version + std::string version = conn->ServerVersion(); + std::cout << "MegaSearch server version: " << version << std::endl; + } + + { + std::cout << "ShowTables" << std::endl; + std::vector tables; + Status stat = conn->ShowTables(tables); + std::cout << "Function call status: " << stat.ToString() << std::endl; + std::cout << "All tables: " << std::endl; + for(auto& table : tables) { + std::cout << "\t" << table << std::endl; + } + } + {//create table - TableSchema tb_schema; - VectorColumn col1; - col1.name = VECTOR_COLUMN_NAME; - col1.dimension = TABLE_DIMENSION; - col1.store_raw_vector = true; - tb_schema.vector_column_array.emplace_back(col1); - - Column col2; - col2.name = "age"; - tb_schema.attribute_column_array.emplace_back(col2); - - tb_schema.table_name = TABLE_NAME; - + TableSchema tb_schema = BuildTableSchema(); PrintTableSchema(tb_schema); + std::cout << "CreateTable" << std::endl; Status stat = conn->CreateTable(tb_schema); - std::cout << "Create table result: " << stat.ToString() << std::endl; + std::cout << "Function call status: " << stat.ToString() << std::endl; } {//describe table TableSchema tb_schema; + std::cout << "DescribeTable" << std::endl; Status stat = conn->DescribeTable(TABLE_NAME, tb_schema); - std::cout << "Describe table result: " << stat.ToString() << std::endl; + std::cout << "Function call status: " << stat.ToString() << std::endl; PrintTableSchema(tb_schema); } @@ -117,22 +197,23 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::vector record_array; BuildVectors(0, 10000, &record_array, nullptr); std::vector record_ids; - std::cout << "Begin add vectors" << std::endl; + std::cout << "AddVector" << std::endl; Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids); - std::cout << "Add vector result: " << stat.ToString() << std::endl; - std::cout << "Returned vector ids: " << record_ids.size() << std::endl; + std::cout << "Function call status: " << stat.ToString() << std::endl; + PrintRecordIdArray(record_ids); } {//search vectors + std::cout << "Waiting data persist. Sleep 10 seconds ..." << std::endl; sleep(10); std::vector record_array; BuildVectors(500, 510, nullptr, &record_array); std::vector topk_query_result_array; - std::cout << "Begin search vectors" << std::endl; + std::cout << "SearchVector" << std::endl; Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, 10); - std::cout << "Search vector result: " << stat.ToString() << std::endl; - std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl; + std::cout << "Function call status: " << stat.ToString() << std::endl; + PrintSearchResult(topk_query_result_array); } // {//delete table @@ -140,5 +221,13 @@ ClientTest::Test(const std::string& address, const std::string& port) { // std::cout << "Delete table result: " << stat.ToString() << std::endl; // } + {//server status + std::string status = conn->ServerStatus(); + std::cout << "Server status before disconnect: " << status << std::endl; + } Connection::Destroy(conn); + {//server status + std::string status = conn->ServerStatus(); + std::cout << "Server status after disconnect: " << status << std::endl; + } } \ No newline at end of file diff --git a/cpp/src/sdk/src/client/ClientProxy.cpp b/cpp/src/sdk/src/client/ClientProxy.cpp index 90d55ca1e0..28ebad6b84 100644 --- a/cpp/src/sdk/src/client/ClientProxy.cpp +++ b/cpp/src/sdk/src/client/ClientProxy.cpp @@ -21,7 +21,7 @@ ClientProxy::Connect(const ConnectParam ¶m) { Disconnect(); int32_t port = atoi(param.port.c_str()); - return ClientPtr()->Connect(param.ip_address, port, "json"); + return ClientPtr()->Connect(param.ip_address, port, THRIFT_PROTOCOL_BINARY); } Status @@ -58,7 +58,7 @@ ClientProxy::Disconnect() { std::string ClientProxy::ClientVersion() const { - return std::string("Current Version"); + return std::string("v1.0"); } Status diff --git a/cpp/src/sdk/src/client/ThriftClient.cpp b/cpp/src/sdk/src/client/ThriftClient.cpp index d492ee69d4..7d95795b37 100644 --- a/cpp/src/sdk/src/client/ThriftClient.cpp +++ b/cpp/src/sdk/src/client/ThriftClient.cpp @@ -50,14 +50,12 @@ ThriftClient::Connect(const std::string& address, int32_t port, const std::strin stdcxx::shared_ptr socket_ptr(new transport::TSocket(address, port)); stdcxx::shared_ptr transport_ptr(new TBufferedTransport(socket_ptr)); stdcxx::shared_ptr protocol_ptr; - if(protocol == "binary") { + if(protocol == THRIFT_PROTOCOL_BINARY) { protocol_ptr.reset(new TBinaryProtocol(transport_ptr)); - } else if(protocol == "json") { + } else if(protocol == THRIFT_PROTOCOL_JSON) { protocol_ptr.reset(new TJSONProtocol(transport_ptr)); - } else if(protocol == "compact") { + } else if(protocol == THRIFT_PROTOCOL_COMPACT) { protocol_ptr.reset(new TCompactProtocol(transport_ptr)); - } else if(protocol == "debug") { - protocol_ptr.reset(new TDebugProtocol(transport_ptr)); } else { //CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently"; return Status(StatusCode::Invalid, "unsupported protocol"); diff --git a/cpp/src/sdk/src/client/ThriftClient.h b/cpp/src/sdk/src/client/ThriftClient.h index ccfca6b47c..4834835edf 100644 --- a/cpp/src/sdk/src/client/ThriftClient.h +++ b/cpp/src/sdk/src/client/ThriftClient.h @@ -14,6 +14,10 @@ namespace megasearch { using MegasearchServiceClientPtr = std::shared_ptr; +static const std::string THRIFT_PROTOCOL_JSON = "json"; +static const std::string THRIFT_PROTOCOL_BINARY = "binary"; +static const std::string THRIFT_PROTOCOL_COMPACT = "compact"; + class ThriftClient { public: ThriftClient(); diff --git a/cpp/src/server/MegasearchHandler.cpp b/cpp/src/server/MegasearchHandler.cpp index 54310d72be..4e7c575606 100644 --- a/cpp/src/server/MegasearchHandler.cpp +++ b/cpp/src/server/MegasearchHandler.cpp @@ -32,14 +32,14 @@ MegasearchServiceHandler::DeleteTable(const std::string &table_name) { void MegasearchServiceHandler::CreateTablePartition(const thrift::CreateTablePartitionParam ¶m) { - // Your implementation goes here - printf("CreateTablePartition\n"); + BaseTaskPtr task_ptr = CreateTablePartitionTask::Create(param); + MegasearchScheduler::ExecTask(task_ptr); } void MegasearchServiceHandler::DeleteTablePartition(const thrift::DeleteTablePartitionParam ¶m) { - // Your implementation goes here - printf("DeleteTablePartition\n"); + BaseTaskPtr task_ptr = DeleteTablePartitionTask::Create(param); + MegasearchScheduler::ExecTask(task_ptr); } void @@ -67,14 +67,14 @@ MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std: void MegasearchServiceHandler::ShowTables(std::vector &_return) { - // Your implementation goes here - printf("ShowTables\n"); + BaseTaskPtr task_ptr = ShowTablesTask::Create(_return); + MegasearchScheduler::ExecTask(task_ptr); } void MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) { - // Your implementation goes here - printf("Ping\n"); + BaseTaskPtr task_ptr = PingTask::Create(cmd, _return); + MegasearchScheduler::ExecTask(task_ptr); } } diff --git a/cpp/src/server/MegasearchServer.cpp b/cpp/src/server/MegasearchServer.cpp index 5940d20bbc..f771fc4dd8 100644 --- a/cpp/src/server/MegasearchServer.cpp +++ b/cpp/src/server/MegasearchServer.cpp @@ -54,7 +54,6 @@ MegasearchServer::StartService() { stdcxx::shared_ptr server_transport(new TServerSocket(address, port)); stdcxx::shared_ptr transport_factory(new TBufferedTransportFactory()); - std::string protocol = "json"; stdcxx::shared_ptr protocol_factory; if (protocol == "binary") { protocol_factory.reset(new TBinaryProtocolFactory()); @@ -62,8 +61,6 @@ MegasearchServer::StartService() { protocol_factory.reset(new TJSONProtocolFactory()); } else if (protocol == "compact") { protocol_factory.reset(new TCompactProtocolFactory()); - } else if (protocol == "debug") { - protocol_factory.reset(new TDebugProtocolFactory()); } else { //SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently"; return; diff --git a/cpp/src/server/MegasearchTask.cpp b/cpp/src/server/MegasearchTask.cpp index 6f01b38ce1..a1df83bcb3 100644 --- a/cpp/src/server/MegasearchTask.cpp +++ b/cpp/src/server/MegasearchTask.cpp @@ -21,6 +21,7 @@ namespace server { static const std::string DQL_TASK_GROUP = "dql"; static const std::string DDL_DML_TASK_GROUP = "ddl_dml"; +static const std::string PING_TASK_GROUP = "ping"; static const std::string VECTOR_UID = "uid"; static const uint64_t USE_MT = 5000; @@ -48,6 +49,10 @@ namespace { } } + ~DBWrapper() { + delete db_; + } + zilliz::vecwise::engine::DB* DB() { return db_; } private: @@ -78,17 +83,17 @@ BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) { ServerError CreateTableTask::OnExecute() { TimeRecorder rc("CreateTableTask"); - + try { if(schema_.vector_column_array.empty()) { return SERVER_INVALID_ARGUMENT; } IVecIdMapper::GetInstance()->AddGroup(schema_.table_name); - engine::meta::TableSchema table_schema; - table_schema.dimension = (uint16_t)schema_.vector_column_array[0].dimension; - table_schema.table_id = schema_.table_name; - engine::Status stat = DB()->CreateTable(table_schema); + engine::meta::TableSchema table_info; + table_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension; + table_info.table_id = schema_.table_name; + engine::Status stat = DB()->CreateTable(table_info); if(!stat.ok()) {//could exist error_msg_ = "Engine failed: " + stat.ToString(); SERVER_LOG_ERROR << error_msg_; @@ -109,7 +114,7 @@ ServerError CreateTableTask::OnExecute() { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema) - : BaseTask(DDL_DML_TASK_GROUP), + : BaseTask(PING_TASK_GROUP), table_name_(table_name), schema_(schema) { schema_.table_name = table_name_; @@ -123,9 +128,9 @@ ServerError DescribeTableTask::OnExecute() { TimeRecorder rc("DescribeTableTask"); try { - engine::meta::TableSchema table_schema; - table_schema.table_id = table_name_; - engine::Status stat = DB()->DescribeTable(table_schema); + engine::meta::TableSchema table_info; + table_info.table_id = table_name_; + engine::Status stat = DB()->DescribeTable(table_info); if(!stat.ok()) { error_code_ = SERVER_GROUP_NOT_EXIST; error_msg_ = "Engine failed: " + stat.ToString(); @@ -154,8 +159,8 @@ DeleteTableTask::DeleteTableTask(const std::string& table_name) } -BaseTaskPtr DeleteTableTask::Create(const std::string& table_id) { - return std::shared_ptr(new DeleteTableTask(table_id)); +BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) { + return std::shared_ptr(new DeleteTableTask(group_id)); } ServerError DeleteTableTask::OnExecute() { @@ -168,6 +173,60 @@ ServerError DeleteTableTask::OnExecute() { return SERVER_NOT_IMPLEMENT; } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +CreateTablePartitionTask::CreateTablePartitionTask(const thrift::CreateTablePartitionParam ¶m) + : BaseTask(DDL_DML_TASK_GROUP), + param_(param) { + +} + +BaseTaskPtr CreateTablePartitionTask::Create(const thrift::CreateTablePartitionParam ¶m) { + return std::shared_ptr(new CreateTablePartitionTask(param)); +} + +ServerError CreateTablePartitionTask::OnExecute() { + error_code_ = SERVER_NOT_IMPLEMENT; + error_msg_ = "create table partition not implemented"; + SERVER_LOG_ERROR << error_msg_; + + return SERVER_NOT_IMPLEMENT; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +DeleteTablePartitionTask::DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam ¶m) + : BaseTask(DDL_DML_TASK_GROUP), + param_(param) { + +} + +BaseTaskPtr DeleteTablePartitionTask::Create(const thrift::DeleteTablePartitionParam ¶m) { + return std::shared_ptr(new DeleteTablePartitionTask(param)); +} + +ServerError DeleteTablePartitionTask::OnExecute() { + error_code_ = SERVER_NOT_IMPLEMENT; + error_msg_ = "delete table partition not implemented"; + SERVER_LOG_ERROR << error_msg_; + + return SERVER_NOT_IMPLEMENT; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +ShowTablesTask::ShowTablesTask(std::vector& tables) + : BaseTask(PING_TASK_GROUP), + tables_(tables) { + +} + +BaseTaskPtr ShowTablesTask::Create(std::vector& tables) { + return std::shared_ptr(new ShowTablesTask(tables)); +} + +ServerError ShowTablesTask::OnExecute() { + IVecIdMapper::GetInstance()->AllGroups(tables_); + + return SERVER_SUCCESS; +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// AddVectorTask::AddVectorTask(const std::string& table_name, @@ -195,9 +254,9 @@ ServerError AddVectorTask::OnExecute() { return SERVER_SUCCESS; } - engine::meta::TableSchema table_schema; - table_schema.table_id = table_name_; - engine::Status stat = DB()->DescribeTable(table_schema); + engine::meta::TableSchema table_info; + table_info.table_id = table_name_; + engine::Status stat = DB()->DescribeTable(table_info); if(!stat.ok()) { error_code_ = SERVER_GROUP_NOT_EXIST; error_msg_ = "Engine failed: " + stat.ToString(); @@ -208,7 +267,7 @@ ServerError AddVectorTask::OnExecute() { rc.Record("get group info"); uint64_t vec_count = (uint64_t)record_array_.size(); - uint64_t group_dim = table_schema.dimension; + uint64_t group_dim = table_info.dimension; std::vector vec_f; vec_f.resize(vec_count*group_dim);//allocate enough memory for(uint64_t i = 0; i < vec_count; i++) { @@ -228,6 +287,7 @@ ServerError AddVectorTask::OnExecute() { return error_code_; } + //convert double array to float array(thrift has no float type) const double* d_p = reinterpret_cast(record.vector_map.begin()->second.data()); for(uint64_t d = 0; d < vec_dim; d++) { vec_f[i*vec_dim + d] = (float)(d_p[d]); @@ -245,12 +305,27 @@ ServerError AddVectorTask::OnExecute() { return error_code_; } - if(record_ids_.size() < vec_count) { + if(record_ids_.size() != vec_count) { SERVER_LOG_ERROR << "Vector ID not returned"; return SERVER_UNEXPECTED_ERROR; } - rc.Record("done"); + //persist attributes + for(uint64_t i = 0; i < vec_count; i++) { + const auto &record = record_array_[i]; + + //any attributes? + if(record.attribute_map.empty()) { + continue; + } + + std::string nid = std::to_string(record_ids_[i]); + std::string attrib_str; + AttributeSerializer::Encode(record.attribute_map, attrib_str); + IVecIdMapper::GetInstance()->Put(nid, attrib_str, table_name_); + } + + rc.Record("persist vector attributes"); } catch (std::exception& ex) { error_code_ = SERVER_UNEXPECTED_ERROR; @@ -293,9 +368,9 @@ ServerError SearchVectorTask::OnExecute() { return error_code_; } - engine::meta::TableSchema table_schema; - table_schema.table_id = table_name_; - engine::Status stat = DB()->DescribeTable(table_schema); + engine::meta::TableSchema table_info; + table_info.table_id = table_name_; + engine::Status stat = DB()->DescribeTable(table_info); if(!stat.ok()) { error_code_ = SERVER_GROUP_NOT_EXIST; error_msg_ = "Engine failed: " + stat.ToString(); @@ -305,7 +380,7 @@ ServerError SearchVectorTask::OnExecute() { std::vector vec_f; uint64_t record_count = (uint64_t)record_array_.size(); - vec_f.resize(record_count*table_schema.dimension); + vec_f.resize(record_count*table_info.dimension); for(uint64_t i = 0; i < record_array_.size(); i++) { const auto& record = record_array_[i]; @@ -317,14 +392,15 @@ ServerError SearchVectorTask::OnExecute() { } uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value? - if (vec_dim != table_schema.dimension) { + if (vec_dim != table_info.dimension) { SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim - << " vs. group dimension:" << table_schema.dimension; + << " vs. group dimension:" << table_info.dimension; error_code_ = SERVER_INVALID_VECTOR_DIMENSION; error_msg_ = "Engine failed: " + stat.ToString(); return error_code_; } + //convert double array to float array(thrift has no float type) const double* d_p = reinterpret_cast(record.vector_map.begin()->second.data()); for(uint64_t d = 0; d < vec_dim; d++) { vec_f[i*vec_dim + d] = (float)(d_p[d]); @@ -336,25 +412,50 @@ ServerError SearchVectorTask::OnExecute() { std::vector dates; engine::QueryResults results; stat = DB()->Query(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results); + rc.Record("search vectors from engine"); if(!stat.ok()) { SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); return SERVER_UNEXPECTED_ERROR; - } else { - rc.Record("do searching"); - for(engine::QueryResult& result : results){ - thrift::TopKQueryResult thrift_topk_result; - for(auto id : result) { - thrift::QueryResult thrift_result; - thrift_result.__set_id(id); - thrift_topk_result.query_result_arrays.emplace_back(thrift_result); - } - - result_array_.emplace_back(thrift_topk_result); - } - rc.Record("construct result"); } - rc.Record("done"); + if(results.size() != record_count) { + SERVER_LOG_ERROR << "Search result not returned"; + return SERVER_UNEXPECTED_ERROR; + } + + //construct result array + for(uint64_t i = 0; i < record_count; i++) { + auto& result = results[i]; + const auto& record = record_array_[i]; + + thrift::TopKQueryResult thrift_topk_result; + for(auto id : result) { + thrift::QueryResult thrift_result; + thrift_result.__set_id(id); + + //need get attributes? + if(record.selected_column_array.empty()) { + thrift_topk_result.query_result_arrays.emplace_back(thrift_result); + continue; + } + + std::string nid = std::to_string(id); + std::string attrib_str; + IVecIdMapper::GetInstance()->Get(nid, attrib_str, table_name_); + + AttribMap attrib_map; + AttributeSerializer::Decode(attrib_str, attrib_map); + + for(auto& attribute : record.selected_column_array) { + thrift_result.column_map[attribute] = attrib_map[attribute]; + } + + thrift_topk_result.query_result_arrays.emplace_back(thrift_result); + } + + result_array_.emplace_back(thrift_topk_result); + } + rc.Record("construct result"); } catch (std::exception& ex) { error_code_ = SERVER_UNEXPECTED_ERROR; @@ -366,6 +467,26 @@ ServerError SearchVectorTask::OnExecute() { return SERVER_SUCCESS; } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +PingTask::PingTask(const std::string& cmd, std::string& result) + : BaseTask(PING_TASK_GROUP), + cmd_(cmd), + result_(result) { + +} + +BaseTaskPtr PingTask::Create(const std::string& cmd, std::string& result) { + return std::shared_ptr(new PingTask(cmd, result)); +} + +ServerError PingTask::OnExecute() { + if(cmd_ == "version") { + result_ = "v1.2.0";//currently hardcode + } + + return SERVER_SUCCESS; +} + } } } diff --git a/cpp/src/server/MegasearchTask.h b/cpp/src/server/MegasearchTask.h index 26c23e1b26..af4178eb21 100644 --- a/cpp/src/server/MegasearchTask.h +++ b/cpp/src/server/MegasearchTask.h @@ -65,6 +65,50 @@ private: std::string table_name_; }; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class CreateTablePartitionTask : public BaseTask { +public: + static BaseTaskPtr Create(const thrift::CreateTablePartitionParam ¶m); + +protected: + CreateTablePartitionTask(const thrift::CreateTablePartitionParam ¶m); + + ServerError OnExecute() override; + + +private: + const thrift::CreateTablePartitionParam ¶m_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class DeleteTablePartitionTask : public BaseTask { +public: + static BaseTaskPtr Create(const thrift::DeleteTablePartitionParam ¶m); + +protected: + DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam ¶m); + + ServerError OnExecute() override; + + +private: + const thrift::DeleteTablePartitionParam ¶m_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class ShowTablesTask : public BaseTask { +public: + static BaseTaskPtr Create(std::vector& tables); + +protected: + ShowTablesTask(std::vector& tables); + + ServerError OnExecute() override; + +private: + std::vector& tables_; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class AddVectorTask : public BaseTask { public: @@ -108,6 +152,21 @@ private: std::vector& result_array_; }; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class PingTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& cmd, std::string& result); + +protected: + PingTask(const std::string& cmd, std::string& result); + + ServerError OnExecute() override; + +private: + std::string cmd_; + std::string& result_; +}; + } } } \ No newline at end of file diff --git a/cpp/src/server/RocksIdMapper.cpp b/cpp/src/server/RocksIdMapper.cpp index 6eec10c60d..37c67d88ec 100644 --- a/cpp/src/server/RocksIdMapper.cpp +++ b/cpp/src/server/RocksIdMapper.cpp @@ -108,6 +108,19 @@ bool RocksIdMapper::IsGroupExist(const std::string& group) const { return IsGroupExistInternal(group); } +ServerError RocksIdMapper::AllGroups(std::vector& groups) const { + groups.clear(); + + std::lock_guard lck(db_mutex_); + for(auto& pair : column_handles_) { + if(pair.first == ROCKSDB_DEFAULT_GROUP) { + continue; + } + groups.push_back(pair.first); + } + + return SERVER_SUCCESS; +} ServerError RocksIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) { std::lock_guard lck(db_mutex_); diff --git a/cpp/src/server/RocksIdMapper.h b/cpp/src/server/RocksIdMapper.h index 1ffee7f335..714a4ef47b 100644 --- a/cpp/src/server/RocksIdMapper.h +++ b/cpp/src/server/RocksIdMapper.h @@ -26,6 +26,7 @@ class RocksIdMapper : public IVecIdMapper{ ServerError AddGroup(const std::string& group) override; bool IsGroupExist(const std::string& group) const override; + ServerError AllGroups(std::vector& groups) const override; ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override; ServerError Put(const std::vector& nid, const std::vector& sid, const std::string& group = "") override; diff --git a/cpp/src/server/VecIdMapper.cpp b/cpp/src/server/VecIdMapper.cpp index d9de2ca3eb..df24a9c729 100644 --- a/cpp/src/server/VecIdMapper.cpp +++ b/cpp/src/server/VecIdMapper.cpp @@ -52,6 +52,15 @@ SimpleIdMapper::IsGroupExist(const std::string& group) const { return id_groups_.count(group) > 0; } +ServerError SimpleIdMapper::AllGroups(std::vector& groups) const { + groups.clear(); + + for(auto& pair : id_groups_) { + groups.push_back(pair.first); + } + + return SERVER_SUCCESS; +} //not thread-safe ServerError SimpleIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) { diff --git a/cpp/src/server/VecIdMapper.h b/cpp/src/server/VecIdMapper.h index f3c2bdde27..7dcdc1dd2a 100644 --- a/cpp/src/server/VecIdMapper.h +++ b/cpp/src/server/VecIdMapper.h @@ -27,6 +27,7 @@ public: virtual ServerError AddGroup(const std::string& group) = 0; virtual bool IsGroupExist(const std::string& group) const = 0; + virtual ServerError AllGroups(std::vector& groups) const = 0; virtual ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") = 0; virtual ServerError Put(const std::vector& nid, const std::vector& sid, const std::string& group = "") = 0; @@ -46,6 +47,7 @@ public: ServerError AddGroup(const std::string& group) override; bool IsGroupExist(const std::string& group) const override; + ServerError AllGroups(std::vector& groups) const override; ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override; ServerError Put(const std::vector& nid, const std::vector& sid, const std::string& group = "") override;