mirror of https://github.com/milvus-io/milvus.git
Add interface for collect result
Signed-off-by: cai.zhang <cai.zhang@zilliz.com>pull/4973/head^2
parent
7479aedf07
commit
34b670f654
|
@ -30,7 +30,7 @@ storage:
|
|||
secretkey: dd
|
||||
|
||||
pulsar:
|
||||
address: localhost
|
||||
address: 0.0.0.0
|
||||
port: 6650
|
||||
|
||||
proxy:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -63,7 +63,7 @@ include( FetchContent )
|
|||
set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
|
||||
set(FETCHCONTENT_QUIET OFF)
|
||||
include( ThirdPartyPackages )
|
||||
|
||||
find_package(OpenMP REQUIRED)
|
||||
# **************************** Compiler arguments ****************************
|
||||
message( STATUS "Building Milvus CPU version" )
|
||||
|
||||
|
|
|
@ -2,78 +2,86 @@
|
|||
#include "pulsar/Result.h"
|
||||
#include "PartitionPolicy.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include <omp.h>
|
||||
|
||||
namespace milvus::message_client {
|
||||
|
||||
std::map<int64_t, std::vector<std::shared_ptr<grpc::QueryResult>>> total_results;
|
||||
|
||||
MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const pulsar::ClientConfiguration &config)
|
||||
: client_id_(client_id), service_url_(service_url) {}
|
||||
MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const uint32_t mut_parallelism, const pulsar::ClientConfiguration &config)
|
||||
: client_id_(client_id), service_url_(service_url), mut_parallelism_(mut_parallelism) {}
|
||||
|
||||
Status MsgClientV2::Init(const std::string &insert_delete,
|
||||
const std::string &search,
|
||||
const std::string &time_sync,
|
||||
const std::string &search_by_id,
|
||||
const std::string &search_result) {
|
||||
const std::string &search,
|
||||
const std::string &time_sync,
|
||||
const std::string &search_by_id,
|
||||
const std::string &search_result) {
|
||||
//create pulsar client
|
||||
auto pulsar_client = std::make_shared<MsgClient>(service_url_);
|
||||
//create pulsar producer
|
||||
ProducerConfiguration producerConfiguration;
|
||||
producerConfiguration.setPartitionsRoutingMode(ProducerConfiguration::CustomPartition);
|
||||
producerConfiguration.setMessageRouter(std::make_shared<PartitionPolicy>());
|
||||
insert_delete_producer_ = std::make_shared<MsgProducer>(pulsar_client, insert_delete, producerConfiguration);
|
||||
// insert_delete_producer_ = std::make_shared<MsgProducer>(pulsar_client, insert_delete, producerConfiguration);
|
||||
search_producer_ = std::make_shared<MsgProducer>(pulsar_client, search, producerConfiguration);
|
||||
search_by_id_producer_ = std::make_shared<MsgProducer>(pulsar_client, search_result, producerConfiguration);
|
||||
search_by_id_producer_ = std::make_shared<MsgProducer>(pulsar_client, search_by_id, producerConfiguration);
|
||||
time_sync_producer_ = std::make_shared<MsgProducer>(pulsar_client, time_sync);
|
||||
|
||||
for (auto i = 0; i < mut_parallelism_; i++) {
|
||||
paralle_mut_producers_.emplace_back(std::make_shared<MsgProducer>(pulsar_client,
|
||||
insert_delete,
|
||||
producerConfiguration));
|
||||
}
|
||||
//create pulsar consumer
|
||||
std::string subscribe_name = std::to_string(CommonUtil::RandomUINT64());
|
||||
consumer_ = std::make_shared<MsgConsumer>(pulsar_client, search_result+subscribe_name);
|
||||
consumer_ = std::make_shared<MsgConsumer>(pulsar_client, search_result + subscribe_name);
|
||||
|
||||
auto result = consumer_->subscribe(search_result);
|
||||
if (result != pulsar::Result::ResultOk) {
|
||||
return Status(SERVER_UNEXPECTED_ERROR, "Pulsar message client init occur error, " + std::string(pulsar::strResult(result)));
|
||||
return Status(SERVER_UNEXPECTED_ERROR,
|
||||
"Pulsar message client init occur error, " + std::string(pulsar::strResult(result)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t GetQueryNodeNum() {
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> &results){
|
||||
//TODO: QueryNode has only one
|
||||
int64_t length = results.size();
|
||||
std::vector<float> all_scores;
|
||||
std::vector<float> all_distance;
|
||||
std::vector<grpc::KeyValuePair> all_kv_pairs;
|
||||
std::vector<int> index(length * results[0]->scores_size());
|
||||
milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> &results) {
|
||||
//TODO: QueryNode has only one
|
||||
int64_t length = results.size();
|
||||
std::vector<float> all_scores;
|
||||
std::vector<float> all_distance;
|
||||
std::vector<grpc::KeyValuePair> all_kv_pairs;
|
||||
std::vector<int> index(length * results[0]->scores_size());
|
||||
|
||||
for (int n = 0; n < length * results[0]->scores_size(); ++n) {
|
||||
index[n] = n;
|
||||
for (int n = 0; n < length * results[0]->scores_size(); ++n) {
|
||||
index[n] = n;
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
for (int j = 0; j < results[i]->scores_size(); j++) {
|
||||
all_scores.push_back(results[i]->scores()[j]);
|
||||
all_distance.push_back(results[i]->distances()[j]);
|
||||
all_kv_pairs.push_back(results[i]->extra_params()[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; i++){
|
||||
for (int j = 0; j < results[i]->scores_size(); j++){
|
||||
all_scores.push_back(results[i]->scores()[j]);
|
||||
all_distance.push_back(results[i]->distances()[j]);
|
||||
all_kv_pairs.push_back(results[i]->extra_params()[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < all_distance.size() - 1; ++k) {
|
||||
for (int l = k + 1; l < all_distance.size(); ++l) {
|
||||
|
||||
if (all_distance[l] > all_distance[k]){
|
||||
float distance_temp = all_distance[k];
|
||||
all_distance[k] = all_distance[l];
|
||||
all_distance[l] = distance_temp;
|
||||
|
||||
int index_temp = index[k];
|
||||
index[k] = index[l];
|
||||
index[l] = index_temp;
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < all_distance.size() - 1; ++k) {
|
||||
for (int l = k + 1; l < all_distance.size(); ++l) {
|
||||
|
||||
if (all_distance[l] > all_distance[k]) {
|
||||
float distance_temp = all_distance[k];
|
||||
all_distance[k] = all_distance[l];
|
||||
all_distance[l] = distance_temp;
|
||||
|
||||
int index_temp = index[k];
|
||||
index[k] = index[l];
|
||||
index[l] = index_temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grpc::QueryResult result;
|
||||
|
||||
|
@ -88,10 +96,10 @@ milvus::grpc::QueryResult Aggregation(std::vector<std::shared_ptr<grpc::QueryRes
|
|||
result.mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
|
||||
}
|
||||
|
||||
result.set_query_id(results[0]->query_id());
|
||||
result.set_client_id(results[0]->client_id());
|
||||
result.set_query_id(results[0]->query_id());
|
||||
result.set_client_id(results[0]->client_id());
|
||||
|
||||
return result;
|
||||
return result;
|
||||
}
|
||||
|
||||
Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &result) {
|
||||
|
@ -114,9 +122,10 @@ Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &
|
|||
auto message = std::make_shared<grpc::QueryResult>(search_res_msg);
|
||||
total_results[message->query_id()].push_back(message);
|
||||
consumer_->acknowledge(msg);
|
||||
} else {
|
||||
return Status(DB_ERROR, "can't parse message which from pulsar!");
|
||||
}
|
||||
}
|
||||
|
||||
result = Aggregation(total_results[query_id]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -126,8 +135,12 @@ Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uin
|
|||
auto row_count = request.rows_data_size();
|
||||
// TODO: Get the segment from master
|
||||
int64_t segment = 0;
|
||||
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
||||
auto stats = std::vector<pulsar::Result>(ParallelNum);
|
||||
|
||||
#pragma omp parallel for default(none), shared(row_count, request, timestamp, segment, stats), num_threads(ParallelNum)
|
||||
for (auto i = 0; i < row_count; i++) {
|
||||
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
||||
int this_thread = omp_get_thread_num();
|
||||
mut_msg.set_op(milvus::grpc::OpType::INSERT);
|
||||
mut_msg.set_uid(request.entity_id_array(i));
|
||||
mut_msg.set_client_id(client_id_);
|
||||
|
@ -138,29 +151,40 @@ Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uin
|
|||
mut_msg.mutable_rows_data()->CopyFrom(request.rows_data(i));
|
||||
mut_msg.mutable_extra_params()->CopyFrom(request.extra_params());
|
||||
|
||||
auto result = insert_delete_producer_->send(mut_msg);
|
||||
auto result = paralle_mut_producers_[this_thread]->send(mut_msg);
|
||||
if (result != pulsar::ResultOk) {
|
||||
// TODO: error code
|
||||
return Status(DB_ERROR, pulsar::strResult(result));
|
||||
stats[this_thread] = result;
|
||||
}
|
||||
}
|
||||
for (auto &stat : stats) {
|
||||
if (stat == pulsar::ResultOk) {
|
||||
return Status(DB_ERROR, pulsar::strResult(stat));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MsgClientV2::SendMutMessage(const milvus::grpc::DeleteByIDParam &request, uint64_t timestamp) {
|
||||
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
||||
for (auto id: request.id_array()) {
|
||||
auto stats = std::vector<pulsar::Result>(ParallelNum);
|
||||
#pragma omp parallel for default(none), shared( request, timestamp, stats), num_threads(ParallelNum)
|
||||
for (auto i = 0; i < request.id_array_size(); i++) {
|
||||
milvus::grpc::InsertOrDeleteMsg mut_msg;
|
||||
mut_msg.set_op(milvus::grpc::OpType::DELETE);
|
||||
mut_msg.set_uid(GetUniqueQId());
|
||||
mut_msg.set_client_id(client_id_);
|
||||
mut_msg.set_uid(id);
|
||||
mut_msg.set_uid(request.id_array(i));
|
||||
mut_msg.set_collection_name(request.collection_name());
|
||||
mut_msg.set_timestamp(timestamp);
|
||||
|
||||
auto result = insert_delete_producer_->send(mut_msg);
|
||||
int this_thread = omp_get_thread_num();
|
||||
auto result = paralle_mut_producers_[this_thread]->send(mut_msg);
|
||||
if (result != pulsar::ResultOk) {
|
||||
// TODO: error code
|
||||
return Status(DB_ERROR, pulsar::strResult(result));
|
||||
stats[this_thread] = result;
|
||||
}
|
||||
}
|
||||
for (auto &stat : stats) {
|
||||
if (stat == pulsar::ResultOk) {
|
||||
return Status(DB_ERROR, pulsar::strResult(stat));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -200,7 +224,8 @@ Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, u
|
|||
search_msg.mutable_extra_params(l)->CopyFrom(request.extra_params(l));
|
||||
}
|
||||
|
||||
auto result = search_by_id_producer_->send(search_msg);
|
||||
std::cout << search_msg.collection_name() << std::endl;
|
||||
auto result = search_producer_->send(search_msg);
|
||||
if (result != pulsar::Result::ResultOk) {
|
||||
return Status(DB_ERROR, pulsar::strResult(result));
|
||||
}
|
||||
|
@ -276,7 +301,10 @@ Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, u
|
|||
}
|
||||
|
||||
MsgClientV2::~MsgClientV2() {
|
||||
insert_delete_producer_->close();
|
||||
// insert_delete_producer_->close();
|
||||
for (auto& producer: paralle_mut_producers_){
|
||||
producer->close();
|
||||
}
|
||||
search_producer_->close();
|
||||
search_by_id_producer_->close();
|
||||
time_sync_producer_->close();
|
||||
|
|
|
@ -6,10 +6,13 @@
|
|||
#include "grpc/message.pb.h"
|
||||
|
||||
namespace milvus::message_client {
|
||||
constexpr uint32_t ParallelNum = 12 * 20;
|
||||
|
||||
class MsgClientV2 {
|
||||
public:
|
||||
MsgClientV2(int64_t client_id,
|
||||
const std::string &service_url,
|
||||
const uint32_t mut_parallelism = ParallelNum,
|
||||
const pulsar::ClientConfiguration &config = pulsar::ClientConfiguration());
|
||||
~MsgClientV2();
|
||||
|
||||
|
@ -40,9 +43,11 @@ class MsgClientV2 {
|
|||
int64_t client_id_;
|
||||
std::string service_url_;
|
||||
std::shared_ptr<MsgConsumer> consumer_;
|
||||
std::shared_ptr<MsgProducer> insert_delete_producer_;
|
||||
// std::shared_ptr<MsgProducer> insert_delete_producer_;
|
||||
std::shared_ptr<MsgProducer> search_producer_;
|
||||
std::shared_ptr<MsgProducer> time_sync_producer_;
|
||||
std::shared_ptr<MsgProducer> search_by_id_producer_;
|
||||
std::vector<std::shared_ptr<MsgProducer>> paralle_mut_producers_;
|
||||
const uint32_t mut_parallelism_;
|
||||
};
|
||||
}
|
|
@ -78,7 +78,7 @@ TEST(CLIENT_CPP, GetResult) {
|
|||
|
||||
// client_v2.SendQueryMessage();
|
||||
milvus::grpc::SearchParam request;
|
||||
auto status_send = client_v2.SendQueryMessage(request, query_id);
|
||||
auto status_send = client_v2.SendQueryMessage(request, 10, query_id);
|
||||
|
||||
milvus::grpc::QueryResult result;
|
||||
auto status = client_v2.GetQueryResult(query_id, result);
|
||||
|
|
Loading…
Reference in New Issue