diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 5c8da29795..9afd0b6262 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -133,7 +133,7 @@ AppendBinaryIndex(CBinarySet c_binary_set, void* index_binary, int64_t index_siz auto binary_set = (milvus::knowhere::BinarySet*)c_binary_set; std::string index_key(c_index_key); uint8_t* index = (uint8_t*)index_binary; - std::shared_ptr data(index); + std::shared_ptr data(index, [](void*) {}); binary_set->Append(index_key, data, index_size); auto status = CStatus(); diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 37dd71c053..6efbb1ef69 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -353,15 +353,16 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { default: wg := sync.WaitGroup{} mu := sync.Mutex{} + findMapMutex := sync.RWMutex{} for i := 0; i < len(ms.consumers); i++ { if isChannelReady[i] { continue } wg.Add(1) - go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu) + go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu, &findMapMutex) } wg.Wait() - timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady) + timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex) if !ok || timeStamp <= ms.lastTimeStamp { log.Printf("All timeTick's timestamps are inconsistent") continue @@ -394,7 +395,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, eofMsgMap map[int]Timestamp, wg *sync.WaitGroup, - mu *sync.Mutex) { + mu *sync.Mutex, + findMapMutex *sync.RWMutex) { defer wg.Done() for { select { @@ -421,7 +423,9 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, log.Printf("Failed to unmarshal, error = %v", err) } if headerMsg.MsgType == internalPb.MsgType_kTimeTick { + findMapMutex.Lock() eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Timestamp + findMapMutex.Unlock() return } mu.Lock() @@ -470,7 +474,7 @@ func (ms *InMemMsgStream) Chan() <- chan *MsgPack { } */ -func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp, bool) { +func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool, mu *sync.RWMutex) (Timestamp, bool) { checkMap := make(map[Timestamp]int) var maxTime Timestamp = 0 for _, v := range msg { @@ -485,7 +489,10 @@ func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp, } return maxTime, true } - for i, v := range msg { + for i := range msg { + mu.RLock() + v := msg[i] + mu.Unlock() if v != maxTime { isChannelReady[i] = false } else { diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d676a067dd..254a7bba10 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -374,9 +374,6 @@ func (qt *QueryTask) PreExecute() error { } } qt.MsgType = internalpb.MsgType_kSearch - if qt.query.PartitionTags == nil || len(qt.query.PartitionTags) <= 0 { - qt.query.PartitionTags = []string{Params.defaultPartitionTag()} - } queryBytes, err := proto.Marshal(qt.query) if err != nil { return err diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index ab5360c5ef..b7644e362f 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "regexp" "sync" "github.com/golang/protobuf/proto" @@ -223,7 +224,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { return errors.New("unmarshal query failed") } collectionName := query.CollectionName - partitionTags := query.PartitionTags + partitionTagsInQuery := query.PartitionTags collection, err := ss.replica.getCollectionByName(collectionName) if err != nil { return err @@ -245,11 +246,29 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { searchResults := make([]*SearchResult, 0) matchedSegments := make([]*Segment, 0) - for _, partitionTag := range partitionTags { - partition, err := ss.replica.getPartitionByTag(collectionID, partitionTag) - if err != nil { - continue + fmt.Println("search msg's partitionTag = ", partitionTagsInQuery) + + var partitionTagsInCol []string + for _, partition := range collection.partitions { + partitionTag := partition.partitionTag + partitionTagsInCol = append(partitionTagsInCol, partitionTag) + } + var searchPartitionTag []string + if len(partitionTagsInQuery) == 0 { + searchPartitionTag = partitionTagsInCol + } else { + for _, tag := range partitionTagsInCol { + for _, toMatchTag := range partitionTagsInQuery { + re := regexp.MustCompile("^" + toMatchTag + "$") + if re.MatchString(tag) { + searchPartitionTag = append(searchPartitionTag, tag) + } + } } + } + + for _, partitionTag := range searchPartitionTag { + partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag) for _, segment := range partition.segments { //fmt.Println("dsl = ", dsl) @@ -360,6 +379,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error { } func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error { + fmt.Println("Public fail SearchResult!") msgPack := msgstream.MsgPack{} searchMsg, ok := msg.(*msgstream.SearchMsg) if !ok { diff --git a/tests/python/test_search.py b/tests/python/test_search.py index c93835c903..70815e8a76 100644 --- a/tests/python/test_search.py +++ b/tests/python/test_search.py @@ -255,7 +255,7 @@ class TestSearchBase: assert res2[0][0].id == res[0][1].id assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") - # pass + # Pass @pytest.mark.skip("search_after_index") @pytest.mark.level(2) def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): @@ -303,6 +303,7 @@ class TestSearchBase: assert len(res) == nq assert len(res[0]) == default_top_k + # should fix, 336 assert fail, insert data don't have partitionTag, But search data have @pytest.mark.skip("search_index_partition") @pytest.mark.level(2) def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): @@ -334,7 +335,7 @@ class TestSearchBase: res = connect.search(collection, query, partition_tags=[default_tag]) assert len(res) == nq - # pass + # PASS @pytest.mark.skip("search_index_partition_B") @pytest.mark.level(2) def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): @@ -385,6 +386,7 @@ class TestSearchBase: assert len(res) == nq assert len(res[0]) == 0 + # PASS @pytest.mark.skip("search_index_partitions") @pytest.mark.level(2) def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): @@ -419,6 +421,7 @@ class TestSearchBase: assert res[0]._distances[0] > epsilon assert res[1]._distances[0] > epsilon + # Pass @pytest.mark.skip("search_index_partitions_B") @pytest.mark.level(2) def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): @@ -479,7 +482,7 @@ class TestSearchBase: with pytest.raises(Exception) as e: res = connect.search(collection, query) - # pass + # PASS @pytest.mark.skip("search_ip_after_index") @pytest.mark.level(2) def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): @@ -509,6 +512,7 @@ class TestSearchBase: assert check_id_result(res[0], ids[0]) assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) + # should fix, nq not correct @pytest.mark.skip("search_ip_index_partition") @pytest.mark.level(2) def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): @@ -542,6 +546,7 @@ class TestSearchBase: res = connect.search(collection, query, partition_tags=[default_tag]) assert len(res) == nq + # PASS @pytest.mark.skip("search_ip_index_partitions") @pytest.mark.level(2) def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): @@ -621,6 +626,7 @@ class TestSearchBase: res = connect.search(collection, query) assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) + # Pass @pytest.mark.skip("test_search_distance_l2_after_index") def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): ''' @@ -675,7 +681,7 @@ class TestSearchBase: res = connect.search(collection, query) assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon - # pass + # Pass @pytest.mark.skip("search_distance_ip_after_index") def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): ''' @@ -946,6 +952,7 @@ class TestSearchBase: assert res[i]._distances[0] < epsilon assert res[i]._distances[1] > epsilon + # should fix @pytest.mark.skip("query_entities_with_field_less_than_top_k") def test_query_entities_with_field_less_than_top_k(self, connect, id_collection): """ @@ -1745,7 +1752,7 @@ class TestSearchInvalid(object): def get_search_params(self, request): yield request.param - # pass + # Pass @pytest.mark.skip("search_with_invalid_params") @pytest.mark.level(2) def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): @@ -1787,7 +1794,7 @@ class TestSearchInvalid(object): with pytest.raises(Exception) as e: res = connect.search(binary_collection, query) - # pass + # Pass @pytest.mark.skip("search_with_empty_params") @pytest.mark.level(2) def test_search_with_empty_params(self, connect, collection, args, get_simple_index):