Add interface for collect result

Signed-off-by: cai.zhang <cai.zhang@zilliz.com>
pull/4973/head^2
cai.zhang 2020-09-11 10:43:08 +08:00 committed by yefu.chen
parent 7479aedf07
commit 34b670f654
7 changed files with 3642 additions and 1541 deletions

View File

@ -30,7 +30,7 @@ storage:
secretkey: dd
pulsar:
address: localhost
address: 0.0.0.0
port: 6650
proxy:

File diff suppressed because it is too large Load Diff

3292
pkg/message/message.pb.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -63,7 +63,7 @@ include( FetchContent )
set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
set(FETCHCONTENT_QUIET OFF)
include( ThirdPartyPackages )
find_package(OpenMP REQUIRED)
# **************************** Compiler arguments ****************************
message( STATUS "Building Milvus CPU version" )

View File

@ -2,78 +2,86 @@
#include "pulsar/Result.h"
#include "PartitionPolicy.h"
#include "utils/CommonUtil.h"
#include <omp.h>
namespace milvus::message_client {
std::map<int64_t, std::vector<std::shared_ptr<grpc::QueryResult>>> total_results;
MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const pulsar::ClientConfiguration &config)
: client_id_(client_id), service_url_(service_url) {}
MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const uint32_t mut_parallelism, const pulsar::ClientConfiguration &config)
: client_id_(client_id), service_url_(service_url), mut_parallelism_(mut_parallelism) {}
Status MsgClientV2::Init(const std::string &insert_delete,
const std::string &search,
const std::string &time_sync,
const std::string &search_by_id,
const std::string &search_result) {
const std::string &search,
const std::string &time_sync,
const std::string &search_by_id,
const std::string &search_result) {
//create pulsar client
auto pulsar_client = std::make_shared<MsgClient>(service_url_);
//create pulsar producer
ProducerConfiguration producerConfiguration;
producerConfiguration.setPartitionsRoutingMode(ProducerConfiguration::CustomPartition);
producerConfiguration.setMessageRouter(std::make_shared<PartitionPolicy>());
insert_delete_producer_ = std::make_shared<MsgProducer>(pulsar_client, insert_delete, producerConfiguration);
// insert_delete_producer_ = std::make_shared<MsgProducer>(pulsar_client, insert_delete, producerConfiguration);
search_producer_ = std::make_shared<MsgProducer>(pulsar_client, search, producerConfiguration);
search_by_id_producer_ = std::make_shared<MsgProducer>(pulsar_client, search_result, producerConfiguration);
search_by_id_producer_ = std::make_shared<MsgProducer>(pulsar_client, search_by_id, producerConfiguration);
time_sync_producer_ = std::make_shared<MsgProducer>(pulsar_client, time_sync);
for (auto i = 0; i < mut_parallelism_; i++) {
paralle_mut_producers_.emplace_back(std::make_shared<MsgProducer>(pulsar_client,
insert_delete,
producerConfiguration));
}
//create pulsar consumer
std::string subscribe_name = std::to_string(CommonUtil::RandomUINT64());
consumer_ = std::make_shared<MsgConsumer>(pulsar_client, search_result+subscribe_name);
consumer_ = std::make_shared<MsgConsumer>(pulsar_client, search_result + subscribe_name);
auto result = consumer_->subscribe(search_result);
if (result != pulsar::Result::ResultOk) {
return Status(SERVER_UNEXPECTED_ERROR, "Pulsar message client init occur error, " + std::string(pulsar::strResult(result)));
return Status(SERVER_UNEXPECTED_ERROR,
"Pulsar message client init occur error, " + std::string(pulsar::strResult(result)));
}
return Status::OK();
}
int64_t GetQueryNodeNum() {
return 2;
return 1;
}
milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> &results){
//TODO: QueryNode has only one
int64_t length = results.size();
std::vector<float> all_scores;
std::vector<float> all_distance;
std::vector<grpc::KeyValuePair> all_kv_pairs;
std::vector<int> index(length * results[0]->scores_size());
milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> &results) {
//TODO: QueryNode has only one
int64_t length = results.size();
std::vector<float> all_scores;
std::vector<float> all_distance;
std::vector<grpc::KeyValuePair> all_kv_pairs;
std::vector<int> index(length * results[0]->scores_size());
for (int n = 0; n < length * results[0]->scores_size(); ++n) {
index[n] = n;
for (int n = 0; n < length * results[0]->scores_size(); ++n) {
index[n] = n;
}
for (int i = 0; i < length; i++) {
for (int j = 0; j < results[i]->scores_size(); j++) {
all_scores.push_back(results[i]->scores()[j]);
all_distance.push_back(results[i]->distances()[j]);
all_kv_pairs.push_back(results[i]->extra_params()[j]);
}
}
for (int i = 0; i < length; i++){
for (int j = 0; j < results[i]->scores_size(); j++){
all_scores.push_back(results[i]->scores()[j]);
all_distance.push_back(results[i]->distances()[j]);
all_kv_pairs.push_back(results[i]->extra_params()[j]);
}
}
for (int k = 0; k < all_distance.size() - 1; ++k) {
for (int l = k + 1; l < all_distance.size(); ++l) {
if (all_distance[l] > all_distance[k]){
float distance_temp = all_distance[k];
all_distance[k] = all_distance[l];
all_distance[l] = distance_temp;
int index_temp = index[k];
index[k] = index[l];
index[l] = index_temp;
}
}
for (int k = 0; k < all_distance.size() - 1; ++k) {
for (int l = k + 1; l < all_distance.size(); ++l) {
if (all_distance[l] > all_distance[k]) {
float distance_temp = all_distance[k];
all_distance[k] = all_distance[l];
all_distance[l] = distance_temp;
int index_temp = index[k];
index[k] = index[l];
index[l] = index_temp;
}
}
}
grpc::QueryResult result;
@ -88,10 +96,10 @@ milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryRes
result.mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
}
result.set_query_id(results[0]->query_id());
result.set_client_id(results[0]->client_id());
result.set_query_id(results[0]->query_id());
result.set_client_id(results[0]->client_id());
return result;
return result;
}
Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &result) {
@ -114,9 +122,10 @@ Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &
auto message = std::make_shared<grpc::QueryResult>(search_res_msg);
total_results[message->query_id()].push_back(message);
consumer_->acknowledge(msg);
} else {
return Status(DB_ERROR, "can't parse message which from pulsar!");
}
}
result = Aggregation(total_results[query_id]);
return Status::OK();
}
@ -126,8 +135,12 @@ Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uin
auto row_count = request.rows_data_size();
// TODO: Get the segment from master
int64_t segment = 0;
milvus::grpc::InsertOrDeleteMsg mut_msg;
auto stats = std::vector<pulsar::Result>(ParallelNum);
#pragma omp parallel for default(none), shared(row_count, request, timestamp, segment, stats), num_threads(ParallelNum)
for (auto i = 0; i < row_count; i++) {
milvus::grpc::InsertOrDeleteMsg mut_msg;
int this_thread = omp_get_thread_num();
mut_msg.set_op(milvus::grpc::OpType::INSERT);
mut_msg.set_uid(request.entity_id_array(i));
mut_msg.set_client_id(client_id_);
@ -138,29 +151,40 @@ Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uin
mut_msg.mutable_rows_data()->CopyFrom(request.rows_data(i));
mut_msg.mutable_extra_params()->CopyFrom(request.extra_params());
auto result = insert_delete_producer_->send(mut_msg);
auto result = paralle_mut_producers_[this_thread]->send(mut_msg);
if (result != pulsar::ResultOk) {
// TODO: error code
return Status(DB_ERROR, pulsar::strResult(result));
stats[this_thread] = result;
}
}
for (auto &stat : stats) {
if (stat == pulsar::ResultOk) {
return Status(DB_ERROR, pulsar::strResult(stat));
}
}
return Status::OK();
}
Status MsgClientV2::SendMutMessage(const milvus::grpc::DeleteByIDParam &request, uint64_t timestamp) {
milvus::grpc::InsertOrDeleteMsg mut_msg;
for (auto id: request.id_array()) {
auto stats = std::vector<pulsar::Result>(ParallelNum);
#pragma omp parallel for default(none), shared( request, timestamp, stats), num_threads(ParallelNum)
for (auto i = 0; i < request.id_array_size(); i++) {
milvus::grpc::InsertOrDeleteMsg mut_msg;
mut_msg.set_op(milvus::grpc::OpType::DELETE);
mut_msg.set_uid(GetUniqueQId());
mut_msg.set_client_id(client_id_);
mut_msg.set_uid(id);
mut_msg.set_uid(request.id_array(i));
mut_msg.set_collection_name(request.collection_name());
mut_msg.set_timestamp(timestamp);
auto result = insert_delete_producer_->send(mut_msg);
int this_thread = omp_get_thread_num();
auto result = paralle_mut_producers_[this_thread]->send(mut_msg);
if (result != pulsar::ResultOk) {
// TODO: error code
return Status(DB_ERROR, pulsar::strResult(result));
stats[this_thread] = result;
}
}
for (auto &stat : stats) {
if (stat == pulsar::ResultOk) {
return Status(DB_ERROR, pulsar::strResult(stat));
}
}
return Status::OK();
@ -200,7 +224,8 @@ Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, u
search_msg.mutable_extra_params(l)->CopyFrom(request.extra_params(l));
}
auto result = search_by_id_producer_->send(search_msg);
std::cout << search_msg.collection_name() << std::endl;
auto result = search_producer_->send(search_msg);
if (result != pulsar::Result::ResultOk) {
return Status(DB_ERROR, pulsar::strResult(result));
}
@ -276,7 +301,10 @@ Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, u
}
MsgClientV2::~MsgClientV2() {
insert_delete_producer_->close();
// insert_delete_producer_->close();
for (auto& producer: paralle_mut_producers_){
producer->close();
}
search_producer_->close();
search_by_id_producer_->close();
time_sync_producer_->close();

View File

@ -6,10 +6,13 @@
#include "grpc/message.pb.h"
namespace milvus::message_client {
constexpr uint32_t ParallelNum = 12 * 20;
class MsgClientV2 {
public:
MsgClientV2(int64_t client_id,
const std::string &service_url,
const uint32_t mut_parallelism = ParallelNum,
const pulsar::ClientConfiguration &config = pulsar::ClientConfiguration());
~MsgClientV2();
@ -40,9 +43,11 @@ class MsgClientV2 {
int64_t client_id_;
std::string service_url_;
std::shared_ptr<MsgConsumer> consumer_;
std::shared_ptr<MsgProducer> insert_delete_producer_;
// std::shared_ptr<MsgProducer> insert_delete_producer_;
std::shared_ptr<MsgProducer> search_producer_;
std::shared_ptr<MsgProducer> time_sync_producer_;
std::shared_ptr<MsgProducer> search_by_id_producer_;
std::vector<std::shared_ptr<MsgProducer>> paralle_mut_producers_;
const uint32_t mut_parallelism_;
};
}

View File

@ -78,7 +78,7 @@ TEST(CLIENT_CPP, GetResult) {
// client_v2.SendQueryMessage();
milvus::grpc::SearchParam request;
auto status_send = client_v2.SendQueryMessage(request, query_id);
auto status_send = client_v2.SendQueryMessage(request, 10, query_id);
milvus::grpc::QueryResult result;
auto status = client_v2.GetQueryResult(query_id, result);