Change the type of slice_nqs and slice_topks from int32_t[] to int64_t[] (#18867)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/18879/head
Cai Yudong 2022-08-29 11:36:56 +08:00 committed by GitHub
parent 8e22d03cf3
commit 9dc3bbecbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 47 additions and 50 deletions

View File

@ -10,10 +10,11 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <cstdint>
#include <vector>
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
#include "utils/Status.h"
#include "common/type_c.h"
@ -31,9 +32,13 @@ class ReduceHelper {
public:
explicit ReduceHelper(std::vector<SearchResult*>& search_results,
milvus::query::Plan* plan,
std::vector<int64_t>& slice_nqs,
std::vector<int64_t>& slice_topKs)
: search_results_(search_results), plan_(plan), slice_nqs_(slice_nqs), slice_topKs_(slice_topKs) {
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t slice_num)
: search_results_(search_results),
plan_(plan),
slice_nqs_(slice_nqs, slice_nqs + slice_num),
slice_topKs_(slice_topKs, slice_topKs + slice_num) {
Initialize();
}

View File

@ -26,9 +26,9 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchPlan c_plan,
CSearchResult* c_search_results,
int64_t num_segments,
int32_t* slice_nqs,
int32_t* slice_topKs,
int32_t num_slices) {
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t num_slices) {
try {
// get SearchResult and SearchPlan
auto plan = static_cast<milvus::query::Plan*>(c_plan);
@ -38,15 +38,7 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
}
// get slice_nqs and slice_topKs
auto slice_nqs_vec = std::vector<int64_t>(num_slices);
auto slice_topKs_vec = std::vector<int64_t>(num_slices);
for (int i = 0; i < num_slices; i++) {
slice_nqs_vec[i] = slice_nqs[i];
slice_topKs_vec[i] = slice_topKs[i];
}
auto reduce_helper = milvus::segcore::ReduceHelper(search_results, plan, slice_nqs_vec, slice_topKs_vec);
auto reduce_helper = milvus::segcore::ReduceHelper(search_results, plan, slice_nqs, slice_topKs, num_slices);
reduce_helper.Reduce();
reduce_helper.Marshal();

View File

@ -24,9 +24,9 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchPlan c_plan,
CSearchResult* search_results,
int64_t num_segments,
int32_t* slice_nqs,
int32_t* slice_topKs,
int32_t num_slices);
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t num_slices);
CStatus
GetSearchResultDataBlob(CProto* searchResultDataBlob,

View File

@ -1137,8 +1137,8 @@ TEST(CApiTest, ReudceNullResult) {
dataset.timestamps_.push_back(1);
{
auto slice_nqs = std::vector<int32_t>{10};
auto slice_topKs = std::vector<int32_t>{1};
auto slice_nqs = std::vector<int64_t>{10};
auto slice_topKs = std::vector<int64_t>{1};
std::vector<CSearchResult> results;
CSearchResult res;
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res, -1);
@ -1214,8 +1214,8 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
dataset.timestamps_.push_back(1);
{
auto slice_nqs = std::vector<int32_t>{num_queries / 2, num_queries / 2};
auto slice_topKs = std::vector<int32_t>{topK / 2, topK};
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
std::vector<CSearchResult> results;
CSearchResult res1, res2;
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res1, -1);
@ -1239,8 +1239,8 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
int nq1 = num_queries / 3;
int nq2 = num_queries / 3;
int nq3 = num_queries - nq1 - nq2;
auto slice_nqs = std::vector<int32_t>{nq1, nq2, nq3};
auto slice_topKs = std::vector<int32_t>{topK / 2, topK, topK};
auto slice_nqs = std::vector<int64_t>{nq1, nq2, nq3};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK, topK};
std::vector<CSearchResult> results;
CSearchResult res1, res2, res3;
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res1, -1);
@ -1324,13 +1324,13 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
results.push_back(res1);
results.push_back(res2);
auto slice_nqs = std::vector<int32_t>{num_queries / 2, num_queries / 2};
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
if (num_queries == 1) {
slice_nqs = std::vector<int32_t>{num_queries};
slice_nqs = std::vector<int64_t>{num_queries};
}
auto slice_topKs = std::vector<int32_t>{topK / 2, topK};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
if (topK == 1) {
slice_topKs = std::vector<int32_t>{topK, topK};
slice_topKs = std::vector<int64_t>{topK, topK};
}
// 1. reduce
@ -2749,8 +2749,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
std::vector<CSearchResult> results;
results.push_back(c_search_result_on_bigIndex);
auto slice_nqs = std::vector<int32_t>{num_queries};
auto slice_topKs = std::vector<int32_t>{topK};
auto slice_nqs = std::vector<int64_t>{num_queries};
auto slice_topKs = std::vector<int64_t>{topK};
CSearchResultDataBlobs cSearchResultData;
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(),
@ -2915,8 +2915,8 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) {
std::vector<CSearchResult> results;
results.push_back(c_search_result_on_bigIndex);
auto slice_nqs = std::vector<int32_t>{num_queries};
auto slice_topKs = std::vector<int32_t>{topK};
auto slice_nqs = std::vector<int64_t>{num_queries};
auto slice_topKs = std::vector<int64_t>{topK};
CSearchResultDataBlobs cSearchResultData;
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(),

View File

@ -1632,7 +1632,7 @@ func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) e
if result.TopK != sliceTopKs[i] {
return fmt.Errorf("unexpected topK when checkSearchResult")
}
if result.NumQueries != int64(sInfo.sliceNQs[i]) {
if result.NumQueries != sInfo.sliceNQs[i] {
return fmt.Errorf("unexpected nq when checkSearchResult")
}
// search empty segment, return empty result.IDs

View File

@ -28,8 +28,8 @@ import (
)
type sliceInfo struct {
sliceNQs []int32
sliceTopKs []int32
sliceNQs []int64
sliceTopKs []int64
}
// SearchResult contains a pointer to the search result in C++ memory
@ -47,8 +47,8 @@ type RetrieveResult struct {
func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *sliceInfo {
sInfo := &sliceInfo{
sliceNQs: make([]int32, 0),
sliceTopKs: make([]int32, 0),
sliceNQs: make([]int64, 0),
sliceTopKs: make([]int64, 0),
}
if nqPerSlice == 0 {
@ -57,12 +57,12 @@ func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *s
for i := 0; i < len(originNQs); i++ {
for j := 0; j < int(originNQs[i]/nqPerSlice); j++ {
sInfo.sliceNQs = append(sInfo.sliceNQs, int32(nqPerSlice))
sInfo.sliceTopKs = append(sInfo.sliceTopKs, int32(originTopKs[i]))
sInfo.sliceNQs = append(sInfo.sliceNQs, nqPerSlice)
sInfo.sliceTopKs = append(sInfo.sliceTopKs, originTopKs[i])
}
if tailSliceSize := originNQs[i] % nqPerSlice; tailSliceSize > 0 {
sInfo.sliceNQs = append(sInfo.sliceNQs, int32(tailSliceSize))
sInfo.sliceTopKs = append(sInfo.sliceTopKs, int32(originTopKs[i]))
sInfo.sliceNQs = append(sInfo.sliceNQs, tailSliceSize)
sInfo.sliceTopKs = append(sInfo.sliceTopKs, originTopKs[i])
}
}
@ -70,7 +70,7 @@ func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *s
}
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult,
numSegments int64, sliceNQs []int32, sliceTopKs []int32) (searchResultDataBlobs, error) {
numSegments int64, sliceNQs []int64, sliceTopKs []int64) (searchResultDataBlobs, error) {
if plan.cSearchPlan == nil {
return nil, fmt.Errorf("nil search plan")
}
@ -92,9 +92,9 @@ func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchRes
}
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
cNumSegments := C.int64_t(numSegments)
var cSliceNQSPtr = (*C.int32_t)(&sliceNQs[0])
var cSliceTopKSPtr = (*C.int32_t)(&sliceTopKs[0])
var cNumSlices = C.int32_t(len(sliceNQs))
var cSliceNQSPtr = (*C.int64_t)(&sliceNQs[0])
var cSliceTopKSPtr = (*C.int64_t)(&sliceTopKs[0])
var cNumSlices = C.int64_t(len(sliceNQs))
var cSearchResultDataBlobs searchResultDataBlobs
status := C.ReduceSearchResultsAndFillData(&cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)

View File

@ -37,8 +37,8 @@ func TestReduce_parseSliceInfo(t *testing.T) {
nqPerSlice := int64(2)
sInfo := parseSliceInfo(originNQs, originTopKs, nqPerSlice)
expectedSliceNQs := []int32{2, 2, 1, 2}
expectedSliceTopKs := []int32{10, 5, 5, 20}
expectedSliceNQs := []int64{2, 2, 1, 2}
expectedSliceTopKs := []int64{10, 5, 5, 20}
assert.True(t, funcutil.SliceSetEqual(sInfo.sliceNQs, expectedSliceNQs))
assert.True(t, funcutil.SliceSetEqual(sInfo.sliceTopKs, expectedSliceTopKs))
}
@ -117,7 +117,7 @@ func TestReduce_Invalid(t *testing.T) {
assert.NoError(t, err)
searchResults := make([]*SearchResult, 0)
searchResults = append(searchResults, nil)
_, err = reduceSearchResultsAndFillData(searchReq.plan, searchResults, 1, []int32{10}, []int32{10})
_, err = reduceSearchResultsAndFillData(searchReq.plan, searchResults, 1, []int64{10}, []int64{10})
assert.Error(t, err)
})
}