Optimize search performance

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/4973/head^2
xige-16 2021-04-19 19:30:36 +08:00 committed by yefu.chen
parent b79a408491
commit 1165db75f6
4 changed files with 69 additions and 28 deletions

View File

@ -65,6 +65,9 @@ set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
set(FETCHCONTENT_QUIET OFF)
include( ThirdPartyPackages )
find_package(OpenMP REQUIRED)
if (OPENMP_FOUND)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
# **************************** Compiler arguments ****************************
message( STATUS "Building Milvus CPU version" )

View File

@ -33,7 +33,7 @@ add_library(milvus_segcore SHARED
)
target_link_libraries(milvus_segcore
tbb milvus_utils pthread knowhere log milvus_proto
tbb milvus_utils pthread knowhere log milvus_proto ${OpenMP_CXX_FLAGS}
dl backtrace
milvus_common
milvus_query

View File

@ -172,19 +172,20 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group;
std::vector<int64_t> num_queries_peer_group(num_groups);
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries);
num_queries_peer_group[i] = num_queries;
total_num_queries += num_queries;
}
std::vector<float> result_distances(total_num_queries * topk);
std::vector<int64_t> result_ids(total_num_queries * topk);
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
std::vector<char> temp_ids;
int64_t count = 0;
std::vector<int64_t> counts(num_segments);
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
@ -192,30 +193,46 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
auto search_result = (SearchResult*)c_search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto size = search_result->result_offsets_.size();
#pragma omp parallel for
for (int j = 0; j < size; j++) {
auto loc = search_result->result_offsets_[j];
result_distances[loc] = search_result->result_distances_[j];
row_datas[loc] = search_result->row_data_[j];
memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t));
}
count += size;
counts[i] = size;
}
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
int64_t fill_hit_offset = 0;
int64_t total_count = 0;
for (int i = 0; i < num_segments; i++) {
total_count += counts[i];
}
AssertInfo(total_count == total_num_queries * topk,
"the reduces result's size less than total_num_queries*topk");
int64_t last_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries_peer_group[i]; j++) {
milvus::proto::milvus::Hits hits;
for (int k = 0; k < topk; k++, fill_hit_offset++) {
hits.add_ids(result_ids[fill_hit_offset]);
hits.add_scores(result_distances[fill_hit_offset]);
auto& row_data = row_datas[fill_hit_offset];
hits.add_row_data(row_data.data(), row_data.size());
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
#pragma omp parallel for
for (int m = 0; m < num_queries_peer_group[i]; m++) {
for (int n = 0; n < topk; n++) {
int64_t result_offset = last_offset + m * topk + n;
hits[m].add_ids(result_ids[result_offset]);
hits[m].add_scores(result_distances[result_offset]);
auto& row_data = row_datas[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
hits_peer_group.blob_length_.push_back(blob.size());
}
last_offset = last_offset + num_queries_peer_group[i] * topk;
#pragma omp parallel for
for (int j = 0; j < num_queries_peer_group[i]; j++) {
auto blob = hits[j].SerializeAsString();
hits_peer_group.hits_[j] = blob;
hits_peer_group.blob_length_[j] = blob.size();
}
}
@ -245,27 +262,37 @@ ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits,
auto search_result = (SearchResult*)c_search_result;
auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group;
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries);
}
int64_t fill_hit_offset = 0;
int64_t last_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries_peer_group[i]; j++) {
milvus::proto::milvus::Hits hits;
for (int k = 0; k < topk; k++, fill_hit_offset++) {
hits.add_scores(search_result->result_distances_[fill_hit_offset]);
auto& row_data = search_result->row_data_[fill_hit_offset];
hits.add_row_data(row_data.data(), row_data.size());
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
#pragma omp parallel for
for (int m = 0; m < num_queries_peer_group[i]; m++) {
for (int n = 0; n < topk; n++) {
int64_t result_offset = last_offset + m * topk + n;
hits[m].add_scores(search_result->result_distances_[result_offset]);
auto& row_data = search_result->row_data_[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
int64_t result_id;
memcpy(&result_id, row_data.data(), sizeof(int64_t));
hits.add_ids(result_id);
hits[m].add_ids(result_id);
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
hits_peer_group.blob_length_.push_back(blob.size());
}
last_offset = last_offset + num_queries_peer_group[i] * topk;
#pragma omp parallel for
for (int j = 0; j < num_queries_peer_group[i]; j++) {
auto blob = hits[j].SerializeAsString();
hits_peer_group.hits_[j] = blob;
hits_peer_group.blob_length_[j] = blob.size();
}
}

View File

@ -287,6 +287,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
if err != nil {
return err
}
queryNum := searchReq.getNumOfQuery()
searchRequests := make([]*searchRequest, 0)
searchRequests = append(searchRequests, searchReq)
@ -315,6 +316,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
searchPartitionIDs = partitionIDsInQuery
}
sp.LogFields(oplog.String("statistical time", "stats start"), oplog.Object("nq", queryNum), oplog.Object("dsl", dsl))
for _, partitionID := range searchPartitionIDs {
segmentIDs, err := s.replica.getSegmentIDs(partitionID)
if err != nil {
@ -336,6 +338,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
}
sp.LogFields(oplog.String("statistical time", "segment search end"))
if len(searchResults) <= 0 {
for _, group := range searchRequests {
nq := group.getNumOfQuery()
@ -378,28 +381,34 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
if numSegment == 1 {
inReduced[0] = true
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
if err != nil {
return err
}
marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0])
sp.LogFields(oplog.String("statistical time", "reorganizeSingleQueryResult end"))
if err != nil {
return err
}
} else {
err = reduceSearchResults(searchResults, numSegment, inReduced)
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
if err != nil {
return err
}
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
if err != nil {
return err
}
marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced)
sp.LogFields(oplog.String("statistical time", "reorganizeQueryResults end"))
if err != nil {
return err
}
}
hitsBlob, err := marshaledHits.getHitsBlob()
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
if err != nil {
return err
}
@ -457,8 +466,10 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
}
}
sp.LogFields(oplog.String("statistical time", "before free c++ memory"))
deleteSearchResults(searchResults)
deleteMarshaledHits(marshaledHits)
sp.LogFields(oplog.String("statistical time", "stats done"))
plan.delete()
searchReq.delete()
return nil