Fix search error in regression test

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/4973/head^2
xige-16 2021-01-09 09:47:22 +08:00 committed by yefu.chen
parent 5dfe9448ae
commit f4566731fc
5 changed files with 51 additions and 20 deletions

View File

@ -133,7 +133,7 @@ AppendBinaryIndex(CBinarySet c_binary_set, void* index_binary, int64_t index_siz
auto binary_set = (milvus::knowhere::BinarySet*)c_binary_set;
std::string index_key(c_index_key);
uint8_t* index = (uint8_t*)index_binary;
std::shared_ptr<uint8_t[]> data(index);
std::shared_ptr<uint8_t[]> data(index, [](void*) {});
binary_set->Append(index_key, data, index_size);
auto status = CStatus();

View File

@ -353,15 +353,16 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() {
default:
wg := sync.WaitGroup{}
mu := sync.Mutex{}
findMapMutex := sync.RWMutex{}
for i := 0; i < len(ms.consumers); i++ {
if isChannelReady[i] {
continue
}
wg.Add(1)
go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu)
go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu, &findMapMutex)
}
wg.Wait()
timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady)
timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex)
if !ok || timeStamp <= ms.lastTimeStamp {
log.Printf("All timeTick's timestamps are inconsistent")
continue
@ -394,7 +395,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() {
func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int,
eofMsgMap map[int]Timestamp,
wg *sync.WaitGroup,
mu *sync.Mutex) {
mu *sync.Mutex,
findMapMutex *sync.RWMutex) {
defer wg.Done()
for {
select {
@ -421,7 +423,9 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int,
log.Printf("Failed to unmarshal, error = %v", err)
}
if headerMsg.MsgType == internalPb.MsgType_kTimeTick {
findMapMutex.Lock()
eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Timestamp
findMapMutex.Unlock()
return
}
mu.Lock()
@ -470,7 +474,7 @@ func (ms *InMemMsgStream) Chan() <- chan *MsgPack {
}
*/
func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp, bool) {
func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool, mu *sync.RWMutex) (Timestamp, bool) {
checkMap := make(map[Timestamp]int)
var maxTime Timestamp = 0
for _, v := range msg {
@ -485,7 +489,10 @@ func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp,
}
return maxTime, true
}
for i, v := range msg {
for i := range msg {
mu.RLock()
v := msg[i]
mu.Unlock()
if v != maxTime {
isChannelReady[i] = false
} else {

View File

@ -374,9 +374,6 @@ func (qt *QueryTask) PreExecute() error {
}
}
qt.MsgType = internalpb.MsgType_kSearch
if qt.query.PartitionTags == nil || len(qt.query.PartitionTags) <= 0 {
qt.query.PartitionTags = []string{Params.defaultPartitionTag()}
}
queryBytes, err := proto.Marshal(qt.query)
if err != nil {
return err

View File

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log"
"regexp"
"sync"
"github.com/golang/protobuf/proto"
@ -223,7 +224,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
return errors.New("unmarshal query failed")
}
collectionName := query.CollectionName
partitionTags := query.PartitionTags
partitionTagsInQuery := query.PartitionTags
collection, err := ss.replica.getCollectionByName(collectionName)
if err != nil {
return err
@ -245,11 +246,29 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
searchResults := make([]*SearchResult, 0)
matchedSegments := make([]*Segment, 0)
for _, partitionTag := range partitionTags {
partition, err := ss.replica.getPartitionByTag(collectionID, partitionTag)
if err != nil {
continue
fmt.Println("search msg's partitionTag = ", partitionTagsInQuery)
var partitionTagsInCol []string
for _, partition := range collection.partitions {
partitionTag := partition.partitionTag
partitionTagsInCol = append(partitionTagsInCol, partitionTag)
}
var searchPartitionTag []string
if len(partitionTagsInQuery) == 0 {
searchPartitionTag = partitionTagsInCol
} else {
for _, tag := range partitionTagsInCol {
for _, toMatchTag := range partitionTagsInQuery {
re := regexp.MustCompile("^" + toMatchTag + "$")
if re.MatchString(tag) {
searchPartitionTag = append(searchPartitionTag, tag)
}
}
}
}
for _, partitionTag := range searchPartitionTag {
partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
for _, segment := range partition.segments {
//fmt.Println("dsl = ", dsl)
@ -360,6 +379,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
}
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error {
fmt.Println("Public fail SearchResult!")
msgPack := msgstream.MsgPack{}
searchMsg, ok := msg.(*msgstream.SearchMsg)
if !ok {

View File

@ -255,7 +255,7 @@ class TestSearchBase:
assert res2[0][0].id == res[0][1].id
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
# pass
# Pass
@pytest.mark.skip("search_after_index")
@pytest.mark.level(2)
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
@ -303,6 +303,7 @@ class TestSearchBase:
assert len(res) == nq
assert len(res[0]) == default_top_k
# should fix, 336 assert fail, insert data don't have partitionTag, But search data have
@pytest.mark.skip("search_index_partition")
@pytest.mark.level(2)
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
@ -334,7 +335,7 @@ class TestSearchBase:
res = connect.search(collection, query, partition_tags=[default_tag])
assert len(res) == nq
# pass
# PASS
@pytest.mark.skip("search_index_partition_B")
@pytest.mark.level(2)
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
@ -385,6 +386,7 @@ class TestSearchBase:
assert len(res) == nq
assert len(res[0]) == 0
# PASS
@pytest.mark.skip("search_index_partitions")
@pytest.mark.level(2)
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
@ -419,6 +421,7 @@ class TestSearchBase:
assert res[0]._distances[0] > epsilon
assert res[1]._distances[0] > epsilon
# Pass
@pytest.mark.skip("search_index_partitions_B")
@pytest.mark.level(2)
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
@ -479,7 +482,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# pass
# PASS
@pytest.mark.skip("search_ip_after_index")
@pytest.mark.level(2)
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
@ -509,6 +512,7 @@ class TestSearchBase:
assert check_id_result(res[0], ids[0])
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
# should fix, nq not correct
@pytest.mark.skip("search_ip_index_partition")
@pytest.mark.level(2)
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
@ -542,6 +546,7 @@ class TestSearchBase:
res = connect.search(collection, query, partition_tags=[default_tag])
assert len(res) == nq
# PASS
@pytest.mark.skip("search_ip_index_partitions")
@pytest.mark.level(2)
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
@ -621,6 +626,7 @@ class TestSearchBase:
res = connect.search(collection, query)
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
# Pass
@pytest.mark.skip("test_search_distance_l2_after_index")
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
'''
@ -675,7 +681,7 @@ class TestSearchBase:
res = connect.search(collection, query)
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
# pass
# Pass
@pytest.mark.skip("search_distance_ip_after_index")
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
'''
@ -946,6 +952,7 @@ class TestSearchBase:
assert res[i]._distances[0] < epsilon
assert res[i]._distances[1] > epsilon
# should fix
@pytest.mark.skip("query_entities_with_field_less_than_top_k")
def test_query_entities_with_field_less_than_top_k(self, connect, id_collection):
"""
@ -1745,7 +1752,7 @@ class TestSearchInvalid(object):
def get_search_params(self, request):
yield request.param
# pass
# Pass
@pytest.mark.skip("search_with_invalid_params")
@pytest.mark.level(2)
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
@ -1787,7 +1794,7 @@ class TestSearchInvalid(object):
with pytest.raises(Exception) as e:
res = connect.search(binary_collection, query)
# pass
# Pass
@pytest.mark.skip("search_with_empty_params")
@pytest.mark.level(2)
def test_search_with_empty_params(self, connect, collection, args, get_simple_index):