mirror of https://github.com/milvus-io/milvus.git
Optimize search performance
Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/4973/head^2
parent
b79a408491
commit
1165db75f6
|
@ -65,6 +65,9 @@ set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
|
|||
set(FETCHCONTENT_QUIET OFF)
|
||||
include( ThirdPartyPackages )
|
||||
find_package(OpenMP REQUIRED)
|
||||
if (OPENMP_FOUND)
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
# **************************** Compiler arguments ****************************
|
||||
message( STATUS "Building Milvus CPU version" )
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ add_library(milvus_segcore SHARED
|
|||
)
|
||||
|
||||
target_link_libraries(milvus_segcore
|
||||
tbb milvus_utils pthread knowhere log milvus_proto
|
||||
tbb milvus_utils pthread knowhere log milvus_proto ${OpenMP_CXX_FLAGS}
|
||||
dl backtrace
|
||||
milvus_common
|
||||
milvus_query
|
||||
|
|
|
@ -172,19 +172,20 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
|
|||
try {
|
||||
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
|
||||
auto topk = GetTopK(c_plan);
|
||||
std::vector<int64_t> num_queries_peer_group;
|
||||
std::vector<int64_t> num_queries_peer_group(num_groups);
|
||||
int64_t total_num_queries = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
||||
num_queries_peer_group.push_back(num_queries);
|
||||
num_queries_peer_group[i] = num_queries;
|
||||
total_num_queries += num_queries;
|
||||
}
|
||||
|
||||
std::vector<float> result_distances(total_num_queries * topk);
|
||||
std::vector<int64_t> result_ids(total_num_queries * topk);
|
||||
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
|
||||
std::vector<char> temp_ids;
|
||||
|
||||
int64_t count = 0;
|
||||
std::vector<int64_t> counts(num_segments);
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
if (is_selected[i] == false) {
|
||||
continue;
|
||||
|
@ -192,30 +193,46 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
|
|||
auto search_result = (SearchResult*)c_search_results[i];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
auto size = search_result->result_offsets_.size();
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < size; j++) {
|
||||
auto loc = search_result->result_offsets_[j];
|
||||
result_distances[loc] = search_result->result_distances_[j];
|
||||
row_datas[loc] = search_result->row_data_[j];
|
||||
memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t));
|
||||
}
|
||||
count += size;
|
||||
counts[i] = size;
|
||||
}
|
||||
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
|
||||
|
||||
int64_t fill_hit_offset = 0;
|
||||
int64_t total_count = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
total_count += counts[i];
|
||||
}
|
||||
AssertInfo(total_count == total_num_queries * topk,
|
||||
"the reduces result's size less than total_num_queries*topk");
|
||||
|
||||
int64_t last_offset = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
||||
milvus::proto::milvus::Hits hits;
|
||||
for (int k = 0; k < topk; k++, fill_hit_offset++) {
|
||||
hits.add_ids(result_ids[fill_hit_offset]);
|
||||
hits.add_scores(result_distances[fill_hit_offset]);
|
||||
auto& row_data = row_datas[fill_hit_offset];
|
||||
hits.add_row_data(row_data.data(), row_data.size());
|
||||
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
|
||||
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
|
||||
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
|
||||
#pragma omp parallel for
|
||||
for (int m = 0; m < num_queries_peer_group[i]; m++) {
|
||||
for (int n = 0; n < topk; n++) {
|
||||
int64_t result_offset = last_offset + m * topk + n;
|
||||
hits[m].add_ids(result_ids[result_offset]);
|
||||
hits[m].add_scores(result_distances[result_offset]);
|
||||
auto& row_data = row_datas[result_offset];
|
||||
hits[m].add_row_data(row_data.data(), row_data.size());
|
||||
}
|
||||
auto blob = hits.SerializeAsString();
|
||||
hits_peer_group.hits_.push_back(blob);
|
||||
hits_peer_group.blob_length_.push_back(blob.size());
|
||||
}
|
||||
last_offset = last_offset + num_queries_peer_group[i] * topk;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
||||
auto blob = hits[j].SerializeAsString();
|
||||
hits_peer_group.hits_[j] = blob;
|
||||
hits_peer_group.blob_length_[j] = blob.size();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -245,27 +262,37 @@ ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits,
|
|||
auto search_result = (SearchResult*)c_search_result;
|
||||
auto topk = GetTopK(c_plan);
|
||||
std::vector<int64_t> num_queries_peer_group;
|
||||
int64_t total_num_queries = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
||||
num_queries_peer_group.push_back(num_queries);
|
||||
}
|
||||
|
||||
int64_t fill_hit_offset = 0;
|
||||
int64_t last_offset = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
||||
milvus::proto::milvus::Hits hits;
|
||||
for (int k = 0; k < topk; k++, fill_hit_offset++) {
|
||||
hits.add_scores(search_result->result_distances_[fill_hit_offset]);
|
||||
auto& row_data = search_result->row_data_[fill_hit_offset];
|
||||
hits.add_row_data(row_data.data(), row_data.size());
|
||||
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
|
||||
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
|
||||
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
|
||||
#pragma omp parallel for
|
||||
for (int m = 0; m < num_queries_peer_group[i]; m++) {
|
||||
for (int n = 0; n < topk; n++) {
|
||||
int64_t result_offset = last_offset + m * topk + n;
|
||||
hits[m].add_scores(search_result->result_distances_[result_offset]);
|
||||
auto& row_data = search_result->row_data_[result_offset];
|
||||
hits[m].add_row_data(row_data.data(), row_data.size());
|
||||
int64_t result_id;
|
||||
memcpy(&result_id, row_data.data(), sizeof(int64_t));
|
||||
hits.add_ids(result_id);
|
||||
hits[m].add_ids(result_id);
|
||||
}
|
||||
auto blob = hits.SerializeAsString();
|
||||
hits_peer_group.hits_.push_back(blob);
|
||||
hits_peer_group.blob_length_.push_back(blob.size());
|
||||
}
|
||||
last_offset = last_offset + num_queries_peer_group[i] * topk;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
||||
auto blob = hits[j].SerializeAsString();
|
||||
hits_peer_group.hits_[j] = blob;
|
||||
hits_peer_group.blob_length_[j] = blob.size();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -287,6 +287,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
queryNum := searchReq.getNumOfQuery()
|
||||
searchRequests := make([]*searchRequest, 0)
|
||||
searchRequests = append(searchRequests, searchReq)
|
||||
|
||||
|
@ -315,6 +316,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
searchPartitionIDs = partitionIDsInQuery
|
||||
}
|
||||
|
||||
sp.LogFields(oplog.String("statistical time", "stats start"), oplog.Object("nq", queryNum), oplog.Object("dsl", dsl))
|
||||
for _, partitionID := range searchPartitionIDs {
|
||||
segmentIDs, err := s.replica.getSegmentIDs(partitionID)
|
||||
if err != nil {
|
||||
|
@ -336,6 +338,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
}
|
||||
}
|
||||
|
||||
sp.LogFields(oplog.String("statistical time", "segment search end"))
|
||||
if len(searchResults) <= 0 {
|
||||
for _, group := range searchRequests {
|
||||
nq := group.getNumOfQuery()
|
||||
|
@ -378,28 +381,34 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
if numSegment == 1 {
|
||||
inReduced[0] = true
|
||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
||||
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0])
|
||||
sp.LogFields(oplog.String("statistical time", "reorganizeSingleQueryResult end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = reduceSearchResults(searchResults, numSegment, inReduced)
|
||||
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
||||
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced)
|
||||
sp.LogFields(oplog.String("statistical time", "reorganizeQueryResults end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -457,8 +466,10 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
}
|
||||
}
|
||||
|
||||
sp.LogFields(oplog.String("statistical time", "before free c++ memory"))
|
||||
deleteSearchResults(searchResults)
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
sp.LogFields(oplog.String("statistical time", "stats done"))
|
||||
plan.delete()
|
||||
searchReq.delete()
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue