mirror of https://github.com/milvus-io/milvus.git
change test_client to unittest
Former-commit-id: 8efbb1314e6cebd5c36643cca0d31668a8be639dpull/191/head
parent
95f3900910
commit
8e23d2eb66
|
@ -33,6 +33,11 @@ link_directories(
|
|||
"${VECWISE_THIRD_PARTY_BUILD}/lib"
|
||||
)
|
||||
|
||||
set(unittest_libs
|
||||
gtest_main
|
||||
gmock_main
|
||||
pthread)
|
||||
|
||||
set(client_libs
|
||||
yaml-cpp
|
||||
boost_system
|
||||
|
@ -44,7 +49,7 @@ set(client_libs
|
|||
include_directories(/usr/local/cuda/include)
|
||||
find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64)
|
||||
|
||||
target_link_libraries(test_client ${client_libs} ${cuda_library})
|
||||
target_link_libraries(test_client ${unittest_libs} ${client_libs} ${cuda_library})
|
||||
|
||||
#add_executable(skeleton_server
|
||||
# ../src/thrift/gen-cpp/VecService_server.skeleton.cpp
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
#include <libgen.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <easylogging++.h>
|
||||
|
||||
#include "src/ClientApp.h"
|
||||
|
@ -26,20 +28,14 @@ main(int argc, char *argv[]) {
|
|||
// return 0;
|
||||
|
||||
std::string app_name = basename(argv[0]);
|
||||
static struct option long_options[] = {{"conf_file", required_argument, 0, 'c'},
|
||||
static struct option long_options[] = {{"conf_file", optional_argument, 0, 'c'},
|
||||
{"help", no_argument, 0, 'h'},
|
||||
{NULL, 0, 0, 0}};
|
||||
|
||||
int option_index = 0;
|
||||
std::string config_filename;
|
||||
std::string config_filename = "../../conf/server_config.yaml";
|
||||
app_name = argv[0];
|
||||
|
||||
if(argc < 2) {
|
||||
print_help(app_name);
|
||||
printf("Client exit...\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
int value;
|
||||
while ((value = getopt_long(argc, argv, "c:p:dh", long_options, &option_index)) != -1) {
|
||||
switch (value) {
|
||||
|
@ -64,8 +60,8 @@ main(int argc, char *argv[]) {
|
|||
zilliz::vecwise::client::ClientApp app;
|
||||
app.Run(config_filename);
|
||||
|
||||
printf("Client exit...\n");
|
||||
return 0;
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -3,177 +3,21 @@
|
|||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include <utils/TimeRecorder.h>
|
||||
#include "ClientApp.h"
|
||||
#include "ClientSession.h"
|
||||
#include "server/ServerConfig.h"
|
||||
#include "Log.h"
|
||||
|
||||
#include <time.h>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
|
||||
namespace {
|
||||
std::string CurrentTime() {
|
||||
time_t tt;
|
||||
time( &tt );
|
||||
tt = tt + 8*3600;
|
||||
tm* t= gmtime( &tt );
|
||||
|
||||
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
|
||||
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
|
||||
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
|
||||
|
||||
return str;
|
||||
}
|
||||
}
|
||||
|
||||
void ClientApp::Run(const std::string &config_file) {
|
||||
server::ServerConfig& config = server::ServerConfig::GetInstance();
|
||||
config.LoadConfigFile(config_file);
|
||||
|
||||
CLIENT_LOG_INFO << "Load config file:" << config_file;
|
||||
|
||||
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
|
||||
std::string address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1");
|
||||
int32_t port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001);
|
||||
std::string protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary");
|
||||
//std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool");
|
||||
int32_t flush_interval = server_config.GetInt32Value(server::CONFIG_SERVER_DB_FLUSH_INTERVAL);
|
||||
|
||||
CLIENT_LOG_INFO << "Connect to server: " << address << ":" << std::to_string(port);
|
||||
|
||||
try {
|
||||
ClientSession session(address, port, protocol);
|
||||
|
||||
//add group
|
||||
const int32_t dim = 256;
|
||||
VecGroup group;
|
||||
group.id = CurrentTime();
|
||||
group.dimension = dim;
|
||||
group.index_type = 0;
|
||||
session.interface()->add_group(group);
|
||||
|
||||
//prepare data
|
||||
const int64_t count = 10000;
|
||||
VecTensorList tensor_list;
|
||||
VecBinaryTensorList bin_tensor_list;
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
VecTensor tensor;
|
||||
tensor.tensor.reserve(dim);
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(dim*sizeof(double));
|
||||
double* d_p = (double*)(const_cast<char*>(bin_tensor.tensor.data()));
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
double val = (double)(i + k);
|
||||
tensor.tensor.push_back(val);
|
||||
d_p[i] = val;
|
||||
}
|
||||
|
||||
tensor.uid = "normal_vec_" + std::to_string(k);
|
||||
tensor_list.tensor_list.emplace_back(tensor);
|
||||
|
||||
bin_tensor.uid = "binary_vec_" + std::to_string(k);
|
||||
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
//add vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
|
||||
if(k%1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add normal vector no." << k;
|
||||
}
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
|
||||
session.interface()->add_vector_batch(group.id, tensor_list);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
|
||||
if(k%1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add binary vector no." << k;
|
||||
}
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch");
|
||||
session.interface()->add_binary_vector_batch(group.id, bin_tensor_list);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
std::cout << "Sleep " << flush_interval << " seconds..." << std::endl;
|
||||
sleep(flush_interval);
|
||||
|
||||
//search vector
|
||||
{
|
||||
server::TimeRecorder rc("Search top_k");
|
||||
VecTensor tensor;
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
tensor.tensor.push_back((double) (i + 100));
|
||||
}
|
||||
|
||||
VecSearchResult res;
|
||||
VecTimeRangeList range;
|
||||
session.interface()->search_vector(res, group.id, 10, tensor, range);
|
||||
|
||||
std::cout << "Search result: " << std::endl;
|
||||
for(auto id : res.id_list) {
|
||||
std::cout << "\t" << id << std::endl;
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//search binary vector
|
||||
{
|
||||
server::TimeRecorder rc("Search binary batch top_k");
|
||||
VecBinaryTensorList tensor_list;
|
||||
for(int32_t k = 350; k < 360; k++) {
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(dim * sizeof(double));
|
||||
double* d_p = new double[dim];
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
d_p[i] = (double)(i + k);
|
||||
}
|
||||
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, dim * sizeof(double));
|
||||
tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
VecSearchResultList res;
|
||||
VecTimeRangeList range;
|
||||
session.interface()->search_binary_vector_batch(res, group.id, 5, tensor_list, range);
|
||||
|
||||
std::cout << "Search binary batch result: " << std::endl;
|
||||
for(size_t i = 0 ; i < res.result_list.size(); i++) {
|
||||
std::cout << "No " << i << ":" << std::endl;
|
||||
for(auto id : res.result_list[i].id_list) {
|
||||
std::cout << "\t" << id << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
}
|
||||
|
||||
CLIENT_LOG_INFO << "Test finished";
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
// Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
// Proprietary and confidential.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#include <gtest/gtest.h>
|
||||
#include <utils/TimeRecorder.h>
|
||||
#include "ClientApp.h"
|
||||
#include "ClientSession.h"
|
||||
#include "server/ServerConfig.h"
|
||||
#include "Log.h"
|
||||
|
||||
#include <time.h>
|
||||
|
||||
using namespace zilliz::vecwise;
|
||||
|
||||
namespace {
|
||||
static const int32_t VEC_DIMENSION = 256;
|
||||
|
||||
std::string CurrentTime() {
|
||||
time_t tt;
|
||||
time( &tt );
|
||||
tt = tt + 8*3600;
|
||||
tm* t= gmtime( &tt );
|
||||
|
||||
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
|
||||
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
|
||||
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
void GetServerAddress(std::string& address, int32_t& port, std::string& protocol) {
|
||||
server::ServerConfig& config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
|
||||
address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1");
|
||||
port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001);
|
||||
protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary");
|
||||
//std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool");
|
||||
}
|
||||
|
||||
int32_t GetFlushInterval() {
|
||||
server::ServerConfig& config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
|
||||
return server_config.GetInt32Value(server::CONFIG_SERVER_DB_FLUSH_INTERVAL);
|
||||
}
|
||||
|
||||
std::string GetGroupID() {
|
||||
static std::string s_id(CurrentTime());
|
||||
return s_id;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AddVector, CLIENT_TEST) {
|
||||
try {
|
||||
std::string address, protocol;
|
||||
int32_t port = 0;
|
||||
GetServerAddress(address, port, protocol);
|
||||
client::ClientSession session(address, port, protocol);
|
||||
|
||||
//add group
|
||||
VecGroup group;
|
||||
group.id = GetGroupID();
|
||||
group.dimension = VEC_DIMENSION;
|
||||
group.index_type = 0;
|
||||
session.interface()->add_group(group);
|
||||
|
||||
//prepare data
|
||||
const int64_t count = 10000;
|
||||
VecTensorList tensor_list;
|
||||
VecBinaryTensorList bin_tensor_list;
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
VecTensor tensor;
|
||||
tensor.tensor.reserve(VEC_DIMENSION);
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
|
||||
double *d_p = (double *) (const_cast<char *>(bin_tensor.tensor.data()));
|
||||
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
double val = (double) (i + k);
|
||||
tensor.tensor.push_back(val);
|
||||
d_p[i] = val;
|
||||
}
|
||||
|
||||
tensor.uid = "normal_vec_" + std::to_string(k);
|
||||
tensor_list.tensor_list.emplace_back(tensor);
|
||||
|
||||
bin_tensor.uid = "binary_vec_" + std::to_string(k);
|
||||
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
//add vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
|
||||
if (k % 1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add normal vector no." << k;
|
||||
}
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
|
||||
session.interface()->add_vector_batch(group.id, tensor_list);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
|
||||
if (k % 1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add binary vector no." << k;
|
||||
}
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch");
|
||||
session.interface()->add_binary_vector_batch(group.id, bin_tensor_list);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
} catch (std::exception &ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SearchVector, CLIENT_TEST) {
|
||||
std::cout << "Sleep " << GetFlushInterval() << " seconds..." << std::endl;
|
||||
sleep(GetFlushInterval());
|
||||
|
||||
try {
|
||||
std::string address, protocol;
|
||||
int32_t port = 0;
|
||||
GetServerAddress(address, port, protocol);
|
||||
client::ClientSession session(address, port, protocol);
|
||||
|
||||
//search vector
|
||||
{
|
||||
const int32_t anchor_index = 100;
|
||||
const int64_t top_k = 10;
|
||||
server::TimeRecorder rc("Search top_k");
|
||||
VecTensor tensor;
|
||||
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
tensor.tensor.push_back((double) (i + anchor_index));
|
||||
}
|
||||
|
||||
VecSearchResult res;
|
||||
VecTimeRangeList range;
|
||||
session.interface()->search_vector(res, GetGroupID(), top_k, tensor, range);
|
||||
|
||||
std::cout << "Search result: " << std::endl;
|
||||
for(auto id : res.id_list) {
|
||||
std::cout << "\t" << id << std::endl;
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
|
||||
ASSERT_EQ(res.id_list.size(), (uint64_t)top_k);
|
||||
if(!res.id_list.empty()) {
|
||||
ASSERT_TRUE(res.id_list[0].find(std::to_string(anchor_index)) != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
//search binary vector
|
||||
{
|
||||
const int32_t anchor_index = 100;
|
||||
const int32_t search_count = 10;
|
||||
const int64_t top_k = 10;
|
||||
server::TimeRecorder rc("Search binary batch top_k");
|
||||
VecBinaryTensorList tensor_list;
|
||||
for(int32_t k = anchor_index; k < anchor_index + search_count; k++) {
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
|
||||
double* d_p = new double[VEC_DIMENSION];
|
||||
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
d_p[i] = (double)(i + k);
|
||||
}
|
||||
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, VEC_DIMENSION * sizeof(double));
|
||||
tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
VecSearchResultList res;
|
||||
VecTimeRangeList range;
|
||||
session.interface()->search_binary_vector_batch(res, GetGroupID(), top_k, tensor_list, range);
|
||||
|
||||
std::cout << "Search binary batch result: " << std::endl;
|
||||
for(size_t i = 0 ; i < res.result_list.size(); i++) {
|
||||
std::cout << "No " << i << ":" << std::endl;
|
||||
for(auto id : res.result_list[i].id_list) {
|
||||
std::cout << "\t" << id << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
rc.Elapse("done!");
|
||||
|
||||
ASSERT_EQ(res.result_list.size(), search_count);
|
||||
for(size_t i = 0 ; i < res.result_list.size(); i++) {
|
||||
ASSERT_EQ(res.result_list[i].id_list.size(), (uint64_t) top_k);
|
||||
if (!res.result_list[i].id_list.empty()) {
|
||||
ASSERT_TRUE(res.result_list[i].id_list[0].find(std::to_string(anchor_index + i)) != std::string::npos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue