mirror of https://github.com/milvus-io/milvus.git
Support return primary key when doing search
Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/4973/head^2
parent
0530fdf62f
commit
36cf8a8ea7
|
@ -50,89 +50,167 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
|||
}
|
||||
|
||||
struct SearchResultPair {
|
||||
uint64_t id_;
|
||||
float distance_;
|
||||
int64_t segment_id_;
|
||||
SearchResult* search_result_;
|
||||
int64_t offset_;
|
||||
int64_t index_;
|
||||
|
||||
SearchResultPair(uint64_t id, float distance, int64_t segment_id)
|
||||
: id_(id), distance_(distance), segment_id_(segment_id) {
|
||||
SearchResultPair(float distance, SearchResult* search_result, int64_t offset, int64_t index)
|
||||
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator<(const SearchResultPair& pair) const {
|
||||
return (distance_ < pair.distance_);
|
||||
}
|
||||
|
||||
void
|
||||
reset_distance() {
|
||||
distance_ = search_result_->result_distances_[offset_];
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
GetResultData(std::vector<SearchResult*>& search_results,
|
||||
SearchResult& final_result,
|
||||
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||
std::vector<SearchResult*>& search_results,
|
||||
int64_t query_offset,
|
||||
bool* is_selected,
|
||||
int64_t topk) {
|
||||
auto num_segments = search_results.size();
|
||||
std::map<int, int> iter_loc_peer_result;
|
||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||
std::vector<SearchResultPair> result_pairs;
|
||||
for (int j = 0; j < num_segments; ++j) {
|
||||
auto id = search_results[j]->result_ids_[query_offset];
|
||||
auto distance = search_results[j]->result_distances_[query_offset];
|
||||
result_pairs.push_back(SearchResultPair(id, distance, j));
|
||||
iter_loc_peer_result[j] = query_offset;
|
||||
auto search_result = search_results[j];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
|
||||
}
|
||||
int64_t loc_offset = query_offset;
|
||||
AssertInfo(topk > 0, "topK must greater than 0");
|
||||
for (int i = 0; i < topk; ++i) {
|
||||
result_pairs[0].reset_distance();
|
||||
std::sort(result_pairs.begin(), result_pairs.end());
|
||||
final_result.result_ids_.push_back(result_pairs[0].id_);
|
||||
final_result.result_distances_.push_back(result_pairs[0].distance_);
|
||||
|
||||
for (int i = 1; i < topk; ++i) {
|
||||
auto segment_id = result_pairs[0].segment_id_;
|
||||
auto query_offset = ++(iter_loc_peer_result[segment_id]);
|
||||
auto id = search_results[segment_id]->result_ids_[query_offset];
|
||||
auto distance = search_results[segment_id]->result_distances_[query_offset];
|
||||
result_pairs[0] = SearchResultPair(id, distance, segment_id);
|
||||
std::sort(result_pairs.begin(), result_pairs.end());
|
||||
final_result.result_ids_.push_back(result_pairs[0].id_);
|
||||
final_result.result_distances_.push_back(result_pairs[0].distance_);
|
||||
auto& result_pair = result_pairs[0];
|
||||
auto index = result_pair.index_;
|
||||
is_selected[index] = true;
|
||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
}
|
||||
}
|
||||
|
||||
CQueryResult
|
||||
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments) {
|
||||
void
|
||||
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
|
||||
std::vector<SearchResult*>& search_results,
|
||||
bool* is_selected) {
|
||||
auto num_segments = search_results.size();
|
||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
if (is_selected[i] == false) {
|
||||
continue;
|
||||
}
|
||||
auto search_result = search_results[i];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
|
||||
std::vector<float> result_distances;
|
||||
std::vector<int64_t> internal_seg_offsets;
|
||||
std::vector<int64_t> result_ids;
|
||||
|
||||
for (int j = 0; j < search_records[i].size(); j++) {
|
||||
auto& offset = search_records[i][j];
|
||||
auto distance = search_result->result_distances_[offset];
|
||||
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
|
||||
auto id = search_result->result_ids_[offset];
|
||||
result_distances.push_back(distance);
|
||||
internal_seg_offsets.push_back(internal_seg_offset);
|
||||
result_ids.push_back(id);
|
||||
}
|
||||
|
||||
search_result->result_distances_ = result_distances;
|
||||
search_result->internal_seg_offsets_ = internal_seg_offsets;
|
||||
search_result->result_ids_ = result_ids;
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
ReduceQueryResults(CQueryResult* c_search_results, int64_t num_segments, bool* is_selected) {
|
||||
std::vector<SearchResult*> search_results;
|
||||
for (int i = 0; i < num_segments; ++i) {
|
||||
search_results.push_back((SearchResult*)query_results[i]);
|
||||
search_results.push_back((SearchResult*)c_search_results[i]);
|
||||
}
|
||||
try {
|
||||
auto topk = search_results[0]->topK_;
|
||||
auto num_queries = search_results[0]->num_queries_;
|
||||
auto final_result = std::make_unique<SearchResult>();
|
||||
std::vector<std::vector<int64_t>> search_records(num_segments);
|
||||
|
||||
int64_t query_offset = 0;
|
||||
for (int j = 0; j < num_queries; ++j) {
|
||||
GetResultData(search_results, *final_result, query_offset, topk);
|
||||
GetResultData(search_records, search_results, query_offset, is_selected, topk);
|
||||
query_offset += topk;
|
||||
}
|
||||
|
||||
return (CQueryResult)final_result.release();
|
||||
ResetSearchResult(search_records, search_results, is_selected);
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
CMarshaledHits
|
||||
ReorganizeQueryResults(CQueryResult c_query_result,
|
||||
CPlan c_plan,
|
||||
CStatus
|
||||
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
|
||||
CPlaceholderGroup* c_placeholder_groups,
|
||||
int64_t num_groups) {
|
||||
int64_t num_groups,
|
||||
CQueryResult* c_search_results,
|
||||
bool* is_selected,
|
||||
int64_t num_segments,
|
||||
CPlan c_plan) {
|
||||
try {
|
||||
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
|
||||
auto search_result = (milvus::engine::QueryResult*)c_query_result;
|
||||
auto& result_ids = search_result->result_ids_;
|
||||
auto& result_distances = search_result->result_distances_;
|
||||
auto topk = GetTopK(c_plan);
|
||||
int64_t queries_offset = 0;
|
||||
std::vector<int64_t> num_queries_peer_group;
|
||||
int64_t total_num_queries = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
||||
num_queries_peer_group.push_back(num_queries);
|
||||
total_num_queries += num_queries;
|
||||
}
|
||||
|
||||
std::vector<float> result_distances(total_num_queries * topk);
|
||||
std::vector<int64_t> result_ids(total_num_queries * topk);
|
||||
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
|
||||
|
||||
int64_t count = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
if (is_selected[i] == false) {
|
||||
continue;
|
||||
}
|
||||
auto search_result = (SearchResult*)c_search_results[i];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
auto size = search_result->result_offsets_.size();
|
||||
for (int j = 0; j < size; j++) {
|
||||
auto loc = search_result->result_offsets_[j];
|
||||
result_distances[loc] = search_result->result_distances_[j];
|
||||
row_datas[loc] = search_result->row_data_[j];
|
||||
result_ids[loc] = search_result->result_ids_[j];
|
||||
}
|
||||
count += size;
|
||||
}
|
||||
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
|
||||
|
||||
int64_t fill_hit_offset = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
||||
for (int j = 0; j < num_queries; j++) {
|
||||
auto index = topk * queries_offset++;
|
||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
||||
milvus::proto::service::Hits hits;
|
||||
for (int k = index; k < index + topk; k++) {
|
||||
hits.add_ids(result_ids[k]);
|
||||
hits.add_scores(result_distances[k]);
|
||||
for (int k = 0; k < topk; k++, fill_hit_offset++) {
|
||||
hits.add_ids(result_ids[fill_hit_offset]);
|
||||
hits.add_scores(result_distances[fill_hit_offset]);
|
||||
auto& row_data = row_datas[fill_hit_offset];
|
||||
hits.add_row_data(row_data.data(), row_data.size());
|
||||
}
|
||||
auto blob = hits.SerializeAsString();
|
||||
hits_peer_group.hits_.push_back(blob);
|
||||
|
@ -140,7 +218,19 @@ ReorganizeQueryResults(CQueryResult c_query_result,
|
|||
}
|
||||
}
|
||||
|
||||
return (CMarshaledHits)marshaledHits.release();
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
auto marshled_res = (CMarshaledHits)marshaledHits.release();
|
||||
*c_marshaled_hits = marshled_res;
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
*c_marshaled_hits = nullptr;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
|
|
@ -25,14 +25,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
|
|||
int
|
||||
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
|
||||
|
||||
CQueryResult
|
||||
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments);
|
||||
CStatus
|
||||
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments, bool* is_selected);
|
||||
|
||||
CMarshaledHits
|
||||
ReorganizeQueryResults(CQueryResult query_result,
|
||||
CPlan c_plan,
|
||||
CStatus
|
||||
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
|
||||
CPlaceholderGroup* c_placeholder_groups,
|
||||
int64_t num_groups);
|
||||
int64_t num_groups,
|
||||
CQueryResult* c_search_results,
|
||||
bool* is_selected,
|
||||
int64_t num_segments,
|
||||
CPlan c_plan);
|
||||
|
||||
int64_t
|
||||
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
|
||||
|
|
|
@ -155,6 +155,24 @@ Search(CSegmentBase c_segment,
|
|||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
|
||||
auto segment = (milvus::segcore::SegmentBase*)c_segment;
|
||||
auto plan = (milvus::query::Plan*)c_plan;
|
||||
auto result = (milvus::engine::QueryResult*)c_result;
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
auto res = segment->FillTargetEntry(plan, *result);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
|
|
|
@ -61,6 +61,9 @@ Search(CSegmentBase c_segment,
|
|||
int num_groups,
|
||||
CQueryResult* result);
|
||||
|
||||
CStatus
|
||||
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult result);
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
|
|
|
@ -641,8 +641,14 @@ TEST(CApiTest, Reduce) {
|
|||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
|
||||
auto reduced_search_result = ReduceQueryResults(results.data(), 2);
|
||||
auto reorganize_search_result = ReorganizeQueryResults(reduced_search_result, plan, placeholderGroups.data(), 1);
|
||||
bool is_selected[1] = {false};
|
||||
status = ReduceQueryResults(results.data(), 1, is_selected);
|
||||
assert(status.error_code == Success);
|
||||
FillTargetEntry(segment, plan, res1);
|
||||
void* reorganize_search_result = nullptr;
|
||||
status = ReorganizeQueryResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), is_selected,
|
||||
1, plan);
|
||||
assert(status.error_code == Success);
|
||||
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
||||
assert(hits_blob_size > 0);
|
||||
std::vector<char> hits_blob;
|
||||
|
@ -660,7 +666,6 @@ TEST(CApiTest, Reduce) {
|
|||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteQueryResult(res1);
|
||||
DeleteQueryResult(res2);
|
||||
DeleteQueryResult(reduced_search_result);
|
||||
DeleteMarshaledHits(reorganize_search_result);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
|
|
|
@ -70,11 +70,22 @@ func (ms *PulsarMsgStream) SetPulsarClient(address string) {
|
|||
|
||||
func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) {
|
||||
for i := 0; i < len(channels); i++ {
|
||||
fn := func() error {
|
||||
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
|
||||
if err != nil {
|
||||
log.Printf("Failed to create querynode producer %s, error = %v", channels[i], err)
|
||||
return err
|
||||
}
|
||||
if pp == nil {
|
||||
return errors.New("pulsar is not ready, producer is nil")
|
||||
}
|
||||
ms.producers = append(ms.producers, &pp)
|
||||
return nil
|
||||
}
|
||||
err := Retry(10, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
errMsg := "Failed to create producer " + channels[i] + ", error = " + err.Error()
|
||||
panic(errMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,7 +115,8 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string,
|
|||
}
|
||||
err := Retry(10, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
panic("create pulsar consumer timeout!")
|
||||
errMsg := "Failed to create consumer " + channels[i] + ", error = " + err.Error()
|
||||
panic(errMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -239,10 +251,6 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() {
|
|||
|
||||
cases := make([]reflect.SelectCase, len(ms.consumers))
|
||||
for i := 0; i < len(ms.consumers); i++ {
|
||||
pc := *ms.consumers[i]
|
||||
if pc == nil {
|
||||
panic("pc is nil")
|
||||
}
|
||||
ch := (*ms.consumers[i]).Chan()
|
||||
cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@ package querynode
|
|||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
@ -21,26 +23,66 @@ type MarshaledHits struct {
|
|||
cMarshaledHits C.CMarshaledHits
|
||||
}
|
||||
|
||||
func reduceSearchResults(searchResults []*SearchResult, numSegments int64) *SearchResult {
|
||||
func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error {
|
||||
cSearchResults := make([]C.CQueryResult, 0)
|
||||
for _, res := range searchResults {
|
||||
cSearchResults = append(cSearchResults, res.cQueryResult)
|
||||
}
|
||||
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
|
||||
cNumSegments := C.long(numSegments)
|
||||
res := C.ReduceQueryResults(cSearchResultPtr, cNumSegments)
|
||||
return &SearchResult{cQueryResult: res}
|
||||
cInReduced := (*C.bool)(&inReduced[0])
|
||||
|
||||
status := C.ReduceQueryResults(cSearchResultPtr, cNumSegments, cInReduced)
|
||||
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("reduceSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sr *SearchResult) reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup) *MarshaledHits {
|
||||
func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
|
||||
for i, value := range inReduced {
|
||||
if value {
|
||||
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
|
||||
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
||||
for _, pg := range placeholderGroups {
|
||||
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
||||
}
|
||||
cNumGroup := (C.long)(len(placeholderGroups))
|
||||
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
res := C.ReorganizeQueryResults(sr.cQueryResult, plan.cPlan, cPlaceHolder, cNumGroup)
|
||||
return &MarshaledHits{cMarshaledHits: res}
|
||||
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
var cNumGroup = (C.long)(len(placeholderGroups))
|
||||
|
||||
cSearchResults := make([]C.CQueryResult, 0)
|
||||
for _, res := range searchResults {
|
||||
cSearchResults = append(cSearchResults, res.cQueryResult)
|
||||
}
|
||||
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
|
||||
|
||||
var cNumSegments = C.long(numSegments)
|
||||
var cInReduced = (*C.bool)(&inReduced[0])
|
||||
var cMarshaledHits C.CMarshaledHits
|
||||
|
||||
status := C.ReorganizeQueryResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cPlan)
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
|
||||
}
|
||||
|
||||
func (mh *MarshaledHits) getHitsBlobSize() int64 {
|
||||
|
|
|
@ -107,15 +107,21 @@ func TestReduce_AllFunc(t *testing.T) {
|
|||
placeholderGroups = append(placeholderGroups, holder)
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
matchedSegment := make([]*Segment, 0)
|
||||
searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{0})
|
||||
assert.Nil(t, err)
|
||||
searchResults = append(searchResults, searchResult)
|
||||
matchedSegment = append(matchedSegment, segment)
|
||||
|
||||
reducedSearchResults := reduceSearchResults(searchResults, 1)
|
||||
assert.NotNil(t, reducedSearchResults)
|
||||
testReduce := make([]bool, len(searchResults))
|
||||
err = reduceSearchResults(searchResults, 1, testReduce)
|
||||
assert.Nil(t, err)
|
||||
err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce)
|
||||
assert.Nil(t, err)
|
||||
|
||||
marshaledHits := reducedSearchResults.reorganizeQueryResults(plan, placeholderGroups)
|
||||
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, 1, testReduce)
|
||||
assert.NotNil(t, marshaledHits)
|
||||
assert.Nil(t, err)
|
||||
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
assert.Nil(t, err)
|
||||
|
@ -137,7 +143,6 @@ func TestReduce_AllFunc(t *testing.T) {
|
|||
plan.delete()
|
||||
holder.delete()
|
||||
deleteSearchResults(searchResults)
|
||||
deleteSearchResults([]*SearchResult{reducedSearchResults})
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
deleteSegment(segment)
|
||||
deleteCollection(collection)
|
||||
|
|
|
@ -238,6 +238,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
placeholderGroups = append(placeholderGroups, placeholderGroup)
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
matchedSegments := make([]*Segment, 0)
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
|
||||
|
@ -257,6 +258,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
return err
|
||||
}
|
||||
searchResults = append(searchResults, searchResult)
|
||||
matchedSegments = append(matchedSegments, segment)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -282,8 +284,20 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
reducedSearchResult := reduceSearchResults(searchResults, int64(len(searchResults)))
|
||||
marshaledHits := reducedSearchResult.reorganizeQueryResults(plan, placeholderGroups)
|
||||
inReduced := make([]bool, len(searchResults))
|
||||
numSegment := int64(len(searchResults))
|
||||
err = reduceSearchResults(searchResults, numSegment, inReduced)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -291,12 +305,12 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
|
||||
var offset int64 = 0
|
||||
for index := range placeholderGroups {
|
||||
hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hits := make([][]byte, 0)
|
||||
for _, len := range hitBolbSizePeerQuery {
|
||||
for _, len := range hitBlobSizePeerQuery {
|
||||
hits = append(hits, hitsBlob[offset:offset+len])
|
||||
//test code to checkout marshaled hits
|
||||
//marshaledHit := hitsBlob[offset:offset+len]
|
||||
|
@ -329,7 +343,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
}
|
||||
|
||||
deleteSearchResults(searchResults)
|
||||
deleteSearchResults([]*SearchResult{reducedSearchResult})
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
plan.delete()
|
||||
placeholderGroup.delete()
|
||||
|
|
|
@ -253,3 +253,242 @@ func TestSearch_Search(t *testing.T) {
|
|||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestSearch_SearchMultiSegments(t *testing.T) {
|
||||
Params.Init()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 1024
|
||||
const receiveBufSize = 1024
|
||||
const DIM = 16
|
||||
insertProducerChannels := Params.insertChannelNames()
|
||||
searchProducerChannels := Params.searchChannelNames()
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
|
||||
// start search service
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
var searchRawData1 []byte
|
||||
var searchRawData2 []byte
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
searchRawData1 = append(searchRawData1, buf...)
|
||||
}
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
|
||||
searchRawData2 = append(searchRawData2, buf...)
|
||||
}
|
||||
placeholderValue := servicepb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
|
||||
Values: [][]byte{searchRawData1, searchRawData2},
|
||||
}
|
||||
|
||||
placeholderGroup := servicepb.PlaceholderGroup{
|
||||
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
|
||||
}
|
||||
|
||||
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
if err != nil {
|
||||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
query := servicepb.Query{
|
||||
CollectionName: "collection0",
|
||||
PartitionTags: []string{"default"},
|
||||
Dsl: dslString,
|
||||
PlaceholderGroup: placeGroupByte,
|
||||
}
|
||||
|
||||
queryByte, err := proto.Marshal(&query)
|
||||
if err != nil {
|
||||
log.Print("marshal query failed")
|
||||
}
|
||||
|
||||
blob := commonpb.Blob{
|
||||
Value: queryByte,
|
||||
}
|
||||
|
||||
searchMsg := &msgstream.SearchMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
SearchRequest: internalpb.SearchRequest{
|
||||
MsgType: internalpb.MsgType_kSearch,
|
||||
ReqID: int64(1),
|
||||
ProxyID: int64(1),
|
||||
Timestamp: uint64(10 + 1000),
|
||||
ResultChannelID: int64(0),
|
||||
Query: &blob,
|
||||
},
|
||||
}
|
||||
|
||||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
searchStream.Start()
|
||||
err = searchStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
go node.searchService.start()
|
||||
|
||||
// start insert
|
||||
timeRange := TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: math.MaxUint64,
|
||||
}
|
||||
|
||||
insertMessages := make([]msgstream.TsMsg, 0)
|
||||
for i := 0; i < msgLength; i++ {
|
||||
segmentID := 0
|
||||
if i >= msgLength/2 {
|
||||
segmentID = 1
|
||||
}
|
||||
var rawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
bs := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bs, 1)
|
||||
rawData = append(rawData, bs...)
|
||||
|
||||
var msg msgstream.TsMsg = &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{
|
||||
uint32(i),
|
||||
},
|
||||
},
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
MsgType: internalpb.MsgType_kInsert,
|
||||
ReqID: int64(i),
|
||||
CollectionName: "collection0",
|
||||
PartitionTag: "default",
|
||||
SegmentID: int64(segmentID),
|
||||
ChannelID: int64(0),
|
||||
ProxyID: int64(0),
|
||||
Timestamps: []uint64{uint64(i + 1000)},
|
||||
RowIDs: []int64{int64(i)},
|
||||
RowData: []*commonpb.Blob{
|
||||
{Value: rawData},
|
||||
},
|
||||
},
|
||||
}
|
||||
insertMessages = append(insertMessages, msg)
|
||||
}
|
||||
|
||||
msgPack := msgstream.MsgPack{
|
||||
BeginTs: timeRange.timestampMin,
|
||||
EndTs: timeRange.timestampMax,
|
||||
Msgs: insertMessages,
|
||||
}
|
||||
|
||||
// generate timeTick
|
||||
timeTickMsgPack := msgstream.MsgPack{}
|
||||
baseMsg := msgstream.BaseMsg{
|
||||
BeginTimestamp: 0,
|
||||
EndTimestamp: 0,
|
||||
HashValues: []uint32{0},
|
||||
}
|
||||
timeTickResult := internalpb.TimeTickMsg{
|
||||
MsgType: internalpb.MsgType_kTimeTick,
|
||||
PeerID: UniqueID(0),
|
||||
Timestamp: math.MaxUint64,
|
||||
}
|
||||
timeTickMsg := &msgstream.TimeTickMsg{
|
||||
BaseMsg: baseMsg,
|
||||
TimeTickMsg: timeTickResult,
|
||||
}
|
||||
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
|
||||
|
||||
// pulsar produce
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
insertStream.Start()
|
||||
err = insertStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
err = insertStream.Broadcast(&timeTickMsgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -208,7 +208,7 @@ func (s *Segment) segmentSearch(plan *Plan,
|
|||
var cTimestamp = (*C.ulong)(×tamp[0])
|
||||
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
var cNumGroups = C.int(len(placeHolderGroups))
|
||||
cQueryResult := (*C.CQueryResult)(&searchResult.cQueryResult)
|
||||
var cQueryResult = (*C.CQueryResult)(&searchResult.cQueryResult)
|
||||
|
||||
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cQueryResult)
|
||||
errorCode := status.error_code
|
||||
|
@ -221,3 +221,18 @@ func (s *Segment) segmentSearch(plan *Plan,
|
|||
|
||||
return &searchResult, nil
|
||||
}
|
||||
|
||||
func (s *Segment) fillTargetEntry(plan *Plan,
|
||||
result *SearchResult) error {
|
||||
|
||||
var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult)
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue