diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 844f1317e8..8e06936888 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -50,97 +50,187 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) { } struct SearchResultPair { - uint64_t id_; float distance_; - int64_t segment_id_; + SearchResult* search_result_; + int64_t offset_; + int64_t index_; - SearchResultPair(uint64_t id, float distance, int64_t segment_id) - : id_(id), distance_(distance), segment_id_(segment_id) { + SearchResultPair(float distance, SearchResult* search_result, int64_t offset, int64_t index) + : distance_(distance), search_result_(search_result), offset_(offset), index_(index) { } bool operator<(const SearchResultPair& pair) const { return (distance_ < pair.distance_); } + + void + reset_distance() { + distance_ = search_result_->result_distances_[offset_]; + } }; void -GetResultData(std::vector& search_results, - SearchResult& final_result, +GetResultData(std::vector>& search_records, + std::vector& search_results, int64_t query_offset, + bool* is_selected, int64_t topk) { auto num_segments = search_results.size(); - std::map iter_loc_peer_result; + AssertInfo(num_segments > 0, "num segment must greater than 0"); std::vector result_pairs; for (int j = 0; j < num_segments; ++j) { - auto id = search_results[j]->result_ids_[query_offset]; auto distance = search_results[j]->result_distances_[query_offset]; - result_pairs.push_back(SearchResultPair(id, distance, j)); - iter_loc_peer_result[j] = query_offset; + auto search_result = search_results[j]; + AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j)); } - std::sort(result_pairs.begin(), result_pairs.end()); - final_result.result_ids_.push_back(result_pairs[0].id_); - final_result.result_distances_.push_back(result_pairs[0].distance_); - - for (int i = 1; i < topk; ++i) { - auto segment_id = result_pairs[0].segment_id_; - auto query_offset = ++(iter_loc_peer_result[segment_id]); - auto id = search_results[segment_id]->result_ids_[query_offset]; - auto distance = search_results[segment_id]->result_distances_[query_offset]; - result_pairs[0] = SearchResultPair(id, distance, segment_id); + int64_t loc_offset = query_offset; + AssertInfo(topk > 0, "topK must greater than 0"); + for (int i = 0; i < topk; ++i) { + result_pairs[0].reset_distance(); std::sort(result_pairs.begin(), result_pairs.end()); - final_result.result_ids_.push_back(result_pairs[0].id_); - final_result.result_distances_.push_back(result_pairs[0].distance_); + auto& result_pair = result_pairs[0]; + auto index = result_pair.index_; + is_selected[index] = true; + result_pair.search_result_->result_offsets_.push_back(loc_offset++); + search_records[index].push_back(result_pair.offset_++); } } -CQueryResult -ReduceQueryResults(CQueryResult* query_results, int64_t num_segments) { +void +ResetSearchResult(std::vector>& search_records, + std::vector& search_results, + bool* is_selected) { + auto num_segments = search_results.size(); + AssertInfo(num_segments > 0, "num segment must greater than 0"); + for (int i = 0; i < num_segments; i++) { + if (is_selected[i] == false) { + continue; + } + auto search_result = search_results[i]; + AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + + std::vector result_distances; + std::vector internal_seg_offsets; + std::vector result_ids; + + for (int j = 0; j < search_records[i].size(); j++) { + auto& offset = search_records[i][j]; + auto distance = search_result->result_distances_[offset]; + auto internal_seg_offset = search_result->internal_seg_offsets_[offset]; + auto id = search_result->result_ids_[offset]; + result_distances.push_back(distance); + internal_seg_offsets.push_back(internal_seg_offset); + result_ids.push_back(id); + } + + search_result->result_distances_ = result_distances; + search_result->internal_seg_offsets_ = internal_seg_offsets; + search_result->result_ids_ = result_ids; + } +} + +CStatus +ReduceQueryResults(CQueryResult* c_search_results, int64_t num_segments, bool* is_selected) { std::vector search_results; for (int i = 0; i < num_segments; ++i) { - search_results.push_back((SearchResult*)query_results[i]); + search_results.push_back((SearchResult*)c_search_results[i]); } - auto topk = search_results[0]->topK_; - auto num_queries = search_results[0]->num_queries_; - auto final_result = std::make_unique(); + try { + auto topk = search_results[0]->topK_; + auto num_queries = search_results[0]->num_queries_; + std::vector> search_records(num_segments); - int64_t query_offset = 0; - for (int j = 0; j < num_queries; ++j) { - GetResultData(search_results, *final_result, query_offset, topk); - query_offset += topk; + int64_t query_offset = 0; + for (int j = 0; j < num_queries; ++j) { + GetResultData(search_records, search_results, query_offset, is_selected, topk); + query_offset += topk; + } + ResetSearchResult(search_records, search_results, is_selected); + auto status = CStatus(); + status.error_code = Success; + status.error_msg = ""; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = UnexpectedException; + status.error_msg = strdup(e.what()); + return status; } - - return (CQueryResult)final_result.release(); } -CMarshaledHits -ReorganizeQueryResults(CQueryResult c_query_result, - CPlan c_plan, +CStatus +ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits, CPlaceholderGroup* c_placeholder_groups, - int64_t num_groups) { - auto marshaledHits = std::make_unique(num_groups); - auto search_result = (milvus::engine::QueryResult*)c_query_result; - auto& result_ids = search_result->result_ids_; - auto& result_distances = search_result->result_distances_; - auto topk = GetTopK(c_plan); - int64_t queries_offset = 0; - for (int i = 0; i < num_groups; i++) { - auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); - MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; - for (int j = 0; j < num_queries; j++) { - auto index = topk * queries_offset++; - milvus::proto::service::Hits hits; - for (int k = index; k < index + topk; k++) { - hits.add_ids(result_ids[k]); - hits.add_scores(result_distances[k]); - } - auto blob = hits.SerializeAsString(); - hits_peer_group.hits_.push_back(blob); - hits_peer_group.blob_length_.push_back(blob.size()); + int64_t num_groups, + CQueryResult* c_search_results, + bool* is_selected, + int64_t num_segments, + CPlan c_plan) { + try { + auto marshaledHits = std::make_unique(num_groups); + auto topk = GetTopK(c_plan); + std::vector num_queries_peer_group; + int64_t total_num_queries = 0; + for (int i = 0; i < num_groups; i++) { + auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); + num_queries_peer_group.push_back(num_queries); + total_num_queries += num_queries; } - } - return (CMarshaledHits)marshaledHits.release(); + std::vector result_distances(total_num_queries * topk); + std::vector result_ids(total_num_queries * topk); + std::vector> row_datas(total_num_queries * topk); + + int64_t count = 0; + for (int i = 0; i < num_segments; i++) { + if (is_selected[i] == false) { + continue; + } + auto search_result = (SearchResult*)c_search_results[i]; + AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); + auto size = search_result->result_offsets_.size(); + for (int j = 0; j < size; j++) { + auto loc = search_result->result_offsets_[j]; + result_distances[loc] = search_result->result_distances_[j]; + row_datas[loc] = search_result->row_data_[j]; + result_ids[loc] = search_result->result_ids_[j]; + } + count += size; + } + AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk"); + + int64_t fill_hit_offset = 0; + for (int i = 0; i < num_groups; i++) { + MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; + for (int j = 0; j < num_queries_peer_group[i]; j++) { + milvus::proto::service::Hits hits; + for (int k = 0; k < topk; k++, fill_hit_offset++) { + hits.add_ids(result_ids[fill_hit_offset]); + hits.add_scores(result_distances[fill_hit_offset]); + auto& row_data = row_datas[fill_hit_offset]; + hits.add_row_data(row_data.data(), row_data.size()); + } + auto blob = hits.SerializeAsString(); + hits_peer_group.hits_.push_back(blob); + hits_peer_group.blob_length_.push_back(blob.size()); + } + } + + auto status = CStatus(); + status.error_code = Success; + status.error_msg = ""; + auto marshled_res = (CMarshaledHits)marshaledHits.release(); + *c_marshaled_hits = marshled_res; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = UnexpectedException; + status.error_msg = strdup(e.what()); + *c_marshaled_hits = nullptr; + return status; + } } int64_t diff --git a/internal/core/src/segcore/reduce_c.h b/internal/core/src/segcore/reduce_c.h index e9e2016665..59a3274286 100644 --- a/internal/core/src/segcore/reduce_c.h +++ b/internal/core/src/segcore/reduce_c.h @@ -25,14 +25,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits); int MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids); -CQueryResult -ReduceQueryResults(CQueryResult* query_results, int64_t num_segments); +CStatus +ReduceQueryResults(CQueryResult* query_results, int64_t num_segments, bool* is_selected); -CMarshaledHits -ReorganizeQueryResults(CQueryResult query_result, - CPlan c_plan, +CStatus +ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits, CPlaceholderGroup* c_placeholder_groups, - int64_t num_groups); + int64_t num_groups, + CQueryResult* c_search_results, + bool* is_selected, + int64_t num_segments, + CPlan c_plan); int64_t GetHitsBlobSize(CMarshaledHits c_marshaled_hits); diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index b053daddca..8c6786c4f0 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -155,6 +155,24 @@ Search(CSegmentBase c_segment, return status; } +CStatus +FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) { + auto segment = (milvus::segcore::SegmentBase*)c_segment; + auto plan = (milvus::query::Plan*)c_plan; + auto result = (milvus::engine::QueryResult*)c_result; + + auto status = CStatus(); + try { + auto res = segment->FillTargetEntry(plan, *result); + status.error_code = Success; + status.error_msg = ""; + } catch (std::runtime_error& e) { + status.error_code = UnexpectedException; + status.error_msg = strdup(e.what()); + } + return status; +} + ////////////////////////////////////////////////////////////////// int diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index 7f73223347..0dc3f7cdcd 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -61,6 +61,9 @@ Search(CSegmentBase c_segment, int num_groups, CQueryResult* result); +CStatus +FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult result); + ////////////////////////////////////////////////////////////////// int diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index d974f4bc46..c9a621a3e5 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -641,8 +641,14 @@ TEST(CApiTest, Reduce) { results.push_back(res1); results.push_back(res2); - auto reduced_search_result = ReduceQueryResults(results.data(), 2); - auto reorganize_search_result = ReorganizeQueryResults(reduced_search_result, plan, placeholderGroups.data(), 1); + bool is_selected[1] = {false}; + status = ReduceQueryResults(results.data(), 1, is_selected); + assert(status.error_code == Success); + FillTargetEntry(segment, plan, res1); + void* reorganize_search_result = nullptr; + status = ReorganizeQueryResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), is_selected, + 1, plan); + assert(status.error_code == Success); auto hits_blob_size = GetHitsBlobSize(reorganize_search_result); assert(hits_blob_size > 0); std::vector hits_blob; @@ -660,7 +666,6 @@ TEST(CApiTest, Reduce) { DeletePlaceholderGroup(placeholderGroup); DeleteQueryResult(res1); DeleteQueryResult(res2); - DeleteQueryResult(reduced_search_result); DeleteMarshaledHits(reorganize_search_result); DeleteCollection(collection); DeleteSegment(segment); diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 4f1eed4a3a..11c0b8f2ab 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -70,11 +70,22 @@ func (ms *PulsarMsgStream) SetPulsarClient(address string) { func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) { for i := 0; i < len(channels); i++ { - pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]}) - if err != nil { - log.Printf("Failed to create querynode producer %s, error = %v", channels[i], err) + fn := func() error { + pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]}) + if err != nil { + return err + } + if pp == nil { + return errors.New("pulsar is not ready, producer is nil") + } + ms.producers = append(ms.producers, &pp) + return nil + } + err := Retry(10, time.Millisecond*200, fn) + if err != nil { + errMsg := "Failed to create producer " + channels[i] + ", error = " + err.Error() + panic(errMsg) } - ms.producers = append(ms.producers, &pp) } } @@ -104,7 +115,8 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string, } err := Retry(10, time.Millisecond*200, fn) if err != nil { - panic("create pulsar consumer timeout!") + errMsg := "Failed to create consumer " + channels[i] + ", error = " + err.Error() + panic(errMsg) } } } @@ -239,10 +251,6 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { cases := make([]reflect.SelectCase, len(ms.consumers)) for i := 0; i < len(ms.consumers); i++ { - pc := *ms.consumers[i] - if pc == nil { - panic("pc is nil") - } ch := (*ms.consumers[i]).Chan() cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} } diff --git a/internal/querynode/reduce.go b/internal/querynode/reduce.go index 035227e8be..8fa56bf7fb 100644 --- a/internal/querynode/reduce.go +++ b/internal/querynode/reduce.go @@ -10,6 +10,8 @@ package querynode */ import "C" import ( + "errors" + "strconv" "unsafe" ) @@ -21,26 +23,66 @@ type MarshaledHits struct { cMarshaledHits C.CMarshaledHits } -func reduceSearchResults(searchResults []*SearchResult, numSegments int64) *SearchResult { +func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error { cSearchResults := make([]C.CQueryResult, 0) for _, res := range searchResults { cSearchResults = append(cSearchResults, res.cQueryResult) } cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0]) cNumSegments := C.long(numSegments) - res := C.ReduceQueryResults(cSearchResultPtr, cNumSegments) - return &SearchResult{cQueryResult: res} + cInReduced := (*C.bool)(&inReduced[0]) + + status := C.ReduceQueryResults(cSearchResultPtr, cNumSegments, cInReduced) + + errorCode := status.error_code + + if errorCode != 0 { + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + return errors.New("reduceSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + } + return nil } -func (sr *SearchResult) reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup) *MarshaledHits { +func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error { + for i, value := range inReduced { + if value { + err := matchedSegments[i].fillTargetEntry(plan, searchResults[i]) + if err != nil { + return err + } + } + } + return nil +} + +func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) { cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) for _, pg := range placeholderGroups { cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) } - cNumGroup := (C.long)(len(placeholderGroups)) - var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) - res := C.ReorganizeQueryResults(sr.cQueryResult, plan.cPlan, cPlaceHolder, cNumGroup) - return &MarshaledHits{cMarshaledHits: res} + var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) + var cNumGroup = (C.long)(len(placeholderGroups)) + + cSearchResults := make([]C.CQueryResult, 0) + for _, res := range searchResults { + cSearchResults = append(cSearchResults, res.cQueryResult) + } + cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0]) + + var cNumSegments = C.long(numSegments) + var cInReduced = (*C.bool)(&inReduced[0]) + var cMarshaledHits C.CMarshaledHits + + status := C.ReorganizeQueryResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cPlan) + errorCode := status.error_code + + if errorCode != 0 { + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + } + return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil } func (mh *MarshaledHits) getHitsBlobSize() int64 { diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index 5774c38acc..a14ae3a919 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -107,15 +107,21 @@ func TestReduce_AllFunc(t *testing.T) { placeholderGroups = append(placeholderGroups, holder) searchResults := make([]*SearchResult, 0) + matchedSegment := make([]*Segment, 0) searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{0}) assert.Nil(t, err) searchResults = append(searchResults, searchResult) + matchedSegment = append(matchedSegment, segment) - reducedSearchResults := reduceSearchResults(searchResults, 1) - assert.NotNil(t, reducedSearchResults) + testReduce := make([]bool, len(searchResults)) + err = reduceSearchResults(searchResults, 1, testReduce) + assert.Nil(t, err) + err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce) + assert.Nil(t, err) - marshaledHits := reducedSearchResults.reorganizeQueryResults(plan, placeholderGroups) + marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, 1, testReduce) assert.NotNil(t, marshaledHits) + assert.Nil(t, err) hitsBlob, err := marshaledHits.getHitsBlob() assert.Nil(t, err) @@ -137,7 +143,6 @@ func TestReduce_AllFunc(t *testing.T) { plan.delete() holder.delete() deleteSearchResults(searchResults) - deleteSearchResults([]*SearchResult{reducedSearchResults}) deleteMarshaledHits(marshaledHits) deleteSegment(segment) deleteCollection(collection) diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index e66aca4e22..a5cda00534 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -238,6 +238,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { placeholderGroups = append(placeholderGroups, placeholderGroup) searchResults := make([]*SearchResult, 0) + matchedSegments := make([]*Segment, 0) for _, partitionTag := range partitionTags { hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag) @@ -257,6 +258,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { return err } searchResults = append(searchResults, searchResult) + matchedSegments = append(matchedSegments, segment) } } @@ -282,8 +284,20 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { return nil } - reducedSearchResult := reduceSearchResults(searchResults, int64(len(searchResults))) - marshaledHits := reducedSearchResult.reorganizeQueryResults(plan, placeholderGroups) + inReduced := make([]bool, len(searchResults)) + numSegment := int64(len(searchResults)) + err = reduceSearchResults(searchResults, numSegment, inReduced) + if err != nil { + return err + } + err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) + if err != nil { + return err + } + marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced) + if err != nil { + return err + } hitsBlob, err := marshaledHits.getHitsBlob() if err != nil { return err @@ -291,12 +305,12 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { var offset int64 = 0 for index := range placeholderGroups { - hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index)) + hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index)) if err != nil { return err } hits := make([][]byte, 0) - for _, len := range hitBolbSizePeerQuery { + for _, len := range hitBlobSizePeerQuery { hits = append(hits, hitsBlob[offset:offset+len]) //test code to checkout marshaled hits //marshaledHit := hitsBlob[offset:offset+len] @@ -329,7 +343,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } deleteSearchResults(searchResults) - deleteSearchResults([]*SearchResult{reducedSearchResult}) deleteMarshaledHits(marshaledHits) plan.delete() placeholderGroup.delete() diff --git a/internal/querynode/search_service_test.go b/internal/querynode/search_service_test.go index d095db3a7f..3624e9b99e 100644 --- a/internal/querynode/search_service_test.go +++ b/internal/querynode/search_service_test.go @@ -253,3 +253,242 @@ func TestSearch_Search(t *testing.T) { cancel() node.Close() } + +func TestSearch_SearchMultiSegments(t *testing.T) { + Params.Init() + ctx, cancel := context.WithCancel(context.Background()) + + // init query node + pulsarURL, _ := Params.pulsarAddress() + node := NewQueryNode(ctx, 0) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) + + // test data generate + const msgLength = 1024 + const receiveBufSize = 1024 + const DIM = 16 + insertProducerChannels := Params.insertChannelNames() + searchProducerChannels := Params.searchChannelNames() + var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + + // start search service + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + var searchRawData1 []byte + var searchRawData2 []byte + for i, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) + searchRawData1 = append(searchRawData1, buf...) + } + for i, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4))) + searchRawData2 = append(searchRawData2, buf...) + } + placeholderValue := servicepb.PlaceholderValue{ + Tag: "$0", + Type: servicepb.PlaceholderType_VECTOR_FLOAT, + Values: [][]byte{searchRawData1, searchRawData2}, + } + + placeholderGroup := servicepb.PlaceholderGroup{ + Placeholders: []*servicepb.PlaceholderValue{&placeholderValue}, + } + + placeGroupByte, err := proto.Marshal(&placeholderGroup) + if err != nil { + log.Print("marshal placeholderGroup failed") + } + + query := servicepb.Query{ + CollectionName: "collection0", + PartitionTags: []string{"default"}, + Dsl: dslString, + PlaceholderGroup: placeGroupByte, + } + + queryByte, err := proto.Marshal(&query) + if err != nil { + log.Print("marshal query failed") + } + + blob := commonpb.Blob{ + Value: queryByte, + } + + searchMsg := &msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{0}, + }, + SearchRequest: internalpb.SearchRequest{ + MsgType: internalpb.MsgType_kSearch, + ReqID: int64(1), + ProxyID: int64(1), + Timestamp: uint64(10 + 1000), + ResultChannelID: int64(0), + Query: &blob, + }, + } + + msgPackSearch := msgstream.MsgPack{} + msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg) + + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + searchStream.SetPulsarClient(pulsarURL) + searchStream.CreatePulsarProducers(searchProducerChannels) + searchStream.Start() + err = searchStream.Produce(&msgPackSearch) + assert.NoError(t, err) + + node.searchService = newSearchService(node.ctx, node.replica) + go node.searchService.start() + + // start insert + timeRange := TimeRange{ + timestampMin: 0, + timestampMax: math.MaxUint64, + } + + insertMessages := make([]msgstream.TsMsg, 0) + for i := 0; i < msgLength; i++ { + segmentID := 0 + if i >= msgLength/2 { + segmentID = 1 + } + var rawData []byte + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) + rawData = append(rawData, buf...) + } + bs := make([]byte, 4) + binary.LittleEndian.PutUint32(bs, 1) + rawData = append(rawData, bs...) + + var msg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{ + uint32(i), + }, + }, + InsertRequest: internalpb.InsertRequest{ + MsgType: internalpb.MsgType_kInsert, + ReqID: int64(i), + CollectionName: "collection0", + PartitionTag: "default", + SegmentID: int64(segmentID), + ChannelID: int64(0), + ProxyID: int64(0), + Timestamps: []uint64{uint64(i + 1000)}, + RowIDs: []int64{int64(i)}, + RowData: []*commonpb.Blob{ + {Value: rawData}, + }, + }, + } + insertMessages = append(insertMessages, msg) + } + + msgPack := msgstream.MsgPack{ + BeginTs: timeRange.timestampMin, + EndTs: timeRange.timestampMax, + Msgs: insertMessages, + } + + // generate timeTick + timeTickMsgPack := msgstream.MsgPack{} + baseMsg := msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{0}, + } + timeTickResult := internalpb.TimeTickMsg{ + MsgType: internalpb.MsgType_kTimeTick, + PeerID: UniqueID(0), + Timestamp: math.MaxUint64, + } + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickResult, + } + timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) + + // pulsar produce + insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + insertStream.SetPulsarClient(pulsarURL) + insertStream.CreatePulsarProducers(insertProducerChannels) + insertStream.Start() + err = insertStream.Produce(&msgPack) + assert.NoError(t, err) + err = insertStream.Broadcast(&timeTickMsgPack) + assert.NoError(t, err) + + // dataSync + node.dataSyncService = newDataSyncService(node.ctx, node.replica) + go node.dataSyncService.start() + + time.Sleep(1 * time.Second) + + cancel() + node.Close() +} diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index ad3a50ce5c..e3f7d81252 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -208,7 +208,7 @@ func (s *Segment) segmentSearch(plan *Plan, var cTimestamp = (*C.ulong)(×tamp[0]) var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) var cNumGroups = C.int(len(placeHolderGroups)) - cQueryResult := (*C.CQueryResult)(&searchResult.cQueryResult) + var cQueryResult = (*C.CQueryResult)(&searchResult.cQueryResult) var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cQueryResult) errorCode := status.error_code @@ -221,3 +221,18 @@ func (s *Segment) segmentSearch(plan *Plan, return &searchResult, nil } + +func (s *Segment) fillTargetEntry(plan *Plan, + result *SearchResult) error { + + var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult) + errorCode := status.error_code + + if errorCode != 0 { + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + } + + return nil +}