Support return primary key when doing search

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/4973/head^2
xige-16 2020-12-03 17:18:06 +08:00 committed by yefu.chen
parent 0530fdf62f
commit 36cf8a8ea7
11 changed files with 536 additions and 95 deletions

View File

@ -50,89 +50,167 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
}
struct SearchResultPair {
uint64_t id_;
float distance_;
int64_t segment_id_;
SearchResult* search_result_;
int64_t offset_;
int64_t index_;
SearchResultPair(uint64_t id, float distance, int64_t segment_id)
: id_(id), distance_(distance), segment_id_(segment_id) {
SearchResultPair(float distance, SearchResult* search_result, int64_t offset, int64_t index)
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
}
bool
operator<(const SearchResultPair& pair) const {
return (distance_ < pair.distance_);
}
void
reset_distance() {
distance_ = search_result_->result_distances_[offset_];
}
};
void
GetResultData(std::vector<SearchResult*>& search_results,
SearchResult& final_result,
GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
int64_t query_offset,
bool* is_selected,
int64_t topk) {
auto num_segments = search_results.size();
std::map<int, int> iter_loc_peer_result;
AssertInfo(num_segments > 0, "num segment must greater than 0");
std::vector<SearchResultPair> result_pairs;
for (int j = 0; j < num_segments; ++j) {
auto id = search_results[j]->result_ids_[query_offset];
auto distance = search_results[j]->result_distances_[query_offset];
result_pairs.push_back(SearchResultPair(id, distance, j));
iter_loc_peer_result[j] = query_offset;
auto search_result = search_results[j];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
}
int64_t loc_offset = query_offset;
AssertInfo(topk > 0, "topK must greater than 0");
for (int i = 0; i < topk; ++i) {
result_pairs[0].reset_distance();
std::sort(result_pairs.begin(), result_pairs.end());
final_result.result_ids_.push_back(result_pairs[0].id_);
final_result.result_distances_.push_back(result_pairs[0].distance_);
for (int i = 1; i < topk; ++i) {
auto segment_id = result_pairs[0].segment_id_;
auto query_offset = ++(iter_loc_peer_result[segment_id]);
auto id = search_results[segment_id]->result_ids_[query_offset];
auto distance = search_results[segment_id]->result_distances_[query_offset];
result_pairs[0] = SearchResultPair(id, distance, segment_id);
std::sort(result_pairs.begin(), result_pairs.end());
final_result.result_ids_.push_back(result_pairs[0].id_);
final_result.result_distances_.push_back(result_pairs[0].distance_);
auto& result_pair = result_pairs[0];
auto index = result_pair.index_;
is_selected[index] = true;
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
search_records[index].push_back(result_pair.offset_++);
}
}
CQueryResult
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments) {
void
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
bool* is_selected) {
auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0");
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
}
auto search_result = search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
std::vector<float> result_distances;
std::vector<int64_t> internal_seg_offsets;
std::vector<int64_t> result_ids;
for (int j = 0; j < search_records[i].size(); j++) {
auto& offset = search_records[i][j];
auto distance = search_result->result_distances_[offset];
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
auto id = search_result->result_ids_[offset];
result_distances.push_back(distance);
internal_seg_offsets.push_back(internal_seg_offset);
result_ids.push_back(id);
}
search_result->result_distances_ = result_distances;
search_result->internal_seg_offsets_ = internal_seg_offsets;
search_result->result_ids_ = result_ids;
}
}
CStatus
ReduceQueryResults(CQueryResult* c_search_results, int64_t num_segments, bool* is_selected) {
std::vector<SearchResult*> search_results;
for (int i = 0; i < num_segments; ++i) {
search_results.push_back((SearchResult*)query_results[i]);
search_results.push_back((SearchResult*)c_search_results[i]);
}
try {
auto topk = search_results[0]->topK_;
auto num_queries = search_results[0]->num_queries_;
auto final_result = std::make_unique<SearchResult>();
std::vector<std::vector<int64_t>> search_records(num_segments);
int64_t query_offset = 0;
for (int j = 0; j < num_queries; ++j) {
GetResultData(search_results, *final_result, query_offset, topk);
GetResultData(search_records, search_results, query_offset, is_selected, topk);
query_offset += topk;
}
return (CQueryResult)final_result.release();
ResetSearchResult(search_records, search_results, is_selected);
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
return status;
}
}
CMarshaledHits
ReorganizeQueryResults(CQueryResult c_query_result,
CPlan c_plan,
CStatus
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups) {
int64_t num_groups,
CQueryResult* c_search_results,
bool* is_selected,
int64_t num_segments,
CPlan c_plan) {
try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto search_result = (milvus::engine::QueryResult*)c_query_result;
auto& result_ids = search_result->result_ids_;
auto& result_distances = search_result->result_distances_;
auto topk = GetTopK(c_plan);
int64_t queries_offset = 0;
std::vector<int64_t> num_queries_peer_group;
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries);
total_num_queries += num_queries;
}
std::vector<float> result_distances(total_num_queries * topk);
std::vector<int64_t> result_ids(total_num_queries * topk);
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
int64_t count = 0;
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
}
auto search_result = (SearchResult*)c_search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto size = search_result->result_offsets_.size();
for (int j = 0; j < size; j++) {
auto loc = search_result->result_offsets_[j];
result_distances[loc] = search_result->result_distances_[j];
row_datas[loc] = search_result->row_data_[j];
result_ids[loc] = search_result->result_ids_[j];
}
count += size;
}
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
int64_t fill_hit_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries; j++) {
auto index = topk * queries_offset++;
for (int j = 0; j < num_queries_peer_group[i]; j++) {
milvus::proto::service::Hits hits;
for (int k = index; k < index + topk; k++) {
hits.add_ids(result_ids[k]);
hits.add_scores(result_distances[k]);
for (int k = 0; k < topk; k++, fill_hit_offset++) {
hits.add_ids(result_ids[fill_hit_offset]);
hits.add_scores(result_distances[fill_hit_offset]);
auto& row_data = row_datas[fill_hit_offset];
hits.add_row_data(row_data.data(), row_data.size());
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
@ -140,7 +218,19 @@ ReorganizeQueryResults(CQueryResult c_query_result,
}
}
return (CMarshaledHits)marshaledHits.release();
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto marshled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshled_res;
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
*c_marshaled_hits = nullptr;
return status;
}
}
int64_t

View File

@ -25,14 +25,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
CQueryResult
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments);
CStatus
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments, bool* is_selected);
CMarshaledHits
ReorganizeQueryResults(CQueryResult query_result,
CPlan c_plan,
CStatus
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups);
int64_t num_groups,
CQueryResult* c_search_results,
bool* is_selected,
int64_t num_segments,
CPlan c_plan);
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);

View File

@ -155,6 +155,24 @@ Search(CSegmentBase c_segment,
return status;
}
CStatus
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
auto result = (milvus::engine::QueryResult*)c_result;
auto status = CStatus();
try {
auto res = segment->FillTargetEntry(plan, *result);
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
return status;
}
//////////////////////////////////////////////////////////////////
int

View File

@ -61,6 +61,9 @@ Search(CSegmentBase c_segment,
int num_groups,
CQueryResult* result);
CStatus
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult result);
//////////////////////////////////////////////////////////////////
int

View File

@ -641,8 +641,14 @@ TEST(CApiTest, Reduce) {
results.push_back(res1);
results.push_back(res2);
auto reduced_search_result = ReduceQueryResults(results.data(), 2);
auto reorganize_search_result = ReorganizeQueryResults(reduced_search_result, plan, placeholderGroups.data(), 1);
bool is_selected[1] = {false};
status = ReduceQueryResults(results.data(), 1, is_selected);
assert(status.error_code == Success);
FillTargetEntry(segment, plan, res1);
void* reorganize_search_result = nullptr;
status = ReorganizeQueryResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), is_selected,
1, plan);
assert(status.error_code == Success);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
std::vector<char> hits_blob;
@ -660,7 +666,6 @@ TEST(CApiTest, Reduce) {
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(res1);
DeleteQueryResult(res2);
DeleteQueryResult(reduced_search_result);
DeleteMarshaledHits(reorganize_search_result);
DeleteCollection(collection);
DeleteSegment(segment);

View File

@ -70,11 +70,22 @@ func (ms *PulsarMsgStream) SetPulsarClient(address string) {
func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) {
for i := 0; i < len(channels); i++ {
fn := func() error {
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
if err != nil {
log.Printf("Failed to create querynode producer %s, error = %v", channels[i], err)
return err
}
if pp == nil {
return errors.New("pulsar is not ready, producer is nil")
}
ms.producers = append(ms.producers, &pp)
return nil
}
err := Retry(10, time.Millisecond*200, fn)
if err != nil {
errMsg := "Failed to create producer " + channels[i] + ", error = " + err.Error()
panic(errMsg)
}
}
}
@ -104,7 +115,8 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string,
}
err := Retry(10, time.Millisecond*200, fn)
if err != nil {
panic("create pulsar consumer timeout!")
errMsg := "Failed to create consumer " + channels[i] + ", error = " + err.Error()
panic(errMsg)
}
}
}
@ -239,10 +251,6 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() {
cases := make([]reflect.SelectCase, len(ms.consumers))
for i := 0; i < len(ms.consumers); i++ {
pc := *ms.consumers[i]
if pc == nil {
panic("pc is nil")
}
ch := (*ms.consumers[i]).Chan()
cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
}

View File

@ -10,6 +10,8 @@ package querynode
*/
import "C"
import (
"errors"
"strconv"
"unsafe"
)
@ -21,26 +23,66 @@ type MarshaledHits struct {
cMarshaledHits C.CMarshaledHits
}
func reduceSearchResults(searchResults []*SearchResult, numSegments int64) *SearchResult {
func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error {
cSearchResults := make([]C.CQueryResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cQueryResult)
}
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
cNumSegments := C.long(numSegments)
res := C.ReduceQueryResults(cSearchResultPtr, cNumSegments)
return &SearchResult{cQueryResult: res}
cInReduced := (*C.bool)(&inReduced[0])
status := C.ReduceQueryResults(cSearchResultPtr, cNumSegments, cInReduced)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New("reduceSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return nil
}
func (sr *SearchResult) reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup) *MarshaledHits {
func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
for i, value := range inReduced {
if value {
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
if err != nil {
return err
}
}
}
return nil
}
func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeholderGroups {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
cNumGroup := (C.long)(len(placeholderGroups))
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
res := C.ReorganizeQueryResults(sr.cQueryResult, plan.cPlan, cPlaceHolder, cNumGroup)
return &MarshaledHits{cMarshaledHits: res}
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroup = (C.long)(len(placeholderGroups))
cSearchResults := make([]C.CQueryResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cQueryResult)
}
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
var cNumSegments = C.long(numSegments)
var cInReduced = (*C.bool)(&inReduced[0])
var cMarshaledHits C.CMarshaledHits
status := C.ReorganizeQueryResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cPlan)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
}
func (mh *MarshaledHits) getHitsBlobSize() int64 {

View File

@ -107,15 +107,21 @@ func TestReduce_AllFunc(t *testing.T) {
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
matchedSegment := make([]*Segment, 0)
searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{0})
assert.Nil(t, err)
searchResults = append(searchResults, searchResult)
matchedSegment = append(matchedSegment, segment)
reducedSearchResults := reduceSearchResults(searchResults, 1)
assert.NotNil(t, reducedSearchResults)
testReduce := make([]bool, len(searchResults))
err = reduceSearchResults(searchResults, 1, testReduce)
assert.Nil(t, err)
err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce)
assert.Nil(t, err)
marshaledHits := reducedSearchResults.reorganizeQueryResults(plan, placeholderGroups)
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, 1, testReduce)
assert.NotNil(t, marshaledHits)
assert.Nil(t, err)
hitsBlob, err := marshaledHits.getHitsBlob()
assert.Nil(t, err)
@ -137,7 +143,6 @@ func TestReduce_AllFunc(t *testing.T) {
plan.delete()
holder.delete()
deleteSearchResults(searchResults)
deleteSearchResults([]*SearchResult{reducedSearchResults})
deleteMarshaledHits(marshaledHits)
deleteSegment(segment)
deleteCollection(collection)

View File

@ -238,6 +238,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
placeholderGroups = append(placeholderGroups, placeholderGroup)
searchResults := make([]*SearchResult, 0)
matchedSegments := make([]*Segment, 0)
for _, partitionTag := range partitionTags {
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
@ -257,6 +258,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
return err
}
searchResults = append(searchResults, searchResult)
matchedSegments = append(matchedSegments, segment)
}
}
@ -282,8 +284,20 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
return nil
}
reducedSearchResult := reduceSearchResults(searchResults, int64(len(searchResults)))
marshaledHits := reducedSearchResult.reorganizeQueryResults(plan, placeholderGroups)
inReduced := make([]bool, len(searchResults))
numSegment := int64(len(searchResults))
err = reduceSearchResults(searchResults, numSegment, inReduced)
if err != nil {
return err
}
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
if err != nil {
return err
}
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
if err != nil {
return err
}
hitsBlob, err := marshaledHits.getHitsBlob()
if err != nil {
return err
@ -291,12 +305,12 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
var offset int64 = 0
for index := range placeholderGroups {
hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
if err != nil {
return err
}
hits := make([][]byte, 0)
for _, len := range hitBolbSizePeerQuery {
for _, len := range hitBlobSizePeerQuery {
hits = append(hits, hitsBlob[offset:offset+len])
//test code to checkout marshaled hits
//marshaledHit := hitsBlob[offset:offset+len]
@ -329,7 +343,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
deleteSearchResults(searchResults)
deleteSearchResults([]*SearchResult{reducedSearchResult})
deleteMarshaledHits(marshaledHits)
plan.delete()
placeholderGroup.delete()

View File

@ -253,3 +253,242 @@ func TestSearch_Search(t *testing.T) {
cancel()
node.Close()
}
func TestSearch_SearchMultiSegments(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 1024
const receiveBufSize = 1024
const DIM = 16
insertProducerChannels := Params.insertChannelNames()
searchProducerChannels := Params.searchChannelNames()
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
// start search service
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
var searchRawData1 []byte
var searchRawData2 []byte
for i, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
searchRawData1 = append(searchRawData1, buf...)
}
for i, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
searchRawData2 = append(searchRawData2, buf...)
}
placeholderValue := servicepb.PlaceholderValue{
Tag: "$0",
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
Values: [][]byte{searchRawData1, searchRawData2},
}
placeholderGroup := servicepb.PlaceholderGroup{
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
}
placeGroupByte, err := proto.Marshal(&placeholderGroup)
if err != nil {
log.Print("marshal placeholderGroup failed")
}
query := servicepb.Query{
CollectionName: "collection0",
PartitionTags: []string{"default"},
Dsl: dslString,
PlaceholderGroup: placeGroupByte,
}
queryByte, err := proto.Marshal(&query)
if err != nil {
log.Print("marshal query failed")
}
blob := commonpb.Blob{
Value: queryByte,
}
searchMsg := &msgstream.SearchMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
SearchRequest: internalpb.SearchRequest{
MsgType: internalpb.MsgType_kSearch,
ReqID: int64(1),
ProxyID: int64(1),
Timestamp: uint64(10 + 1000),
ResultChannelID: int64(0),
Query: &blob,
},
}
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.ctx, node.replica)
go node.searchService.start()
// start insert
timeRange := TimeRange{
timestampMin: 0,
timestampMax: math.MaxUint64,
}
insertMessages := make([]msgstream.TsMsg, 0)
for i := 0; i < msgLength; i++ {
segmentID := 0
if i >= msgLength/2 {
segmentID = 1
}
var rawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
rawData = append(rawData, buf...)
}
bs := make([]byte, 4)
binary.LittleEndian.PutUint32(bs, 1)
rawData = append(rawData, bs...)
var msg msgstream.TsMsg = &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{
uint32(i),
},
},
InsertRequest: internalpb.InsertRequest{
MsgType: internalpb.MsgType_kInsert,
ReqID: int64(i),
CollectionName: "collection0",
PartitionTag: "default",
SegmentID: int64(segmentID),
ChannelID: int64(0),
ProxyID: int64(0),
Timestamps: []uint64{uint64(i + 1000)},
RowIDs: []int64{int64(i)},
RowData: []*commonpb.Blob{
{Value: rawData},
},
},
}
insertMessages = append(insertMessages, msg)
}
msgPack := msgstream.MsgPack{
BeginTs: timeRange.timestampMin,
EndTs: timeRange.timestampMax,
Msgs: insertMessages,
}
// generate timeTick
timeTickMsgPack := msgstream.MsgPack{}
baseMsg := msgstream.BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{0},
}
timeTickResult := internalpb.TimeTickMsg{
MsgType: internalpb.MsgType_kTimeTick,
PeerID: UniqueID(0),
Timestamp: math.MaxUint64,
}
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: baseMsg,
TimeTickMsg: timeTickResult,
}
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
// pulsar produce
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
insertStream.Start()
err = insertStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
go node.dataSyncService.start()
time.Sleep(1 * time.Second)
cancel()
node.Close()
}

View File

@ -208,7 +208,7 @@ func (s *Segment) segmentSearch(plan *Plan,
var cTimestamp = (*C.ulong)(&timestamp[0])
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroups = C.int(len(placeHolderGroups))
cQueryResult := (*C.CQueryResult)(&searchResult.cQueryResult)
var cQueryResult = (*C.CQueryResult)(&searchResult.cQueryResult)
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cQueryResult)
errorCode := status.error_code
@ -221,3 +221,18 @@ func (s *Segment) segmentSearch(plan *Plan,
return &searchResult, nil
}
func (s *Segment) fillTargetEntry(plan *Plan,
result *SearchResult) error {
var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return nil
}