From 765907ab775589effe2d0d465a4e61b4ce79b2ca Mon Sep 17 00:00:00 2001
From: Cai Yudong <yudong.cai@zilliz.com>
Date: Tue, 6 Sep 2022 10:55:12 +0800
Subject: [PATCH] Optimize segcore Reduce (#18902)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
---
 internal/core/src/segcore/Reduce.cpp  | 335 ++++++++++++--------------
 internal/core/src/segcore/Reduce.h    |  25 +-
 internal/core/unittest/test_c_api.cpp |  21 +-
 3 files changed, 173 insertions(+), 208 deletions(-)

diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/Reduce.cpp
index d0a6850361..e030632de8 100644
--- a/internal/core/src/segcore/Reduce.cpp
+++ b/internal/core/src/segcore/Reduce.cpp
@@ -28,127 +28,40 @@ ReduceHelper::Initialize() {
     AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs");
     AssertInfo(slice_nqs_.size() == slice_topKs_.size(), "unaligned slice_nqs and slice_topKs");
 
-    unify_topK_ = search_results_[0]->unity_topK_;
     total_nq_ = search_results_[0]->total_nq_;
     num_segments_ = search_results_.size();
     num_slices_ = slice_nqs_.size();
 
     // prefix sum, get slices offsets
     AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed");
-    auto slice_offsets_size = num_slices_ + 1;
-    nq_slice_offsets_ = std::vector<int32_t>(slice_offsets_size);
-
-    for (int i = 1; i < slice_offsets_size; i++) {
-        nq_slice_offsets_[i] = nq_slice_offsets_[i - 1] + slice_nqs_[i - 1];
-        for (auto j = nq_slice_offsets_[i - 1]; j < nq_slice_offsets_[i]; j++) {
-        }
-    }
-    AssertInfo(nq_slice_offsets_[num_slices_] == total_nq_,
-               "illegal req sizes"
-               ", nq_slice_offsets[last] = " +
-                   std::to_string(nq_slice_offsets_[num_slices_]) + ", total_nq = " + std::to_string(total_nq_));
+    slice_nqs_prefix_sum_.resize(num_slices_ + 1);
+    std::partial_sum(slice_nqs_.begin(), slice_nqs_.end(), slice_nqs_prefix_sum_.begin() + 1);
+    AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_, "illegal req sizes, slice_nqs_prefix_sum_[last] = " +
+                                                                    std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
+                                                                    ", total_nq = " + std::to_string(total_nq_));
 
     // init final_search_records and final_read_topKs
-    final_search_records_ = std::vector<std::vector<int64_t>>(num_segments_);
-    final_real_topKs_ = std::vector<std::vector<int64_t>>(num_segments_);
-    for (auto& topKs : final_real_topKs_) {
-        // `topKs` records real topK of each query
-        topKs.resize(total_nq_);
+    final_search_records_.resize(num_segments_);
+    for (auto& search_record : final_search_records_) {
+        search_record.resize(total_nq_);
     }
 }
 
 void
 ReduceHelper::Reduce() {
-    std::vector<SearchResult*> valid_search_results;
-    // get primary keys for duplicates removal
-    for (auto search_result : search_results_) {
-        FilterInvalidSearchResult(search_result);
-        if (search_result->get_total_result_count() > 0) {
-            auto segment = static_cast<SegmentInterface*>(search_result->segment_);
-            segment->FillPrimaryKeys(plan_, *search_result);
-            valid_search_results.emplace_back(search_result);
-        }
-    }
-    search_results_ = valid_search_results;
-    num_segments_ = search_results_.size();
-    if (valid_search_results.size() == 0) {
-        // TODO: return empty search result?
-        return;
-    }
-
-    for (int i = 0; i < num_slices_; i++) {
-        // ReduceResultData for each slice
-        ReduceResultData(i);
-    }
-    // after reduce, remove invalid primary_keys, distances and ids by `final_search_records`
-    for (int i = 0; i < num_segments_; i++) {
-        auto search_result = search_results_[i];
-        if (search_result->result_offsets_.size() != 0) {
-            std::vector<milvus::PkType> primary_keys;
-            std::vector<float> distances;
-            std::vector<int64_t> seg_offsets;
-            for (int j = 0; j < final_search_records_[i].size(); j++) {
-                auto& offset = final_search_records_[i][j];
-                primary_keys.push_back(search_result->primary_keys_[offset]);
-                distances.push_back(search_result->distances_[offset]);
-                seg_offsets.push_back(search_result->seg_offsets_[offset]);
-            }
-
-            search_result->primary_keys_ = std::move(primary_keys);
-            search_result->distances_ = std::move(distances);
-            search_result->seg_offsets_ = std::move(seg_offsets);
-        }
-        search_result->topk_per_nq_prefix_sum_.resize(final_real_topKs_[i].size() + 1);
-        std::partial_sum(final_real_topKs_[i].begin(), final_real_topKs_[i].end(),
-                         search_result->topk_per_nq_prefix_sum_.begin() + 1);
-    }
-
-    // fill target entry
-    for (auto& search_result : search_results_) {
-        auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
-        segment->FillTargetEntry(plan_, *search_result);
-    }
+    FillPrimaryKey();
+    ReduceResultData();
+    RefreshSearchResult();
+    FillEntryData();
 }
 
 void
 ReduceHelper::Marshal() {
-    // example:
-    //  ----------------------------------
-    //           nq0     nq1     nq2
-    //   sr0    topk00  topk01  topk02
-    //   sr1    topk10  topk11  topk12
-    //  ----------------------------------
-    // then:
-    // result_slice_offsets[] = {
-    //      0,
-    //                                      == sr0->topk_per_nq_prefix_sum_[0] + sr1->topk_per_nq_prefix_sum_[0]
-    //      ((topk00) + (topk10)),
-    //                                      == sr0->topk_per_nq_prefix_sum_[1] + sr1->topk_per_nq_prefix_sum_[1]
-    //      ((topk00 + topk01) + (topk10 + topk11)),
-    //                                      == sr0->topk_per_nq_prefix_sum_[2] + sr1->topk_per_nq_prefix_sum_[2]
-    //      ((topk00 + topk01 + topk02) + (topk10 + topk11 + topk12)),
-    //                                      == sr0->topk_per_nq_prefix_sum_[3] + sr1->topk_per_nq_prefix_sum_[3]
-    // }
-    auto result_slice_offsets = std::vector<int64_t>(nq_slice_offsets_.size(), 0);
-    for (auto search_result : search_results_) {
-        AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
-                   "incorrect topk_per_nq_prefix_sum_ size in search result");
-        for (int i = 1; i < nq_slice_offsets_.size(); i++) {
-            result_slice_offsets[i] += search_result->topk_per_nq_prefix_sum_[nq_slice_offsets_[i]];
-        }
-    }
-    AssertInfo(result_slice_offsets[num_slices_] <= total_nq_ * unify_topK_,
-               "illegal result_slice_offsets when Marshal, result_slice_offsets[last] = " +
-                   std::to_string(result_slice_offsets[num_slices_]) + ", total_nq = " + std::to_string(total_nq_) +
-                   ", unify_topK = " + std::to_string(unify_topK_));
-
     // get search result data blobs of slices
     search_result_data_blobs_ = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
     search_result_data_blobs_->blobs.resize(num_slices_);
-    //#pragma omp parallel for
     for (int i = 0; i < num_slices_; i++) {
-        auto result_count = result_slice_offsets[i + 1] - result_slice_offsets[i];
-        auto proto = GetSearchResultDataSlice(i, result_count);
+        auto proto = GetSearchResultDataSlice(i);
         search_result_data_blobs_->blobs[i] = proto;
     }
 }
@@ -178,102 +91,152 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
         }
     }
 
-    search_result->distances_ = std::move(distances);
-    search_result->seg_offsets_ = std::move(seg_offsets);
+    search_result->distances_.swap(distances);
+    search_result->seg_offsets_.swap(seg_offsets);
     search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
     std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
 }
 
 void
-ReduceHelper::ReduceResultData(int slice_index) {
+ReduceHelper::FillPrimaryKey() {
+    std::vector<SearchResult*> valid_search_results;
+    // get primary keys for duplicates removal
+    for (auto search_result : search_results_) {
+        FilterInvalidSearchResult(search_result);
+        if (search_result->get_total_result_count() > 0) {
+            auto segment = static_cast<SegmentInterface*>(search_result->segment_);
+            segment->FillPrimaryKeys(plan_, *search_result);
+            valid_search_results.emplace_back(search_result);
+        }
+    }
+    search_results_.swap(valid_search_results);
+    num_segments_ = search_results_.size();
+}
+
+void
+ReduceHelper::RefreshSearchResult() {
+    for (int i = 0; i < num_segments_; i++) {
+        std::vector<int64_t> real_topks(total_nq_, 0);
+        auto search_result = search_results_[i];
+        if (search_result->result_offsets_.size() != 0) {
+            std::vector<milvus::PkType> primary_keys;
+            std::vector<float> distances;
+            std::vector<int64_t> seg_offsets;
+            for (int j = 0; j < total_nq_; j++) {
+                for (auto offset : final_search_records_[i][j]) {
+                    primary_keys.push_back(search_result->primary_keys_[offset]);
+                    distances.push_back(search_result->distances_[offset]);
+                    seg_offsets.push_back(search_result->seg_offsets_[offset]);
+                    real_topks[j]++;
+                }
+            }
+            search_result->primary_keys_ = std::move(primary_keys);
+            search_result->distances_ = std::move(distances);
+            search_result->seg_offsets_ = std::move(seg_offsets);
+        }
+        std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
+    }
+}
+
+void
+ReduceHelper::FillEntryData() {
+    for (auto search_result : search_results_) {
+        auto segment = static_cast<milvus::segcore::SegmentInterface*>(search_result->segment_);
+        segment->FillTargetEntry(plan_, *search_result);
+    }
+}
+
+int64_t
+ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) {
+    std::vector<SearchResultPair> result_pairs;
+    for (int i = 0; i < num_segments_; i++) {
+        auto search_result = search_results_[i];
+        auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
+        auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
+        if (offset_beg == offset_end) {
+            continue;
+        }
+        auto primary_key = search_result->primary_keys_[offset_beg];
+        auto distance = search_result->distances_[offset_beg];
+        result_pairs.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
+    }
+
+    // nq has no results for all segments
+    if (result_pairs.size() == 0) {
+        return 0;
+    }
+
+    int64_t dup_cnt = 0;
+    std::unordered_set<milvus::PkType> pk_set;
+    int64_t prev_offset = offset;
+    while (offset - prev_offset < topk) {
+        std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
+        auto& pilot = result_pairs[0];
+        auto index = pilot.segment_index_;
+        auto pk = pilot.primary_key_;
+        // no valid search result for this nq, break to next
+        if (pk == INVALID_PK) {
+            break;
+        }
+        // remove duplicates
+        if (pk_set.count(pk) == 0) {
+            pilot.search_result_->result_offsets_.push_back(offset++);
+            final_search_records_[index][qi].push_back(pilot.offset_);
+            pk_set.insert(pk);
+        } else {
+            // skip entity with same primary key
+            dup_cnt++;
+        }
+        pilot.reset();
+    }
+    return dup_cnt;
+}
+
+void
+ReduceHelper::ReduceResultData() {
     for (int i = 0; i < num_segments_; i++) {
         auto search_result = search_results_[i];
         auto result_count = search_result->get_total_result_count();
         AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
-        AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
         AssertInfo(search_result->distances_.size() == result_count, "incorrect search result distance size");
+        AssertInfo(search_result->seg_offsets_.size() == result_count, "incorrect search result seg offset size");
+        AssertInfo(search_result->primary_keys_.size() == result_count, "incorrect search result primary key size");
     }
 
-    auto nq_offset_begin = nq_slice_offsets_[slice_index];
-    auto nq_offset_end = nq_slice_offsets_[slice_index + 1];
-    AssertInfo(nq_offset_begin < nq_offset_end,
-               "illegal nq offsets when ReduceResultData, nq_offset_begin = " + std::to_string(nq_offset_begin) +
-                   ", nq_offset_end = " + std::to_string(nq_offset_end));
-
-    // `search_records` records the search result offsets
-    std::vector<std::vector<int64_t>> search_records(num_segments_);
     int64_t skip_dup_cnt = 0;
+    for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) {
+        auto nq_begin = slice_nqs_prefix_sum_[slice_index];
+        auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
 
-    // reduce search results
-    int64_t result_offset = 0;
-    for (int64_t qi = nq_offset_begin; qi < nq_offset_end; qi++) {
-        std::vector<SearchResultPair> result_pairs;
-        for (int i = 0; i < num_segments_; i++) {
-            auto search_result = search_results_[i];
-            if (search_result->topk_per_nq_prefix_sum_[qi + 1] - search_result->topk_per_nq_prefix_sum_[qi] == 0) {
-                continue;
-            }
-            auto base_offset = search_result->topk_per_nq_prefix_sum_[qi];
-            auto primary_key = search_result->primary_keys_[base_offset];
-            auto distance = search_result->distances_[base_offset];
-            result_pairs.emplace_back(primary_key, distance, search_result, i, base_offset,
-                                      search_result->topk_per_nq_prefix_sum_[qi + 1]);
-        }
-
-        // nq has no results for all segments
-        if (result_pairs.size() == 0) {
-            continue;
-        }
-        std::unordered_set<milvus::PkType> pk_set;
-        int64_t last_nq_result_offset = result_offset;
-        while (result_offset - last_nq_result_offset < slice_topKs_[slice_index]) {
-            std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
-            auto& pilot = result_pairs[0];
-            auto index = pilot.segment_index_;
-            auto curr_pk = pilot.primary_key_;
-            // no valid search result for this nq, break to next
-            if (curr_pk == INVALID_PK) {
-                break;
-            }
-            // remove duplicates
-            if (pk_set.count(curr_pk) == 0) {
-                pilot.search_result_->result_offsets_.push_back(result_offset++);
-                search_records[index].push_back(pilot.offset_);
-                pk_set.insert(curr_pk);
-                final_real_topKs_[index][qi]++;
-            } else {
-                // skip entity with same primary key
-                skip_dup_cnt++;
-            }
-            pilot.reset();
+        // reduce search results
+        int64_t result_offset = 0;
+        for (int64_t qi = nq_begin; qi < nq_end; qi++) {
+            skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], result_offset);
         }
     }
-
     if (skip_dup_cnt > 0) {
         LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = " << skip_dup_cnt;
     }
-
-    // append search_records to final_search_records
-    for (int i = 0; i < num_segments_; i++) {
-        for (int j = 0; j < search_records[i].size(); j++) {
-            final_search_records_[i].emplace_back(search_records[i][j]);
-        }
-    }
 }
 
 std::vector<char>
-ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
-    auto nq_offset_begin = nq_slice_offsets_[slice_index_];
-    auto nq_offset_end = nq_slice_offsets_[slice_index_ + 1];
-    AssertInfo(nq_offset_begin <= nq_offset_end,
-               "illegal offsets when GetSearchResultDataSlice, nq_offset_begin = " + std::to_string(nq_offset_begin) +
-                   ", nq_offset_end = " + std::to_string(nq_offset_end));
+ReduceHelper::GetSearchResultDataSlice(int slice_index) {
+    auto nq_begin = slice_nqs_prefix_sum_[slice_index];
+    auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
+
+    int64_t result_count = 0;
+    for (auto search_result : search_results_) {
+        AssertInfo(search_result->topk_per_nq_prefix_sum_.size() == search_result->total_nq_ + 1,
+                   "incorrect topk_per_nq_prefix_sum_ size in search result");
+        result_count +=
+            search_result->topk_per_nq_prefix_sum_[nq_end] - search_result->topk_per_nq_prefix_sum_[nq_begin];
+    }
 
     auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
     // set unify_topK and total_nq
-    search_result_data->set_top_k(slice_topKs_[slice_index_]);
-    search_result_data->set_num_queries(nq_offset_end - nq_offset_begin);
-    search_result_data->mutable_topks()->Resize(nq_offset_end - nq_offset_begin, 0);
+    search_result_data->set_top_k(slice_topKs_[slice_index]);
+    search_result_data->set_num_queries(nq_end - nq_begin);
+    search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
 
     // `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
     std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
@@ -306,19 +269,20 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
     search_result_data->mutable_scores()->Resize(result_count, 0);
 
     // fill pks and distances
-    for (auto nq_offset = nq_offset_begin; nq_offset < nq_offset_end; nq_offset++) {
-        int64_t topK_count = 0;
-        for (int i = 0; i < search_results_.size(); i++) {
-            auto search_result = search_results_[i];
+    for (auto qi = nq_begin; qi < nq_end; qi++) {
+        int64_t topk_count = 0;
+        for (auto search_result : search_results_) {
             AssertInfo(search_result != nullptr, "null search result when reorganize");
             if (search_result->result_offsets_.size() == 0) {
                 continue;
             }
 
-            auto result_start = search_result->topk_per_nq_prefix_sum_[nq_offset];
-            auto result_end = search_result->topk_per_nq_prefix_sum_[nq_offset + 1];
-            for (auto offset = result_start; offset < result_end; offset++) {
-                auto loc = search_result->result_offsets_[offset];
+            auto topk_start = search_result->topk_per_nq_prefix_sum_[qi];
+            auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
+            topk_count += topk_end - topk_start;
+
+            for (auto ki = topk_start; ki < topk_end; ki++) {
+                auto loc = search_result->result_offsets_[ki];
                 AssertInfo(loc < result_count && loc >= 0,
                            "invalid loc when GetSearchResultDataSlice, loc = " + std::to_string(loc) +
                                ", result_count = " + std::to_string(result_count));
@@ -326,12 +290,12 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
                 switch (pk_type) {
                     case milvus::DataType::INT64: {
                         search_result_data->mutable_ids()->mutable_int_id()->mutable_data()->Set(
-                            loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[offset]));
+                            loc, std::visit(Int64PKVisitor{}, search_result->primary_keys_[ki]));
                         break;
                     }
                     case milvus::DataType::VARCHAR: {
                         *search_result_data->mutable_ids()->mutable_str_id()->mutable_data()->Mutable(loc) =
-                            std::visit(StrPKVisitor{}, search_result->primary_keys_[offset]);
+                            std::visit(StrPKVisitor{}, search_result->primary_keys_[ki]);
                         break;
                     }
                     default: {
@@ -340,17 +304,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index_, int64_t result_count) {
                 }
 
                 // set result distances
-                search_result_data->mutable_scores()->Set(loc, search_result->distances_[offset]);
+                search_result_data->mutable_scores()->Set(loc, search_result->distances_[ki]);
                 // set result offset to fill output fields data
-                result_pairs[loc] = std::make_pair(search_result, offset);
+                result_pairs[loc] = std::make_pair(search_result, ki);
             }
-
-            topK_count += search_result->topk_per_nq_prefix_sum_[nq_offset + 1] -
-                          search_result->topk_per_nq_prefix_sum_[nq_offset];
         }
 
         // update result topKs
-        search_result_data->mutable_topks()->Set(nq_offset - nq_offset_begin, topK_count);
+        search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
     }
 
     AssertInfo(search_result_data->scores_size() == result_count,
diff --git a/internal/core/src/segcore/Reduce.h b/internal/core/src/segcore/Reduce.h
index fee85e25e4..9147731de7 100644
--- a/internal/core/src/segcore/Reduce.h
+++ b/internal/core/src/segcore/Reduce.h
@@ -61,15 +61,26 @@ class ReduceHelper {
     FilterInvalidSearchResult(SearchResult* search_result);
 
     void
-    ReduceResultData(int slice_index);
+    FillPrimaryKey();
+
+    void
+    RefreshSearchResult();
+
+    void
+    FillEntryData();
+
+    int64_t
+    ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset);
+
+    void
+    ReduceResultData();
 
     std::vector<char>
-    GetSearchResultDataSlice(int slice_index_, int64_t result_count);
+    GetSearchResultDataSlice(int slice_index_);
 
  private:
     std::vector<int64_t> slice_topKs_;
     std::vector<int64_t> slice_nqs_;
-    int64_t unify_topK_;
     int64_t total_nq_;
     int64_t num_segments_;
     int64_t num_slices_;
@@ -77,10 +88,10 @@ class ReduceHelper {
     milvus::query::Plan* plan_;
     std::vector<SearchResult*>& search_results_;
 
-    //
-    std::vector<int32_t> nq_slice_offsets_;
-    std::vector<std::vector<int64_t>> final_search_records_;
-    std::vector<std::vector<int64_t>> final_real_topKs_;
+    std::vector<int64_t> slice_nqs_prefix_sum_;
+
+    // dim0: num_segments_; dim1: total_nq_; dim2: offset
+    std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
 
     // output
     std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp
index 363423bc3f..55ba7fba86 100644
--- a/internal/core/unittest/test_c_api.cpp
+++ b/internal/core/unittest/test_c_api.cpp
@@ -1347,25 +1347,16 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
         auto suc = search_result_data.ParseFromArray(search_result_data_blobs->blobs[i].data(),
                                                      search_result_data_blobs->blobs[i].size());
         assert(suc);
-
-        assert(suc);
         assert(search_result_data.num_queries() == slice_nqs[i]);
         assert(search_result_data.top_k() == slice_topKs[i]);
-        assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]);
-        assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]);
+        assert(search_result_data.scores().size() == search_result_data.topks().at(0) * slice_nqs[i]);
+        assert(search_result_data.ids().int_id().data_size() == search_result_data.topks().at(0) * slice_nqs[i]);
 
-        // check topKs
+        // check real topks
         assert(search_result_data.topks().size() == slice_nqs[i]);
-        for (int j = 0; j < search_result_data.topks().size(); j++) {
-            assert(search_result_data.topks().at(j) == slice_topKs[i]);
+        for (auto real_topk : search_result_data.topks()) {
+            assert(real_topk <= slice_topKs[i]);
         }
-
-        // assert(search_result_data.scores().size() == slice_topKs[i] * slice_nqs[i]);
-        // assert(search_result_data.ids().int_id().data_size() == slice_topKs[i] * slice_nqs[i]);
-        // assert(search_result_data.top_k() == topK);
-        // assert(search_result_data.num_queries() == req_sizes[i]);
-        // assert(search_result_data.scores().size() == topK * req_sizes[i]);
-        // assert(search_result_data.ids().int_id().data_size() == topK * req_sizes[i]);
     }
 
     DeleteSearchResultDataBlobs(cSearchResultData);
@@ -1378,6 +1369,8 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
 }
 
 TEST(CApiTest, ReduceSearchWithExpr) {
+    testReduceSearchWithExpr(2, 1, 1);
+    testReduceSearchWithExpr(2, 10, 10);
     testReduceSearchWithExpr(100, 1, 1);
     testReduceSearchWithExpr(100, 10, 10);
     testReduceSearchWithExpr(10000, 1, 1);