mirror of https://github.com/milvus-io/milvus.git
Change the type of slice_nqs and slice_topks from int32_t[] to int64_t[] (#18867)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com> Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/18879/head
parent
8e22d03cf3
commit
9dc3bbecbd
|
@ -10,10 +10,11 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/Status.h"
|
||||
#include "common/type_c.h"
|
||||
|
@ -31,9 +32,13 @@ class ReduceHelper {
|
|||
public:
|
||||
explicit ReduceHelper(std::vector<SearchResult*>& search_results,
|
||||
milvus::query::Plan* plan,
|
||||
std::vector<int64_t>& slice_nqs,
|
||||
std::vector<int64_t>& slice_topKs)
|
||||
: search_results_(search_results), plan_(plan), slice_nqs_(slice_nqs), slice_topKs_(slice_topKs) {
|
||||
int64_t* slice_nqs,
|
||||
int64_t* slice_topKs,
|
||||
int64_t slice_num)
|
||||
: search_results_(search_results),
|
||||
plan_(plan),
|
||||
slice_nqs_(slice_nqs, slice_nqs + slice_num),
|
||||
slice_topKs_(slice_topKs, slice_topKs + slice_num) {
|
||||
Initialize();
|
||||
}
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
|||
CSearchPlan c_plan,
|
||||
CSearchResult* c_search_results,
|
||||
int64_t num_segments,
|
||||
int32_t* slice_nqs,
|
||||
int32_t* slice_topKs,
|
||||
int32_t num_slices) {
|
||||
int64_t* slice_nqs,
|
||||
int64_t* slice_topKs,
|
||||
int64_t num_slices) {
|
||||
try {
|
||||
// get SearchResult and SearchPlan
|
||||
auto plan = static_cast<milvus::query::Plan*>(c_plan);
|
||||
|
@ -38,15 +38,7 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
|||
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
|
||||
}
|
||||
|
||||
// get slice_nqs and slice_topKs
|
||||
auto slice_nqs_vec = std::vector<int64_t>(num_slices);
|
||||
auto slice_topKs_vec = std::vector<int64_t>(num_slices);
|
||||
for (int i = 0; i < num_slices; i++) {
|
||||
slice_nqs_vec[i] = slice_nqs[i];
|
||||
slice_topKs_vec[i] = slice_topKs[i];
|
||||
}
|
||||
|
||||
auto reduce_helper = milvus::segcore::ReduceHelper(search_results, plan, slice_nqs_vec, slice_topKs_vec);
|
||||
auto reduce_helper = milvus::segcore::ReduceHelper(search_results, plan, slice_nqs, slice_topKs, num_slices);
|
||||
reduce_helper.Reduce();
|
||||
reduce_helper.Marshal();
|
||||
|
||||
|
|
|
@ -24,9 +24,9 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
|||
CSearchPlan c_plan,
|
||||
CSearchResult* search_results,
|
||||
int64_t num_segments,
|
||||
int32_t* slice_nqs,
|
||||
int32_t* slice_topKs,
|
||||
int32_t num_slices);
|
||||
int64_t* slice_nqs,
|
||||
int64_t* slice_topKs,
|
||||
int64_t num_slices);
|
||||
|
||||
CStatus
|
||||
GetSearchResultDataBlob(CProto* searchResultDataBlob,
|
||||
|
|
|
@ -1137,8 +1137,8 @@ TEST(CApiTest, ReudceNullResult) {
|
|||
dataset.timestamps_.push_back(1);
|
||||
|
||||
{
|
||||
auto slice_nqs = std::vector<int32_t>{10};
|
||||
auto slice_topKs = std::vector<int32_t>{1};
|
||||
auto slice_nqs = std::vector<int64_t>{10};
|
||||
auto slice_topKs = std::vector<int64_t>{1};
|
||||
std::vector<CSearchResult> results;
|
||||
CSearchResult res;
|
||||
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res, -1);
|
||||
|
@ -1214,8 +1214,8 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
|
|||
dataset.timestamps_.push_back(1);
|
||||
|
||||
{
|
||||
auto slice_nqs = std::vector<int32_t>{num_queries / 2, num_queries / 2};
|
||||
auto slice_topKs = std::vector<int32_t>{topK / 2, topK};
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
|
||||
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
|
||||
std::vector<CSearchResult> results;
|
||||
CSearchResult res1, res2;
|
||||
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res1, -1);
|
||||
|
@ -1239,8 +1239,8 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
|
|||
int nq1 = num_queries / 3;
|
||||
int nq2 = num_queries / 3;
|
||||
int nq3 = num_queries - nq1 - nq2;
|
||||
auto slice_nqs = std::vector<int32_t>{nq1, nq2, nq3};
|
||||
auto slice_topKs = std::vector<int32_t>{topK / 2, topK, topK};
|
||||
auto slice_nqs = std::vector<int64_t>{nq1, nq2, nq3};
|
||||
auto slice_topKs = std::vector<int64_t>{topK / 2, topK, topK};
|
||||
std::vector<CSearchResult> results;
|
||||
CSearchResult res1, res2, res3;
|
||||
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res1, -1);
|
||||
|
@ -1324,13 +1324,13 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
|
|||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
|
||||
auto slice_nqs = std::vector<int32_t>{num_queries / 2, num_queries / 2};
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
|
||||
if (num_queries == 1) {
|
||||
slice_nqs = std::vector<int32_t>{num_queries};
|
||||
slice_nqs = std::vector<int64_t>{num_queries};
|
||||
}
|
||||
auto slice_topKs = std::vector<int32_t>{topK / 2, topK};
|
||||
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
|
||||
if (topK == 1) {
|
||||
slice_topKs = std::vector<int32_t>{topK, topK};
|
||||
slice_topKs = std::vector<int64_t>{topK, topK};
|
||||
}
|
||||
|
||||
// 1. reduce
|
||||
|
@ -2749,8 +2749,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
|
|||
std::vector<CSearchResult> results;
|
||||
results.push_back(c_search_result_on_bigIndex);
|
||||
|
||||
auto slice_nqs = std::vector<int32_t>{num_queries};
|
||||
auto slice_topKs = std::vector<int32_t>{topK};
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries};
|
||||
auto slice_topKs = std::vector<int64_t>{topK};
|
||||
|
||||
CSearchResultDataBlobs cSearchResultData;
|
||||
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(),
|
||||
|
@ -2915,8 +2915,8 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) {
|
|||
std::vector<CSearchResult> results;
|
||||
results.push_back(c_search_result_on_bigIndex);
|
||||
|
||||
auto slice_nqs = std::vector<int32_t>{num_queries};
|
||||
auto slice_topKs = std::vector<int32_t>{topK};
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries};
|
||||
auto slice_topKs = std::vector<int64_t>{topK};
|
||||
|
||||
CSearchResultDataBlobs cSearchResultData;
|
||||
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(),
|
||||
|
|
|
@ -1632,7 +1632,7 @@ func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) e
|
|||
if result.TopK != sliceTopKs[i] {
|
||||
return fmt.Errorf("unexpected topK when checkSearchResult")
|
||||
}
|
||||
if result.NumQueries != int64(sInfo.sliceNQs[i]) {
|
||||
if result.NumQueries != sInfo.sliceNQs[i] {
|
||||
return fmt.Errorf("unexpected nq when checkSearchResult")
|
||||
}
|
||||
// search empty segment, return empty result.IDs
|
||||
|
|
|
@ -28,8 +28,8 @@ import (
|
|||
)
|
||||
|
||||
type sliceInfo struct {
|
||||
sliceNQs []int32
|
||||
sliceTopKs []int32
|
||||
sliceNQs []int64
|
||||
sliceTopKs []int64
|
||||
}
|
||||
|
||||
// SearchResult contains a pointer to the search result in C++ memory
|
||||
|
@ -47,8 +47,8 @@ type RetrieveResult struct {
|
|||
|
||||
func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *sliceInfo {
|
||||
sInfo := &sliceInfo{
|
||||
sliceNQs: make([]int32, 0),
|
||||
sliceTopKs: make([]int32, 0),
|
||||
sliceNQs: make([]int64, 0),
|
||||
sliceTopKs: make([]int64, 0),
|
||||
}
|
||||
|
||||
if nqPerSlice == 0 {
|
||||
|
@ -57,12 +57,12 @@ func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *s
|
|||
|
||||
for i := 0; i < len(originNQs); i++ {
|
||||
for j := 0; j < int(originNQs[i]/nqPerSlice); j++ {
|
||||
sInfo.sliceNQs = append(sInfo.sliceNQs, int32(nqPerSlice))
|
||||
sInfo.sliceTopKs = append(sInfo.sliceTopKs, int32(originTopKs[i]))
|
||||
sInfo.sliceNQs = append(sInfo.sliceNQs, nqPerSlice)
|
||||
sInfo.sliceTopKs = append(sInfo.sliceTopKs, originTopKs[i])
|
||||
}
|
||||
if tailSliceSize := originNQs[i] % nqPerSlice; tailSliceSize > 0 {
|
||||
sInfo.sliceNQs = append(sInfo.sliceNQs, int32(tailSliceSize))
|
||||
sInfo.sliceTopKs = append(sInfo.sliceTopKs, int32(originTopKs[i]))
|
||||
sInfo.sliceNQs = append(sInfo.sliceNQs, tailSliceSize)
|
||||
sInfo.sliceTopKs = append(sInfo.sliceTopKs, originTopKs[i])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,7 +70,7 @@ func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *s
|
|||
}
|
||||
|
||||
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult,
|
||||
numSegments int64, sliceNQs []int32, sliceTopKs []int32) (searchResultDataBlobs, error) {
|
||||
numSegments int64, sliceNQs []int64, sliceTopKs []int64) (searchResultDataBlobs, error) {
|
||||
if plan.cSearchPlan == nil {
|
||||
return nil, fmt.Errorf("nil search plan")
|
||||
}
|
||||
|
@ -92,9 +92,9 @@ func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchRes
|
|||
}
|
||||
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
|
||||
cNumSegments := C.int64_t(numSegments)
|
||||
var cSliceNQSPtr = (*C.int32_t)(&sliceNQs[0])
|
||||
var cSliceTopKSPtr = (*C.int32_t)(&sliceTopKs[0])
|
||||
var cNumSlices = C.int32_t(len(sliceNQs))
|
||||
var cSliceNQSPtr = (*C.int64_t)(&sliceNQs[0])
|
||||
var cSliceTopKSPtr = (*C.int64_t)(&sliceTopKs[0])
|
||||
var cNumSlices = C.int64_t(len(sliceNQs))
|
||||
var cSearchResultDataBlobs searchResultDataBlobs
|
||||
status := C.ReduceSearchResultsAndFillData(&cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
|
||||
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)
|
||||
|
|
|
@ -37,8 +37,8 @@ func TestReduce_parseSliceInfo(t *testing.T) {
|
|||
nqPerSlice := int64(2)
|
||||
sInfo := parseSliceInfo(originNQs, originTopKs, nqPerSlice)
|
||||
|
||||
expectedSliceNQs := []int32{2, 2, 1, 2}
|
||||
expectedSliceTopKs := []int32{10, 5, 5, 20}
|
||||
expectedSliceNQs := []int64{2, 2, 1, 2}
|
||||
expectedSliceTopKs := []int64{10, 5, 5, 20}
|
||||
assert.True(t, funcutil.SliceSetEqual(sInfo.sliceNQs, expectedSliceNQs))
|
||||
assert.True(t, funcutil.SliceSetEqual(sInfo.sliceTopKs, expectedSliceTopKs))
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ func TestReduce_Invalid(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
searchResults = append(searchResults, nil)
|
||||
_, err = reduceSearchResultsAndFillData(searchReq.plan, searchResults, 1, []int32{10}, []int32{10})
|
||||
_, err = reduceSearchResultsAndFillData(searchReq.plan, searchResults, 1, []int64{10}, []int64{10})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue