mirror of https://github.com/milvus-io/milvus.git
Fix reduce algorithm in proxy search task (#8206)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/8279/head
parent
c1e229cb7e
commit
b2e8ba7b33
|
@ -83,6 +83,8 @@ const (
|
|||
CreateAliasTaskName = "CreateAliasTask"
|
||||
DropAliasTaskName = "DropAliasTask"
|
||||
AlterAliasTaskName = "AlterAliasTask"
|
||||
|
||||
minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
)
|
||||
|
||||
type task interface {
|
||||
|
@ -1755,8 +1757,6 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
}
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
// TODO(yukun): Use parallel function
|
||||
var realTopK int64 = -1
|
||||
var idx int64
|
||||
|
@ -1766,17 +1766,14 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
|
||||
j = 0
|
||||
for ; j < topk; j++ {
|
||||
valid := true
|
||||
choice, maxDistance := 0, minFloat32
|
||||
choice, maxDistance := -1, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= topk {
|
||||
continue
|
||||
}
|
||||
curIdx := idx*topk + loc
|
||||
id := searchResultData[q].Ids.GetIntId().Data[curIdx]
|
||||
if id == -1 {
|
||||
valid = false
|
||||
} else {
|
||||
if id != -1 {
|
||||
distance := searchResultData[q].Scores[curIdx]
|
||||
if distance > maxDistance {
|
||||
choice = q
|
||||
|
@ -1784,7 +1781,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
}
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
if choice == -1 {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
|
|
|
@ -1903,7 +1903,7 @@ func TestSearchTask_all(t *testing.T) {
|
|||
for i := 0; i < nq; i++ {
|
||||
for j := 0; j < topk; j++ {
|
||||
offset := i*topk + j
|
||||
score := rand.Float32()
|
||||
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
|
||||
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
resultData.Scores[offset] = score
|
||||
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
|
||||
|
@ -1981,6 +1981,250 @@ func TestSearchTask_all(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSearchTask_7803_reduce(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Params.Init()
|
||||
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
|
||||
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = InitMetaCache(rc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
prefix := "TestSearchTask_7803_reduce"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
nprobe := 10
|
||||
|
||||
schema := constructCollectionSchema(
|
||||
int64Field,
|
||||
floatVecField,
|
||||
dim,
|
||||
collectionName)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
query := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
defer chMgr.removeAllDQLStream()
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
qc := NewQueryCoordMock()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
DbID: 0,
|
||||
CollectionID: collectionID,
|
||||
Schema: nil,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
req := constructSearchRequest(dbName, collectionName,
|
||||
expr,
|
||||
floatVecField,
|
||||
nq, dim, nprobe, topk)
|
||||
|
||||
task := &searchTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
|
||||
DbID: 0,
|
||||
CollectionID: 0,
|
||||
PartitionIDs: nil,
|
||||
Dsl: "",
|
||||
PlaceholderGroup: nil,
|
||||
DslType: 0,
|
||||
SerializedExprPlan: nil,
|
||||
OutputFieldsId: nil,
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
},
|
||||
ctx: ctx,
|
||||
resultBuf: make(chan []*internalpb.SearchResults),
|
||||
result: nil,
|
||||
query: req,
|
||||
chMgr: chMgr,
|
||||
qc: qc,
|
||||
}
|
||||
|
||||
// simple mock for query node
|
||||
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
|
||||
|
||||
err = chMgr.createDQLStream(collectionID)
|
||||
assert.NoError(t, err)
|
||||
stream, err := chMgr.getDQLStream(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
consumeCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-consumeCtx.Done():
|
||||
return
|
||||
case pack := <-stream.Chan():
|
||||
for _, msg := range pack.Msgs {
|
||||
_, ok := msg.(*msgstream.SearchMsg)
|
||||
assert.True(t, ok)
|
||||
// TODO(dragondriver): construct result according to the request
|
||||
|
||||
constructSearchResulstData := func(invalidNum int) *schemapb.SearchResultData {
|
||||
resultData := &schemapb.SearchResultData{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
FieldsData: nil,
|
||||
Scores: make([]float32, nq*topk),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, nq*topk),
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: make([]int64, nq),
|
||||
}
|
||||
|
||||
for i := 0; i < nq; i++ {
|
||||
for j := 0; j < topk; j++ {
|
||||
offset := i*topk + j
|
||||
if j >= invalidNum {
|
||||
resultData.Scores[offset] = minFloat32
|
||||
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = -1
|
||||
} else {
|
||||
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
|
||||
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
resultData.Scores[offset] = score
|
||||
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
|
||||
}
|
||||
}
|
||||
resultData.Topks[i] = int64(topk)
|
||||
}
|
||||
|
||||
return resultData
|
||||
}
|
||||
|
||||
result1 := &internalpb.SearchResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_SearchResult,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
ResultChannelID: "",
|
||||
MetricType: distance.L2,
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
SealedSegmentIDsSearched: nil,
|
||||
ChannelIDsSearched: nil,
|
||||
GlobalSealedSegmentIDs: nil,
|
||||
SlicedBlob: nil,
|
||||
SlicedNumCount: 1,
|
||||
SlicedOffset: 0,
|
||||
}
|
||||
resultData := constructSearchResulstData(topk / 2)
|
||||
sliceBlob, err := proto.Marshal(resultData)
|
||||
assert.NoError(t, err)
|
||||
result1.SlicedBlob = sliceBlob
|
||||
|
||||
result2 := &internalpb.SearchResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_SearchResult,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
ResultChannelID: "",
|
||||
MetricType: distance.L2,
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
SealedSegmentIDsSearched: nil,
|
||||
ChannelIDsSearched: nil,
|
||||
GlobalSealedSegmentIDs: nil,
|
||||
SlicedBlob: nil,
|
||||
SlicedNumCount: 1,
|
||||
SlicedOffset: 0,
|
||||
}
|
||||
resultData2 := constructSearchResulstData(topk - topk/2)
|
||||
sliceBlob2, err := proto.Marshal(resultData2)
|
||||
assert.NoError(t, err)
|
||||
result2.SlicedBlob = sliceBlob2
|
||||
|
||||
// send search result
|
||||
task.resultBuf <- []*internalpb.SearchResults{result1, result2}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSearchTask_Type(t *testing.T) {
|
||||
Params.Init()
|
||||
|
||||
|
|
Loading…
Reference in New Issue