Signed-off-by: yhmo <yihua.mo@zilliz.com>
pull/2615/head
groot 2020-06-18 09:36:04 +08:00 committed by JinHai-CN
parent 36d2214fc9
commit 508da260ca
1 changed files with 14 additions and 14 deletions

View File

@ -384,14 +384,6 @@ SearchCombineRequest::OnExecute() {
return status;
}
// avoid memcpy crash, check id count = target vector count * topk
if (result_ids.size() != total_count * search_topk_) {
status = Status(DB_ERROR, "Result count doesn't match target vectors count");
// let all request return
FreeRequests(status);
return status;
}
// avoid memcpy crash, check distance count = id count
if (result_distances.size() != result_ids.size()) {
status = Status(DB_ERROR, "Result distance and id count doesn't match");
@ -401,18 +393,26 @@ SearchCombineRequest::OnExecute() {
}
// step 5: construct result array
// engine ensure each target vector has same count of id/distance pairs
size_t pair_each_vector = result_ids.size() / vectors_data_.vector_count_;
offset = 0;
for (auto& request : request_list_) {
uint64_t count = request->VectorsData().vector_count_;
int64_t topk = request->TopK();
uint64_t element_cnt = count * topk;
uint64_t pair_cnt = (pair_each_vector > topk) ? topk : pair_each_vector;
TopKQueryResult& result = request->QueryResult();
result.row_num_ = count;
result.id_list_.resize(element_cnt);
result.distance_list_.resize(element_cnt);
memcpy(result.id_list_.data(), result_ids.data() + offset, element_cnt * sizeof(int64_t));
memcpy(result.distance_list_.data(), result_distances.data() + offset, element_cnt * sizeof(float));
offset += (count * search_topk_);
result.id_list_.resize(count * pair_cnt);
result.distance_list_.resize(count * pair_cnt);
for (uint64_t i = 0; i < count; i++) {
uint64_t poz = i * pair_cnt;
memcpy(result.id_list_.data() + poz, result_ids.data() + offset + poz, pair_cnt * sizeof(int64_t));
memcpy(result.distance_list_.data() + poz, result_distances.data() + offset + poz,
pair_cnt * sizeof(float));
}
offset += count * pair_cnt;
// let request return
FreeRequest(request, Status::OK());