milvus/internal/reader/search_service.go

282 lines
7.6 KiB
Go

package reader
import "C"
import (
"context"
"errors"
"fmt"
"log"
"math"
"sync"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
type searchService struct {
ctx context.Context
wait sync.WaitGroup
cancel context.CancelFunc
msgBuffer chan msgstream.TsMsg
unsolvedMsg []msgstream.TsMsg
replica *collectionReplica
tSafeWatcher *tSafeWatcher
searchMsgStream *msgstream.MsgStream
searchResultMsgStream *msgstream.MsgStream
}
type ResultEntityIds []UniqueID
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
receiveBufSize := Params.searchReceiveBufSize()
pulsarBufSize := Params.searchPulsarBufSize()
msgStreamURL, err := Params.pulsarAddress()
if err != nil {
log.Fatal(err)
}
consumeChannels := []string{"search"}
consumeSubName := "subSearch"
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(msgStreamURL)
unmarshalDispatcher := msgstream.NewUnmarshalDispatcher()
searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
var inputStream msgstream.MsgStream = searchStream
producerChannels := []string{"searchResult"}
searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchResultStream.SetPulsarClient(msgStreamURL)
searchResultStream.CreatePulsarProducers(producerChannels)
var outputStream msgstream.MsgStream = searchResultStream
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
msgBuffer := make(chan msgstream.TsMsg, receiveBufSize)
unsolvedMsg := make([]msgstream.TsMsg, 0)
return &searchService{
ctx: searchServiceCtx,
cancel: searchServiceCancel,
msgBuffer: msgBuffer,
unsolvedMsg: unsolvedMsg,
replica: replica,
tSafeWatcher: newTSafeWatcher(),
searchMsgStream: &inputStream,
searchResultMsgStream: &outputStream,
}
}
func (ss *searchService) start() {
(*ss.searchMsgStream).Start()
(*ss.searchResultMsgStream).Start()
ss.wait.Add(2)
go ss.receiveSearchMsg()
go ss.startSearchService()
ss.wait.Wait()
}
func (ss *searchService) close() {
(*ss.searchMsgStream).Close()
(*ss.searchResultMsgStream).Close()
ss.cancel()
}
func (ss *searchService) register() {
tSafe := (*(ss.replica)).getTSafe()
(*tSafe).registerTSafeWatcher(ss.tSafeWatcher)
}
func (ss *searchService) waitNewTSafe() Timestamp {
// block until dataSyncService updating tSafe
ss.tSafeWatcher.hasUpdate()
timestamp := (*(*ss.replica).getTSafe()).get()
return timestamp
}
func (ss *searchService) receiveSearchMsg() {
defer ss.wait.Done()
for {
select {
case <-ss.ctx.Done():
return
default:
msgPack := (*ss.searchMsgStream).Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
for i := range msgPack.Msgs {
ss.msgBuffer <- msgPack.Msgs[i]
//fmt.Println("receive a search msg")
}
}
}
}
func (ss *searchService) startSearchService() {
defer ss.wait.Done()
for {
select {
case <-ss.ctx.Done():
return
default:
serviceTimestamp := (*(*ss.replica).getTSafe()).get()
searchMsg := make([]msgstream.TsMsg, 0)
tempMsg := make([]msgstream.TsMsg, 0)
tempMsg = append(tempMsg, ss.unsolvedMsg...)
ss.unsolvedMsg = ss.unsolvedMsg[:0]
for _, msg := range tempMsg {
if msg.BeginTs() > serviceTimestamp {
searchMsg = append(searchMsg, msg)
continue
}
ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
}
msgBufferLength := len(ss.msgBuffer)
for i := 0; i < msgBufferLength; i++ {
msg := <-ss.msgBuffer
if msg.BeginTs() > serviceTimestamp {
searchMsg = append(searchMsg, msg)
continue
}
ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
}
if len(searchMsg) <= 0 {
continue
}
err := ss.search(searchMsg)
if err != nil {
fmt.Println("search Failed")
ss.publishFailedSearchResult()
}
fmt.Println("Do search done")
}
}
}
func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
// TODO:: cache map[dsl]plan
// TODO: reBatched search requests
for _, msg := range searchMessages {
searchMsg, ok := msg.(*msgstream.SearchMsg)
if !ok {
return errors.New("invalid request type = " + string(msg.Type()))
}
searchTimestamp := searchMsg.Timestamp
var queryBlob = searchMsg.Query.Value
query := servicepb.Query{}
err := proto.Unmarshal(queryBlob, &query)
if err != nil {
return errors.New("unmarshal query failed")
}
collectionName := query.CollectionName
partitionTags := query.PartitionTags
collection, err := (*ss.replica).getCollectionByName(collectionName)
if err != nil {
return err
}
collectionID := collection.ID()
dsl := query.Dsl
plan := CreatePlan(*collection, dsl)
topK := plan.GetTopK()
placeHolderGroupBlob := query.PlaceholderGroup
group := servicepb.PlaceholderGroup{}
err = proto.Unmarshal(placeHolderGroupBlob, &group)
if err != nil {
return err
}
placeholderGroup := ParserPlaceholderGroup(plan, placeHolderGroupBlob)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups = append(placeholderGroups, placeholderGroup)
// 2d slice for receiving multiple queries's results
var numQueries int64 = 0
for _, pg := range placeholderGroups {
numQueries += pg.GetNumOfQuery()
}
resultIds := make([]IntPrimaryKey, topK*numQueries)
resultDistances := make([]float32, topK*numQueries)
for i := range resultDistances {
resultDistances[i] = math.MaxFloat32
}
// 3. Do search in all segments
for _, partitionTag := range partitionTags {
partition, err := (*ss.replica).getPartitionByTag(collectionID, partitionTag)
if err != nil {
return err
}
for _, segment := range partition.segments {
err := segment.segmentSearch(plan,
placeholderGroups,
[]Timestamp{searchTimestamp},
resultIds,
resultDistances,
numQueries,
topK)
if err != nil {
return err
}
}
}
// 4. return results
hits := make([]*servicepb.Hits, 0)
for i := int64(0); i < numQueries; i++ {
hit := servicepb.Hits{}
score := servicepb.Score{}
for j := i * topK; j < (i+1)*topK; j++ {
hit.IDs = append(hit.IDs, resultIds[j])
score.Values = append(score.Values, resultDistances[j])
}
hit.Scores = append(hit.Scores, &score)
hits = append(hits, &hit)
}
var results = internalpb.SearchResult{
MsgType: internalpb.MsgType_kSearchResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
ReqID: searchMsg.ReqID,
ProxyID: searchMsg.ProxyID,
QueryNodeID: searchMsg.ProxyID,
Timestamp: searchTimestamp,
ResultChannelID: searchMsg.ResultChannelID,
Hits: hits,
}
var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: results}
ss.publishSearchResult(tsMsg)
plan.Delete()
placeholderGroup.Delete()
}
return nil
}
func (ss *searchService) publishSearchResult(res msgstream.TsMsg) {
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, res)
(*ss.searchResultMsgStream).Produce(&msgPack)
}
func (ss *searchService) publishFailedSearchResult() {
var errorResults = internalpb.SearchResult{
MsgType: internalpb.MsgType_kSearchResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
}
var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: errorResults}
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
(*ss.searchResultMsgStream).Produce(&msgPack)
}