mirror of https://github.com/milvus-io/milvus.git
related: #36407 --------- Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/36463/head
parent
98a917c5d4
commit
d55d9d6e1d
|
@ -218,95 +218,91 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
totalResCount += subSearchNqOffset[i][nq-1]
|
||||
}
|
||||
|
||||
if subSearchNum == 1 && offset == 0 {
|
||||
ret.Results = subSearchResultData[0]
|
||||
} else {
|
||||
var realTopK int64 = -1
|
||||
var retSize int64
|
||||
var realTopK int64 = -1
|
||||
var retSize int64
|
||||
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
||||
skipOffsetMap = make(map[interface{}]bool)
|
||||
groupByValList = make([]interface{}, limit)
|
||||
groupByValIdx = 0
|
||||
)
|
||||
j int64
|
||||
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
||||
skipOffsetMap = make(map[interface{}]bool)
|
||||
groupByValList = make([]interface{}, limit)
|
||||
groupByValIdx = 0
|
||||
)
|
||||
|
||||
for j = 0; j < groupBound; {
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
for j = 0; j < groupBound; {
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.GetScores()[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
if groupByVal == nil {
|
||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||
"there must be sth wrong on queryNode side")
|
||||
}
|
||||
|
||||
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
||||
skipOffsetMap[groupByVal] = true
|
||||
// the first offset's group will be ignored
|
||||
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
||||
// skip when groupbyMap has been full and found new groupByVal
|
||||
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
||||
// skip when target group has been full
|
||||
} else {
|
||||
if len(groupByValMap[groupByVal]) == 0 {
|
||||
groupByValList[groupByValIdx] = groupByVal
|
||||
groupByValIdx++
|
||||
}
|
||||
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
|
||||
subSearchIdx: subSearchIdx,
|
||||
resultIdx: resultDataIdx, id: id, score: score,
|
||||
})
|
||||
j++
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.GetScores()[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
if groupByVal == nil {
|
||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||
"there must be sth wrong on queryNode side")
|
||||
}
|
||||
|
||||
// assemble all eligible values in group
|
||||
// values in groupByValList is sorted by the highest score in each group
|
||||
for _, groupVal := range groupByValList {
|
||||
if groupVal != nil {
|
||||
groupEntities := groupByValMap[groupVal]
|
||||
for _, groupEntity := range groupEntities {
|
||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
||||
skipOffsetMap[groupByVal] = true
|
||||
// the first offset's group will be ignored
|
||||
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
||||
// skip when groupbyMap has been full and found new groupByVal
|
||||
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
||||
// skip when target group has been full
|
||||
} else {
|
||||
if len(groupByValMap[groupByVal]) == 0 {
|
||||
groupByValList[groupByValIdx] = groupByVal
|
||||
groupByValIdx++
|
||||
}
|
||||
groupByValMap[groupByVal] = append(groupByValMap[groupByVal], &groupReduceInfo{
|
||||
subSearchIdx: subSearchIdx,
|
||||
resultIdx: resultDataIdx, id: id, score: score,
|
||||
})
|
||||
j++
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
// assemble all eligible values in group
|
||||
// values in groupByValList is sorted by the highest score in each group
|
||||
for _, groupVal := range groupByValList {
|
||||
if groupVal != nil {
|
||||
groupEntities := groupByValMap[groupVal]
|
||||
for _, groupEntity := range groupEntities {
|
||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
|
|
|
@ -12,8 +12,8 @@ allure-pytest==2.7.0
|
|||
pytest-print==0.2.1
|
||||
pytest-level==0.1.1
|
||||
pytest-xdist==2.5.0
|
||||
pymilvus==2.5.0rc80
|
||||
pymilvus[bulk_writer]==2.5.0rc80
|
||||
pymilvus==2.5.0rc81
|
||||
pymilvus[bulk_writer]==2.5.0rc81
|
||||
pytest-rerunfailures==9.1.1
|
||||
git+https://github.com/Projectplace/pytest-tags
|
||||
ndg-httpsclient
|
||||
|
|
|
@ -1266,7 +1266,6 @@ class TestGroupSearch(TestCaseClassBase):
|
|||
self.collection_wrap.load()
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.xfail(reason="issue #36407")
|
||||
@pytest.mark.parametrize("group_by_field", [DataType.VARCHAR.name, "varchar_with_index"])
|
||||
def test_search_group_size(self, group_by_field):
|
||||
"""
|
||||
|
@ -1308,7 +1307,6 @@ class TestGroupSearch(TestCaseClassBase):
|
|||
assert len(set(group_values)) == limit
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.xfail(reason="issue #36407")
|
||||
def test_hybrid_search_group_size(self):
|
||||
"""
|
||||
hybrid search group by on 3 different float vector fields with group by varchar field with group size
|
||||
|
@ -1360,7 +1358,6 @@ class TestGroupSearch(TestCaseClassBase):
|
|||
group_distances = [res[i][l + 1].distance]
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #36407")
|
||||
def test_hybrid_search_group_by(self):
|
||||
"""
|
||||
verify hybrid search group by works with different Rankers
|
||||
|
|
Loading…
Reference in New Issue