mirror of https://github.com/milvus-io/milvus.git
Fix search without insertion, improve nil hits behavior (#622)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/4973/head^2
parent
004d0027b3
commit
4fad3b189c
|
@ -476,7 +476,16 @@ func (qt *QueryTask) PostExecute() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
topk := len(hits[0][0].IDs)
|
||||
topk := 0
|
||||
getMax := func(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
for _, hit := range hits {
|
||||
topk = getMax(topk, len(hit[0].IDs))
|
||||
}
|
||||
qt.result = &servicepb.QueryResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
|
@ -494,14 +503,22 @@ func (qt *QueryTask) PostExecute() error {
|
|||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][i].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][i].Scores[loc]
|
||||
if distance > maxDistance {
|
||||
if distance > maxDistance || (distance == maxDistance && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
|
|
|
@ -283,25 +283,37 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
}
|
||||
|
||||
if len(searchResults) <= 0 {
|
||||
var results = internalpb.SearchResult{
|
||||
MsgType: internalpb.MsgType_kSearchResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
|
||||
ReqID: searchMsg.ReqID,
|
||||
ProxyID: searchMsg.ProxyID,
|
||||
QueryNodeID: ss.queryNodeID,
|
||||
Timestamp: searchTimestamp,
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
Hits: nil,
|
||||
for _, group := range placeholderGroups {
|
||||
nq := group.getNumOfQuery()
|
||||
nilHits := make([][]byte, nq)
|
||||
hit := &servicepb.Hits{}
|
||||
for i := 0; i < int(nq); i++ {
|
||||
bs, err := proto.Marshal(hit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nilHits[i] = bs
|
||||
}
|
||||
var results = internalpb.SearchResult{
|
||||
MsgType: internalpb.MsgType_kSearchResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
|
||||
ReqID: searchMsg.ReqID,
|
||||
ProxyID: searchMsg.ProxyID,
|
||||
QueryNodeID: ss.queryNodeID,
|
||||
Timestamp: searchTimestamp,
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
Hits: nilHits,
|
||||
}
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
SearchResult: results,
|
||||
}
|
||||
err = ss.publishSearchResult(searchResultMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
SearchResult: results,
|
||||
}
|
||||
err = ss.publishSearchResult(searchResultMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
inReduced := make([]bool, len(searchResults))
|
||||
|
|
|
@ -287,6 +287,7 @@ class TestIndexBase:
|
|||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
@pytest.mark.skip("test_create_index_multithread_ip")
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_multithread_ip(self, connect, collection, args):
|
||||
'''
|
||||
|
|
|
@ -303,6 +303,7 @@ class TestSearchBase:
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# pass
|
||||
# should fix, 336 assert fail, insert data don't have partitionTag, But search data have
|
||||
@pytest.mark.skip("search_index_partition")
|
||||
@pytest.mark.level(2)
|
||||
|
|
Loading…
Reference in New Issue