mirror of https://github.com/milvus-io/milvus.git
Fix search error in regression test
Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/4973/head^2
parent
5dfe9448ae
commit
f4566731fc
|
@ -133,7 +133,7 @@ AppendBinaryIndex(CBinarySet c_binary_set, void* index_binary, int64_t index_siz
|
|||
auto binary_set = (milvus::knowhere::BinarySet*)c_binary_set;
|
||||
std::string index_key(c_index_key);
|
||||
uint8_t* index = (uint8_t*)index_binary;
|
||||
std::shared_ptr<uint8_t[]> data(index);
|
||||
std::shared_ptr<uint8_t[]> data(index, [](void*) {});
|
||||
binary_set->Append(index_key, data, index_size);
|
||||
|
||||
auto status = CStatus();
|
||||
|
|
|
@ -353,15 +353,16 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() {
|
|||
default:
|
||||
wg := sync.WaitGroup{}
|
||||
mu := sync.Mutex{}
|
||||
findMapMutex := sync.RWMutex{}
|
||||
for i := 0; i < len(ms.consumers); i++ {
|
||||
if isChannelReady[i] {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu)
|
||||
go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu, &findMapMutex)
|
||||
}
|
||||
wg.Wait()
|
||||
timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady)
|
||||
timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex)
|
||||
if !ok || timeStamp <= ms.lastTimeStamp {
|
||||
log.Printf("All timeTick's timestamps are inconsistent")
|
||||
continue
|
||||
|
@ -394,7 +395,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() {
|
|||
func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int,
|
||||
eofMsgMap map[int]Timestamp,
|
||||
wg *sync.WaitGroup,
|
||||
mu *sync.Mutex) {
|
||||
mu *sync.Mutex,
|
||||
findMapMutex *sync.RWMutex) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
|
@ -421,7 +423,9 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int,
|
|||
log.Printf("Failed to unmarshal, error = %v", err)
|
||||
}
|
||||
if headerMsg.MsgType == internalPb.MsgType_kTimeTick {
|
||||
findMapMutex.Lock()
|
||||
eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Timestamp
|
||||
findMapMutex.Unlock()
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
|
@ -470,7 +474,7 @@ func (ms *InMemMsgStream) Chan() <- chan *MsgPack {
|
|||
}
|
||||
*/
|
||||
|
||||
func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp, bool) {
|
||||
func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool, mu *sync.RWMutex) (Timestamp, bool) {
|
||||
checkMap := make(map[Timestamp]int)
|
||||
var maxTime Timestamp = 0
|
||||
for _, v := range msg {
|
||||
|
@ -485,7 +489,10 @@ func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp,
|
|||
}
|
||||
return maxTime, true
|
||||
}
|
||||
for i, v := range msg {
|
||||
for i := range msg {
|
||||
mu.RLock()
|
||||
v := msg[i]
|
||||
mu.Unlock()
|
||||
if v != maxTime {
|
||||
isChannelReady[i] = false
|
||||
} else {
|
||||
|
|
|
@ -374,9 +374,6 @@ func (qt *QueryTask) PreExecute() error {
|
|||
}
|
||||
}
|
||||
qt.MsgType = internalpb.MsgType_kSearch
|
||||
if qt.query.PartitionTags == nil || len(qt.query.PartitionTags) <= 0 {
|
||||
qt.query.PartitionTags = []string{Params.defaultPartitionTag()}
|
||||
}
|
||||
queryBytes, err := proto.Marshal(qt.query)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -223,7 +224,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
return errors.New("unmarshal query failed")
|
||||
}
|
||||
collectionName := query.CollectionName
|
||||
partitionTags := query.PartitionTags
|
||||
partitionTagsInQuery := query.PartitionTags
|
||||
collection, err := ss.replica.getCollectionByName(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -245,11 +246,29 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
searchResults := make([]*SearchResult, 0)
|
||||
matchedSegments := make([]*Segment, 0)
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
partition, err := ss.replica.getPartitionByTag(collectionID, partitionTag)
|
||||
if err != nil {
|
||||
continue
|
||||
fmt.Println("search msg's partitionTag = ", partitionTagsInQuery)
|
||||
|
||||
var partitionTagsInCol []string
|
||||
for _, partition := range collection.partitions {
|
||||
partitionTag := partition.partitionTag
|
||||
partitionTagsInCol = append(partitionTagsInCol, partitionTag)
|
||||
}
|
||||
var searchPartitionTag []string
|
||||
if len(partitionTagsInQuery) == 0 {
|
||||
searchPartitionTag = partitionTagsInCol
|
||||
} else {
|
||||
for _, tag := range partitionTagsInCol {
|
||||
for _, toMatchTag := range partitionTagsInQuery {
|
||||
re := regexp.MustCompile("^" + toMatchTag + "$")
|
||||
if re.MatchString(tag) {
|
||||
searchPartitionTag = append(searchPartitionTag, tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, partitionTag := range searchPartitionTag {
|
||||
partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
|
||||
for _, segment := range partition.segments {
|
||||
//fmt.Println("dsl = ", dsl)
|
||||
|
||||
|
@ -360,6 +379,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
|
|||
}
|
||||
|
||||
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error {
|
||||
fmt.Println("Public fail SearchResult!")
|
||||
msgPack := msgstream.MsgPack{}
|
||||
searchMsg, ok := msg.(*msgstream.SearchMsg)
|
||||
if !ok {
|
||||
|
|
|
@ -255,7 +255,7 @@ class TestSearchBase:
|
|||
assert res2[0][0].id == res[0][1].id
|
||||
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
|
||||
|
||||
# pass
|
||||
# Pass
|
||||
@pytest.mark.skip("search_after_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
|
@ -303,6 +303,7 @@ class TestSearchBase:
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# should fix, 336 assert fail, insert data don't have partitionTag, But search data have
|
||||
@pytest.mark.skip("search_index_partition")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
|
@ -334,7 +335,7 @@ class TestSearchBase:
|
|||
res = connect.search(collection, query, partition_tags=[default_tag])
|
||||
assert len(res) == nq
|
||||
|
||||
# pass
|
||||
# PASS
|
||||
@pytest.mark.skip("search_index_partition_B")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
|
@ -385,6 +386,7 @@ class TestSearchBase:
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("search_index_partitions")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
|
@ -419,6 +421,7 @@ class TestSearchBase:
|
|||
assert res[0]._distances[0] > epsilon
|
||||
assert res[1]._distances[0] > epsilon
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("search_index_partitions_B")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
|
||||
|
@ -479,7 +482,7 @@ class TestSearchBase:
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# pass
|
||||
# PASS
|
||||
@pytest.mark.skip("search_ip_after_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
|
@ -509,6 +512,7 @@ class TestSearchBase:
|
|||
assert check_id_result(res[0], ids[0])
|
||||
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
|
||||
|
||||
# should fix, nq not correct
|
||||
@pytest.mark.skip("search_ip_index_partition")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
|
@ -542,6 +546,7 @@ class TestSearchBase:
|
|||
res = connect.search(collection, query, partition_tags=[default_tag])
|
||||
assert len(res) == nq
|
||||
|
||||
# PASS
|
||||
@pytest.mark.skip("search_ip_index_partitions")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
|
@ -621,6 +626,7 @@ class TestSearchBase:
|
|||
res = connect.search(collection, query)
|
||||
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
|
||||
|
||||
# Pass
|
||||
@pytest.mark.skip("test_search_distance_l2_after_index")
|
||||
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
|
||||
'''
|
||||
|
@ -675,7 +681,7 @@ class TestSearchBase:
|
|||
res = connect.search(collection, query)
|
||||
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
|
||||
|
||||
# pass
|
||||
# Pass
|
||||
@pytest.mark.skip("search_distance_ip_after_index")
|
||||
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
|
||||
'''
|
||||
|
@ -946,6 +952,7 @@ class TestSearchBase:
|
|||
assert res[i]._distances[0] < epsilon
|
||||
assert res[i]._distances[1] > epsilon
|
||||
|
||||
# should fix
|
||||
@pytest.mark.skip("query_entities_with_field_less_than_top_k")
|
||||
def test_query_entities_with_field_less_than_top_k(self, connect, id_collection):
|
||||
"""
|
||||
|
@ -1745,7 +1752,7 @@ class TestSearchInvalid(object):
|
|||
def get_search_params(self, request):
|
||||
yield request.param
|
||||
|
||||
# pass
|
||||
# Pass
|
||||
@pytest.mark.skip("search_with_invalid_params")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
|
||||
|
@ -1787,7 +1794,7 @@ class TestSearchInvalid(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(binary_collection, query)
|
||||
|
||||
# pass
|
||||
# Pass
|
||||
@pytest.mark.skip("search_with_empty_params")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
|
||||
|
|
Loading…
Reference in New Issue