mirror of https://github.com/milvus-io/milvus.git
Add interface for collect result
Signed-off-by: cai.zhang <cai.zhang@zilliz.com>pull/4973/head^2
parent
7479aedf07
commit
34b670f654
|
@ -30,7 +30,7 @@ storage:
|
||||||
secretkey: dd
|
secretkey: dd
|
||||||
|
|
||||||
pulsar:
|
pulsar:
|
||||||
address: localhost
|
address: 0.0.0.0
|
||||||
port: 6650
|
port: 6650
|
||||||
|
|
||||||
proxy:
|
proxy:
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -63,7 +63,7 @@ include( FetchContent )
|
||||||
set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
|
set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
|
||||||
set(FETCHCONTENT_QUIET OFF)
|
set(FETCHCONTENT_QUIET OFF)
|
||||||
include( ThirdPartyPackages )
|
include( ThirdPartyPackages )
|
||||||
|
find_package(OpenMP REQUIRED)
|
||||||
# **************************** Compiler arguments ****************************
|
# **************************** Compiler arguments ****************************
|
||||||
message( STATUS "Building Milvus CPU version" )
|
message( STATUS "Building Milvus CPU version" )
|
||||||
|
|
||||||
|
|
|
@ -2,78 +2,86 @@
|
||||||
#include "pulsar/Result.h"
|
#include "pulsar/Result.h"
|
||||||
#include "PartitionPolicy.h"
|
#include "PartitionPolicy.h"
|
||||||
#include "utils/CommonUtil.h"
|
#include "utils/CommonUtil.h"
|
||||||
|
#include <omp.h>
|
||||||
|
|
||||||
namespace milvus::message_client {
|
namespace milvus::message_client {
|
||||||
|
|
||||||
std::map<int64_t, std::vector<std::shared_ptr<grpc::QueryResult>>> total_results;
|
std::map<int64_t, std::vector<std::shared_ptr<grpc::QueryResult>>> total_results;
|
||||||
|
|
||||||
MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const pulsar::ClientConfiguration &config)
|
MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const uint32_t mut_parallelism, const pulsar::ClientConfiguration &config)
|
||||||
: client_id_(client_id), service_url_(service_url) {}
|
: client_id_(client_id), service_url_(service_url), mut_parallelism_(mut_parallelism) {}
|
||||||
|
|
||||||
Status MsgClientV2::Init(const std::string &insert_delete,
|
Status MsgClientV2::Init(const std::string &insert_delete,
|
||||||
const std::string &search,
|
const std::string &search,
|
||||||
const std::string &time_sync,
|
const std::string &time_sync,
|
||||||
const std::string &search_by_id,
|
const std::string &search_by_id,
|
||||||
const std::string &search_result) {
|
const std::string &search_result) {
|
||||||
//create pulsar client
|
//create pulsar client
|
||||||
auto pulsar_client = std::make_shared<MsgClient>(service_url_);
|
auto pulsar_client = std::make_shared<MsgClient>(service_url_);
|
||||||
//create pulsar producer
|
//create pulsar producer
|
||||||
ProducerConfiguration producerConfiguration;
|
ProducerConfiguration producerConfiguration;
|
||||||
producerConfiguration.setPartitionsRoutingMode(ProducerConfiguration::CustomPartition);
|
producerConfiguration.setPartitionsRoutingMode(ProducerConfiguration::CustomPartition);
|
||||||
producerConfiguration.setMessageRouter(std::make_shared<PartitionPolicy>());
|
producerConfiguration.setMessageRouter(std::make_shared<PartitionPolicy>());
|
||||||
insert_delete_producer_ = std::make_shared<MsgProducer>(pulsar_client, insert_delete, producerConfiguration);
|
// insert_delete_producer_ = std::make_shared<MsgProducer>(pulsar_client, insert_delete, producerConfiguration);
|
||||||
search_producer_ = std::make_shared<MsgProducer>(pulsar_client, search, producerConfiguration);
|
search_producer_ = std::make_shared<MsgProducer>(pulsar_client, search, producerConfiguration);
|
||||||
search_by_id_producer_ = std::make_shared<MsgProducer>(pulsar_client, search_result, producerConfiguration);
|
search_by_id_producer_ = std::make_shared<MsgProducer>(pulsar_client, search_by_id, producerConfiguration);
|
||||||
time_sync_producer_ = std::make_shared<MsgProducer>(pulsar_client, time_sync);
|
time_sync_producer_ = std::make_shared<MsgProducer>(pulsar_client, time_sync);
|
||||||
|
|
||||||
|
for (auto i = 0; i < mut_parallelism_; i++) {
|
||||||
|
paralle_mut_producers_.emplace_back(std::make_shared<MsgProducer>(pulsar_client,
|
||||||
|
insert_delete,
|
||||||
|
producerConfiguration));
|
||||||
|
}
|
||||||
//create pulsar consumer
|
//create pulsar consumer
|
||||||
std::string subscribe_name = std::to_string(CommonUtil::RandomUINT64());
|
std::string subscribe_name = std::to_string(CommonUtil::RandomUINT64());
|
||||||
consumer_ = std::make_shared<MsgConsumer>(pulsar_client, search_result+subscribe_name);
|
consumer_ = std::make_shared<MsgConsumer>(pulsar_client, search_result + subscribe_name);
|
||||||
|
|
||||||
auto result = consumer_->subscribe(search_result);
|
auto result = consumer_->subscribe(search_result);
|
||||||
if (result != pulsar::Result::ResultOk) {
|
if (result != pulsar::Result::ResultOk) {
|
||||||
return Status(SERVER_UNEXPECTED_ERROR, "Pulsar message client init occur error, " + std::string(pulsar::strResult(result)));
|
return Status(SERVER_UNEXPECTED_ERROR,
|
||||||
|
"Pulsar message client init occur error, " + std::string(pulsar::strResult(result)));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t GetQueryNodeNum() {
|
int64_t GetQueryNodeNum() {
|
||||||
return 2;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> &results){
|
milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> &results) {
|
||||||
//TODO: QueryNode has only one
|
//TODO: QueryNode has only one
|
||||||
int64_t length = results.size();
|
int64_t length = results.size();
|
||||||
std::vector<float> all_scores;
|
std::vector<float> all_scores;
|
||||||
std::vector<float> all_distance;
|
std::vector<float> all_distance;
|
||||||
std::vector<grpc::KeyValuePair> all_kv_pairs;
|
std::vector<grpc::KeyValuePair> all_kv_pairs;
|
||||||
std::vector<int> index(length * results[0]->scores_size());
|
std::vector<int> index(length * results[0]->scores_size());
|
||||||
|
|
||||||
for (int n = 0; n < length * results[0]->scores_size(); ++n) {
|
for (int n = 0; n < length * results[0]->scores_size(); ++n) {
|
||||||
index[n] = n;
|
index[n] = n;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
for (int j = 0; j < results[i]->scores_size(); j++) {
|
||||||
|
all_scores.push_back(results[i]->scores()[j]);
|
||||||
|
all_distance.push_back(results[i]->distances()[j]);
|
||||||
|
all_kv_pairs.push_back(results[i]->extra_params()[j]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < length; i++){
|
for (int k = 0; k < all_distance.size() - 1; ++k) {
|
||||||
for (int j = 0; j < results[i]->scores_size(); j++){
|
for (int l = k + 1; l < all_distance.size(); ++l) {
|
||||||
all_scores.push_back(results[i]->scores()[j]);
|
|
||||||
all_distance.push_back(results[i]->distances()[j]);
|
if (all_distance[l] > all_distance[k]) {
|
||||||
all_kv_pairs.push_back(results[i]->extra_params()[j]);
|
float distance_temp = all_distance[k];
|
||||||
}
|
all_distance[k] = all_distance[l];
|
||||||
}
|
all_distance[l] = distance_temp;
|
||||||
|
|
||||||
for (int k = 0; k < all_distance.size() - 1; ++k) {
|
int index_temp = index[k];
|
||||||
for (int l = k + 1; l < all_distance.size(); ++l) {
|
index[k] = index[l];
|
||||||
|
index[l] = index_temp;
|
||||||
if (all_distance[l] > all_distance[k]){
|
}
|
||||||
float distance_temp = all_distance[k];
|
|
||||||
all_distance[k] = all_distance[l];
|
|
||||||
all_distance[l] = distance_temp;
|
|
||||||
|
|
||||||
int index_temp = index[k];
|
|
||||||
index[k] = index[l];
|
|
||||||
index[l] = index_temp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
grpc::QueryResult result;
|
grpc::QueryResult result;
|
||||||
|
|
||||||
|
@ -88,10 +96,10 @@ milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryRes
|
||||||
result.mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
|
result.mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
result.set_query_id(results[0]->query_id());
|
result.set_query_id(results[0]->query_id());
|
||||||
result.set_client_id(results[0]->client_id());
|
result.set_client_id(results[0]->client_id());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &result) {
|
Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &result) {
|
||||||
|
@ -114,9 +122,10 @@ Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &
|
||||||
auto message = std::make_shared<grpc::QueryResult>(search_res_msg);
|
auto message = std::make_shared<grpc::QueryResult>(search_res_msg);
|
||||||
total_results[message->query_id()].push_back(message);
|
total_results[message->query_id()].push_back(message);
|
||||||
consumer_->acknowledge(msg);
|
consumer_->acknowledge(msg);
|
||||||
|
} else {
|
||||||
|
return Status(DB_ERROR, "can't parse message which from pulsar!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = Aggregation(total_results[query_id]);
|
result = Aggregation(total_results[query_id]);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -126,8 +135,12 @@ Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uin
|
||||||
auto row_count = request.rows_data_size();
|
auto row_count = request.rows_data_size();
|
||||||
// TODO: Get the segment from master
|
// TODO: Get the segment from master
|
||||||
int64_t segment = 0;
|
int64_t segment = 0;
|
||||||
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
auto stats = std::vector<pulsar::Result>(ParallelNum);
|
||||||
|
|
||||||
|
#pragma omp parallel for default(none), shared(row_count, request, timestamp, segment, stats), num_threads(ParallelNum)
|
||||||
for (auto i = 0; i < row_count; i++) {
|
for (auto i = 0; i < row_count; i++) {
|
||||||
|
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
||||||
|
int this_thread = omp_get_thread_num();
|
||||||
mut_msg.set_op(milvus::grpc::OpType::INSERT);
|
mut_msg.set_op(milvus::grpc::OpType::INSERT);
|
||||||
mut_msg.set_uid(request.entity_id_array(i));
|
mut_msg.set_uid(request.entity_id_array(i));
|
||||||
mut_msg.set_client_id(client_id_);
|
mut_msg.set_client_id(client_id_);
|
||||||
|
@ -138,29 +151,40 @@ Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uin
|
||||||
mut_msg.mutable_rows_data()->CopyFrom(request.rows_data(i));
|
mut_msg.mutable_rows_data()->CopyFrom(request.rows_data(i));
|
||||||
mut_msg.mutable_extra_params()->CopyFrom(request.extra_params());
|
mut_msg.mutable_extra_params()->CopyFrom(request.extra_params());
|
||||||
|
|
||||||
auto result = insert_delete_producer_->send(mut_msg);
|
auto result = paralle_mut_producers_[this_thread]->send(mut_msg);
|
||||||
if (result != pulsar::ResultOk) {
|
if (result != pulsar::ResultOk) {
|
||||||
// TODO: error code
|
stats[this_thread] = result;
|
||||||
return Status(DB_ERROR, pulsar::strResult(result));
|
}
|
||||||
|
}
|
||||||
|
for (auto &stat : stats) {
|
||||||
|
if (stat == pulsar::ResultOk) {
|
||||||
|
return Status(DB_ERROR, pulsar::strResult(stat));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MsgClientV2::SendMutMessage(const milvus::grpc::DeleteByIDParam &request, uint64_t timestamp) {
|
Status MsgClientV2::SendMutMessage(const milvus::grpc::DeleteByIDParam &request, uint64_t timestamp) {
|
||||||
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
auto stats = std::vector<pulsar::Result>(ParallelNum);
|
||||||
for (auto id: request.id_array()) {
|
#pragma omp parallel for default(none), shared( request, timestamp, stats), num_threads(ParallelNum)
|
||||||
|
for (auto i = 0; i < request.id_array_size(); i++) {
|
||||||
|
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
||||||
mut_msg.set_op(milvus::grpc::OpType::DELETE);
|
mut_msg.set_op(milvus::grpc::OpType::DELETE);
|
||||||
mut_msg.set_uid(GetUniqueQId());
|
mut_msg.set_uid(GetUniqueQId());
|
||||||
mut_msg.set_client_id(client_id_);
|
mut_msg.set_client_id(client_id_);
|
||||||
mut_msg.set_uid(id);
|
mut_msg.set_uid(request.id_array(i));
|
||||||
mut_msg.set_collection_name(request.collection_name());
|
mut_msg.set_collection_name(request.collection_name());
|
||||||
mut_msg.set_timestamp(timestamp);
|
mut_msg.set_timestamp(timestamp);
|
||||||
|
|
||||||
auto result = insert_delete_producer_->send(mut_msg);
|
int this_thread = omp_get_thread_num();
|
||||||
|
auto result = paralle_mut_producers_[this_thread]->send(mut_msg);
|
||||||
if (result != pulsar::ResultOk) {
|
if (result != pulsar::ResultOk) {
|
||||||
// TODO: error code
|
stats[this_thread] = result;
|
||||||
return Status(DB_ERROR, pulsar::strResult(result));
|
}
|
||||||
|
}
|
||||||
|
for (auto &stat : stats) {
|
||||||
|
if (stat == pulsar::ResultOk) {
|
||||||
|
return Status(DB_ERROR, pulsar::strResult(stat));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -200,7 +224,8 @@ Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, u
|
||||||
search_msg.mutable_extra_params(l)->CopyFrom(request.extra_params(l));
|
search_msg.mutable_extra_params(l)->CopyFrom(request.extra_params(l));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = search_by_id_producer_->send(search_msg);
|
std::cout << search_msg.collection_name() << std::endl;
|
||||||
|
auto result = search_producer_->send(search_msg);
|
||||||
if (result != pulsar::Result::ResultOk) {
|
if (result != pulsar::Result::ResultOk) {
|
||||||
return Status(DB_ERROR, pulsar::strResult(result));
|
return Status(DB_ERROR, pulsar::strResult(result));
|
||||||
}
|
}
|
||||||
|
@ -276,7 +301,10 @@ Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, u
|
||||||
}
|
}
|
||||||
|
|
||||||
MsgClientV2::~MsgClientV2() {
|
MsgClientV2::~MsgClientV2() {
|
||||||
insert_delete_producer_->close();
|
// insert_delete_producer_->close();
|
||||||
|
for (auto& producer: paralle_mut_producers_){
|
||||||
|
producer->close();
|
||||||
|
}
|
||||||
search_producer_->close();
|
search_producer_->close();
|
||||||
search_by_id_producer_->close();
|
search_by_id_producer_->close();
|
||||||
time_sync_producer_->close();
|
time_sync_producer_->close();
|
||||||
|
|
|
@ -6,10 +6,13 @@
|
||||||
#include "grpc/message.pb.h"
|
#include "grpc/message.pb.h"
|
||||||
|
|
||||||
namespace milvus::message_client {
|
namespace milvus::message_client {
|
||||||
|
constexpr uint32_t ParallelNum = 12 * 20;
|
||||||
|
|
||||||
class MsgClientV2 {
|
class MsgClientV2 {
|
||||||
public:
|
public:
|
||||||
MsgClientV2(int64_t client_id,
|
MsgClientV2(int64_t client_id,
|
||||||
const std::string &service_url,
|
const std::string &service_url,
|
||||||
|
const uint32_t mut_parallelism = ParallelNum,
|
||||||
const pulsar::ClientConfiguration &config = pulsar::ClientConfiguration());
|
const pulsar::ClientConfiguration &config = pulsar::ClientConfiguration());
|
||||||
~MsgClientV2();
|
~MsgClientV2();
|
||||||
|
|
||||||
|
@ -40,9 +43,11 @@ class MsgClientV2 {
|
||||||
int64_t client_id_;
|
int64_t client_id_;
|
||||||
std::string service_url_;
|
std::string service_url_;
|
||||||
std::shared_ptr<MsgConsumer> consumer_;
|
std::shared_ptr<MsgConsumer> consumer_;
|
||||||
std::shared_ptr<MsgProducer> insert_delete_producer_;
|
// std::shared_ptr<MsgProducer> insert_delete_producer_;
|
||||||
std::shared_ptr<MsgProducer> search_producer_;
|
std::shared_ptr<MsgProducer> search_producer_;
|
||||||
std::shared_ptr<MsgProducer> time_sync_producer_;
|
std::shared_ptr<MsgProducer> time_sync_producer_;
|
||||||
std::shared_ptr<MsgProducer> search_by_id_producer_;
|
std::shared_ptr<MsgProducer> search_by_id_producer_;
|
||||||
|
std::vector<std::shared_ptr<MsgProducer>> paralle_mut_producers_;
|
||||||
|
const uint32_t mut_parallelism_;
|
||||||
};
|
};
|
||||||
}
|
}
|
|
@ -78,7 +78,7 @@ TEST(CLIENT_CPP, GetResult) {
|
||||||
|
|
||||||
// client_v2.SendQueryMessage();
|
// client_v2.SendQueryMessage();
|
||||||
milvus::grpc::SearchParam request;
|
milvus::grpc::SearchParam request;
|
||||||
auto status_send = client_v2.SendQueryMessage(request, query_id);
|
auto status_send = client_v2.SendQueryMessage(request, 10, query_id);
|
||||||
|
|
||||||
milvus::grpc::QueryResult result;
|
milvus::grpc::QueryResult result;
|
||||||
auto status = client_v2.GetQueryResult(query_id, result);
|
auto status = client_v2.GetQueryResult(query_id, result);
|
||||||
|
|
Loading…
Reference in New Issue