change test_client to unittest

Former-commit-id: 8efbb1314e6cebd5c36643cca0d31668a8be639d
pull/191/head
groot 2019-04-28 12:42:04 +08:00
parent 95f3900910
commit 8e23d2eb66
4 changed files with 227 additions and 167 deletions

View File

@ -33,6 +33,11 @@ link_directories(
"${VECWISE_THIRD_PARTY_BUILD}/lib"
)
set(unittest_libs
gtest_main
gmock_main
pthread)
set(client_libs
yaml-cpp
boost_system
@ -44,7 +49,7 @@ set(client_libs
include_directories(/usr/local/cuda/include)
find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64)
target_link_libraries(test_client ${client_libs} ${cuda_library})
target_link_libraries(test_client ${unittest_libs} ${client_libs} ${cuda_library})
#add_executable(skeleton_server
# ../src/thrift/gen-cpp/VecService_server.skeleton.cpp

View File

@ -8,6 +8,8 @@
#include <libgen.h>
#include <cstring>
#include <string>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <easylogging++.h>
#include "src/ClientApp.h"
@ -26,20 +28,14 @@ main(int argc, char *argv[]) {
// return 0;
std::string app_name = basename(argv[0]);
static struct option long_options[] = {{"conf_file", required_argument, 0, 'c'},
static struct option long_options[] = {{"conf_file", optional_argument, 0, 'c'},
{"help", no_argument, 0, 'h'},
{NULL, 0, 0, 0}};
int option_index = 0;
std::string config_filename;
std::string config_filename = "../../conf/server_config.yaml";
app_name = argv[0];
if(argc < 2) {
print_help(app_name);
printf("Client exit...\n");
return EXIT_FAILURE;
}
int value;
while ((value = getopt_long(argc, argv, "c:p:dh", long_options, &option_index)) != -1) {
switch (value) {
@ -64,8 +60,8 @@ main(int argc, char *argv[]) {
zilliz::vecwise::client::ClientApp app;
app.Run(config_filename);
printf("Client exit...\n");
return 0;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
void

View File

@ -3,177 +3,21 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include <utils/TimeRecorder.h>
#include "ClientApp.h"
#include "ClientSession.h"
#include "server/ServerConfig.h"
#include "Log.h"
#include <time.h>
namespace zilliz {
namespace vecwise {
namespace client {
namespace {
std::string CurrentTime() {
time_t tt;
time( &tt );
tt = tt + 8*3600;
tm* t= gmtime( &tt );
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
return str;
}
}
void ClientApp::Run(const std::string &config_file) {
server::ServerConfig& config = server::ServerConfig::GetInstance();
config.LoadConfigFile(config_file);
CLIENT_LOG_INFO << "Load config file:" << config_file;
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
std::string address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1");
int32_t port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001);
std::string protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary");
//std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool");
int32_t flush_interval = server_config.GetInt32Value(server::CONFIG_SERVER_DB_FLUSH_INTERVAL);
CLIENT_LOG_INFO << "Connect to server: " << address << ":" << std::to_string(port);
try {
ClientSession session(address, port, protocol);
//add group
const int32_t dim = 256;
VecGroup group;
group.id = CurrentTime();
group.dimension = dim;
group.index_type = 0;
session.interface()->add_group(group);
//prepare data
const int64_t count = 10000;
VecTensorList tensor_list;
VecBinaryTensorList bin_tensor_list;
for (int64_t k = 0; k < count; k++) {
VecTensor tensor;
tensor.tensor.reserve(dim);
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(dim*sizeof(double));
double* d_p = (double*)(const_cast<char*>(bin_tensor.tensor.data()));
for (int32_t i = 0; i < dim; i++) {
double val = (double)(i + k);
tensor.tensor.push_back(val);
d_p[i] = val;
}
tensor.uid = "normal_vec_" + std::to_string(k);
tensor_list.tensor_list.emplace_back(tensor);
bin_tensor.uid = "binary_vec_" + std::to_string(k);
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
}
//add vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
for (int64_t k = 0; k < count; k++) {
session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
if(k%1000 == 0) {
CLIENT_LOG_INFO << "add normal vector no." << k;
}
}
rc.Elapse("done!");
}
//add vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
session.interface()->add_vector_batch(group.id, tensor_list);
rc.Elapse("done!");
}
//add binary vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
for (int64_t k = 0; k < count; k++) {
session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
if(k%1000 == 0) {
CLIENT_LOG_INFO << "add binary vector no." << k;
}
}
rc.Elapse("done!");
}
//add binary vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch");
session.interface()->add_binary_vector_batch(group.id, bin_tensor_list);
rc.Elapse("done!");
}
std::cout << "Sleep " << flush_interval << " seconds..." << std::endl;
sleep(flush_interval);
//search vector
{
server::TimeRecorder rc("Search top_k");
VecTensor tensor;
for (int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double) (i + 100));
}
VecSearchResult res;
VecTimeRangeList range;
session.interface()->search_vector(res, group.id, 10, tensor, range);
std::cout << "Search result: " << std::endl;
for(auto id : res.id_list) {
std::cout << "\t" << id << std::endl;
}
rc.Elapse("done!");
}
//search binary vector
{
server::TimeRecorder rc("Search binary batch top_k");
VecBinaryTensorList tensor_list;
for(int32_t k = 350; k < 360; k++) {
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(dim * sizeof(double));
double* d_p = new double[dim];
for (int32_t i = 0; i < dim; i++) {
d_p[i] = (double)(i + k);
}
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, dim * sizeof(double));
tensor_list.tensor_list.emplace_back(bin_tensor);
}
VecSearchResultList res;
VecTimeRangeList range;
session.interface()->search_binary_vector_batch(res, group.id, 5, tensor_list, range);
std::cout << "Search binary batch result: " << std::endl;
for(size_t i = 0 ; i < res.result_list.size(); i++) {
std::cout << "No " << i << ":" << std::endl;
for(auto id : res.result_list[i].id_list) {
std::cout << "\t" << id << std::endl;
}
}
rc.Elapse("done!");
}
} catch (std::exception& ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
}
CLIENT_LOG_INFO << "Test finished";
}
}

View File

@ -0,0 +1,215 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
#include <utils/TimeRecorder.h>
#include "ClientApp.h"
#include "ClientSession.h"
#include "server/ServerConfig.h"
#include "Log.h"
#include <time.h>
using namespace zilliz::vecwise;
namespace {
static const int32_t VEC_DIMENSION = 256;
std::string CurrentTime() {
time_t tt;
time( &tt );
tt = tt + 8*3600;
tm* t= gmtime( &tt );
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
return str;
}
void GetServerAddress(std::string& address, int32_t& port, std::string& protocol) {
server::ServerConfig& config = server::ServerConfig::GetInstance();
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1");
port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001);
protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary");
//std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool");
}
int32_t GetFlushInterval() {
server::ServerConfig& config = server::ServerConfig::GetInstance();
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
return server_config.GetInt32Value(server::CONFIG_SERVER_DB_FLUSH_INTERVAL);
}
std::string GetGroupID() {
static std::string s_id(CurrentTime());
return s_id;
}
}
TEST(AddVector, CLIENT_TEST) {
try {
std::string address, protocol;
int32_t port = 0;
GetServerAddress(address, port, protocol);
client::ClientSession session(address, port, protocol);
//add group
VecGroup group;
group.id = GetGroupID();
group.dimension = VEC_DIMENSION;
group.index_type = 0;
session.interface()->add_group(group);
//prepare data
const int64_t count = 10000;
VecTensorList tensor_list;
VecBinaryTensorList bin_tensor_list;
for (int64_t k = 0; k < count; k++) {
VecTensor tensor;
tensor.tensor.reserve(VEC_DIMENSION);
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
double *d_p = (double *) (const_cast<char *>(bin_tensor.tensor.data()));
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
double val = (double) (i + k);
tensor.tensor.push_back(val);
d_p[i] = val;
}
tensor.uid = "normal_vec_" + std::to_string(k);
tensor_list.tensor_list.emplace_back(tensor);
bin_tensor.uid = "binary_vec_" + std::to_string(k);
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
}
//add vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
for (int64_t k = 0; k < count; k++) {
session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
if (k % 1000 == 0) {
CLIENT_LOG_INFO << "add normal vector no." << k;
}
}
rc.Elapse("done!");
}
//add vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
session.interface()->add_vector_batch(group.id, tensor_list);
rc.Elapse("done!");
}
//add binary vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
for (int64_t k = 0; k < count; k++) {
session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
if (k % 1000 == 0) {
CLIENT_LOG_INFO << "add binary vector no." << k;
}
}
rc.Elapse("done!");
}
//add binary vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch");
session.interface()->add_binary_vector_batch(group.id, bin_tensor_list);
rc.Elapse("done!");
}
} catch (std::exception &ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
ASSERT_TRUE(false);
}
}
TEST(SearchVector, CLIENT_TEST) {
std::cout << "Sleep " << GetFlushInterval() << " seconds..." << std::endl;
sleep(GetFlushInterval());
try {
std::string address, protocol;
int32_t port = 0;
GetServerAddress(address, port, protocol);
client::ClientSession session(address, port, protocol);
//search vector
{
const int32_t anchor_index = 100;
const int64_t top_k = 10;
server::TimeRecorder rc("Search top_k");
VecTensor tensor;
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
tensor.tensor.push_back((double) (i + anchor_index));
}
VecSearchResult res;
VecTimeRangeList range;
session.interface()->search_vector(res, GetGroupID(), top_k, tensor, range);
std::cout << "Search result: " << std::endl;
for(auto id : res.id_list) {
std::cout << "\t" << id << std::endl;
}
rc.Elapse("done!");
ASSERT_EQ(res.id_list.size(), (uint64_t)top_k);
if(!res.id_list.empty()) {
ASSERT_TRUE(res.id_list[0].find(std::to_string(anchor_index)) != std::string::npos);
}
}
//search binary vector
{
const int32_t anchor_index = 100;
const int32_t search_count = 10;
const int64_t top_k = 10;
server::TimeRecorder rc("Search binary batch top_k");
VecBinaryTensorList tensor_list;
for(int32_t k = anchor_index; k < anchor_index + search_count; k++) {
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
double* d_p = new double[VEC_DIMENSION];
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
d_p[i] = (double)(i + k);
}
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, VEC_DIMENSION * sizeof(double));
tensor_list.tensor_list.emplace_back(bin_tensor);
}
VecSearchResultList res;
VecTimeRangeList range;
session.interface()->search_binary_vector_batch(res, GetGroupID(), top_k, tensor_list, range);
std::cout << "Search binary batch result: " << std::endl;
for(size_t i = 0 ; i < res.result_list.size(); i++) {
std::cout << "No " << i << ":" << std::endl;
for(auto id : res.result_list[i].id_list) {
std::cout << "\t" << id << std::endl;
}
}
rc.Elapse("done!");
ASSERT_EQ(res.result_list.size(), search_count);
for(size_t i = 0 ; i < res.result_list.size(); i++) {
ASSERT_EQ(res.result_list[i].id_list.size(), (uint64_t) top_k);
if (!res.result_list[i].id_list.empty()) {
ASSERT_TRUE(res.result_list[i].id_list[0].find(std::to_string(anchor_index + i)) != std::string::npos);
}
}
}
} catch (std::exception& ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
ASSERT_TRUE(false);
}
}