mirror of https://github.com/milvus-io/milvus.git
Merge search_collection and retrieve_colletion into query_collection (#6037)
* Merge search_collection and retrieve_colletion into query_collection Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix static-check Signed-off-by: fishpenguin <kun.yu@zilliz.com>pull/6054/head
parent
24038be146
commit
b72e4c6372
|
@ -53,6 +53,8 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
|
|||
deleteMsg := DeleteMsg{}
|
||||
searchMsg := SearchMsg{}
|
||||
searchResultMsg := SearchResultMsg{}
|
||||
retrieveMsg := RetrieveMsg{}
|
||||
retrieveResultMsg := RetrieveResultMsg{}
|
||||
timeTickMsg := TimeTickMsg{}
|
||||
createCollectionMsg := CreateCollectionMsg{}
|
||||
dropCollectionMsg := DropCollectionMsg{}
|
||||
|
@ -72,6 +74,8 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
|
|||
p.TempMap[commonpb.MsgType_Delete] = deleteMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_Search] = searchMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_SearchResult] = searchResultMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_Retrieve] = retrieveMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_RetrieveResult] = retrieveResultMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_TimeTick] = timeTickMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_QueryNodeStats] = queryNodeSegStatsMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_CreateCollection] = createCollectionMsg.Unmarshal
|
||||
|
|
|
@ -1537,8 +1537,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
zap.Uint64("timestamp", rt.Base.Timestamp),
|
||||
zap.String("db", retrieveRequest.DbName),
|
||||
zap.String("collection", retrieveRequest.CollectionName),
|
||||
zap.Any("partitions", retrieveRequest.PartitionNames),
|
||||
zap.Any("len(Ids)", len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
|
||||
zap.Any("partitions", retrieveRequest.PartitionNames))
|
||||
}()
|
||||
|
||||
err = rt.WaitToFinish()
|
||||
|
|
|
@ -75,8 +75,8 @@ func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.Stri
|
|||
|
||||
func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQueryChannelRequest) (*commonpb.Status, error) {
|
||||
collectionID := in.CollectionID
|
||||
if node.searchService == nil {
|
||||
errMsg := "null search service, collectionID = " + fmt.Sprintln(collectionID)
|
||||
if node.queryService == nil {
|
||||
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
|
@ -94,34 +94,32 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
|
|||
//}
|
||||
|
||||
// add search collection
|
||||
if !node.searchService.hasSearchCollection(collectionID) {
|
||||
node.searchService.addSearchCollection(collectionID)
|
||||
log.Debug("add search collection", zap.Any("collectionID", collectionID))
|
||||
if !node.queryService.hasQueryCollection(collectionID) {
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
||||
}
|
||||
|
||||
// add request channel
|
||||
sc := node.searchService.searchCollections[in.CollectionID]
|
||||
sc := node.queryService.queryCollections[in.CollectionID]
|
||||
consumeChannels := []string{in.RequestChannelID}
|
||||
//consumeSubName := Params.MsgChannelSubName
|
||||
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
|
||||
sc.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
node.retrieveService.retrieveMsgStream.AsConsumer(consumeChannels, "RetrieveSubName")
|
||||
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{in.ResultChannelID}
|
||||
sc.searchResultMsgStream.AsProducer(producerChannels)
|
||||
node.retrieveService.retrieveResultMsgStream.AsProducer(producerChannels)
|
||||
sc.queryResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||
|
||||
// message stream need to asConsumer before start
|
||||
// add search collection
|
||||
if !node.searchService.hasSearchCollection(collectionID) {
|
||||
node.searchService.addSearchCollection(collectionID)
|
||||
log.Debug("add search collection", zap.Any("collectionID", collectionID))
|
||||
if !node.queryService.hasQueryCollection(collectionID) {
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
||||
}
|
||||
sc.start()
|
||||
log.Debug("start search collection", zap.Any("collectionID", collectionID))
|
||||
log.Debug("start query collection", zap.Any("collectionID", collectionID))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -35,7 +36,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type searchCollection struct {
|
||||
type queryCollection struct {
|
||||
releaseCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
|
@ -43,9 +44,9 @@ type searchCollection struct {
|
|||
historical *historical
|
||||
streaming *streaming
|
||||
|
||||
msgBuffer chan *msgstream.SearchMsg
|
||||
unsolvedMsgMu sync.Mutex // guards unsolvedMsg
|
||||
unsolvedMsg []*msgstream.SearchMsg
|
||||
unsolvedMsgMu sync.Mutex // guards unsolvedMsg
|
||||
unsolvedMsg []*msgstream.SearchMsg
|
||||
unsolvedRetrieveMsg []*msgstream.RetrieveMsg
|
||||
|
||||
tSafeWatchers map[Channel]*tSafeWatcher
|
||||
watcherSelectCase []reflect.SelectCase
|
||||
|
@ -53,27 +54,25 @@ type searchCollection struct {
|
|||
serviceableTimeMutex sync.Mutex // guards serviceableTime
|
||||
serviceableTime Timestamp
|
||||
|
||||
searchMsgStream msgstream.MsgStream
|
||||
searchResultMsgStream msgstream.MsgStream
|
||||
queryMsgStream msgstream.MsgStream
|
||||
queryResultMsgStream msgstream.MsgStream
|
||||
}
|
||||
|
||||
type ResultEntityIds []UniqueID
|
||||
|
||||
func newSearchCollection(releaseCtx context.Context,
|
||||
func newQueryCollection(releaseCtx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
collectionID UniqueID,
|
||||
historical *historical,
|
||||
streaming *streaming,
|
||||
factory msgstream.Factory) *searchCollection {
|
||||
factory msgstream.Factory) *queryCollection {
|
||||
|
||||
receiveBufSize := Params.SearchReceiveBufSize
|
||||
msgBuffer := make(chan *msgstream.SearchMsg, receiveBufSize)
|
||||
unsolvedMsg := make([]*msgstream.SearchMsg, 0)
|
||||
|
||||
searchStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
searchResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
queryStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
|
||||
sc := &searchCollection{
|
||||
qc := &queryCollection{
|
||||
releaseCtx: releaseCtx,
|
||||
cancel: cancel,
|
||||
|
||||
|
@ -83,80 +82,94 @@ func newSearchCollection(releaseCtx context.Context,
|
|||
|
||||
tSafeWatchers: make(map[Channel]*tSafeWatcher),
|
||||
|
||||
msgBuffer: msgBuffer,
|
||||
unsolvedMsg: unsolvedMsg,
|
||||
|
||||
searchMsgStream: searchStream,
|
||||
searchResultMsgStream: searchResultStream,
|
||||
queryMsgStream: queryStream,
|
||||
queryResultMsgStream: queryResultStream,
|
||||
}
|
||||
|
||||
sc.register()
|
||||
return sc
|
||||
qc.register()
|
||||
return qc
|
||||
}
|
||||
|
||||
func (s *searchCollection) start() {
|
||||
go s.searchMsgStream.Start()
|
||||
go s.searchResultMsgStream.Start()
|
||||
go s.consumeSearch()
|
||||
go s.doUnsolvedMsgSearch()
|
||||
func (q *queryCollection) start() {
|
||||
go q.queryMsgStream.Start()
|
||||
go q.queryResultMsgStream.Start()
|
||||
go q.consumeQuery()
|
||||
go q.doUnsolvedMsgSearch()
|
||||
go q.doUnsolvedMsgRetrieve()
|
||||
}
|
||||
|
||||
func (s *searchCollection) close() {
|
||||
if s.searchMsgStream != nil {
|
||||
s.searchMsgStream.Close()
|
||||
func (q *queryCollection) close() {
|
||||
if q.queryMsgStream != nil {
|
||||
q.queryMsgStream.Close()
|
||||
}
|
||||
if s.searchResultMsgStream != nil {
|
||||
s.searchResultMsgStream.Close()
|
||||
if q.queryResultMsgStream != nil {
|
||||
q.queryResultMsgStream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *searchCollection) register() {
|
||||
collection, err := s.streaming.replica.getCollectionByID(s.collectionID)
|
||||
func (q *queryCollection) register() {
|
||||
collection, err := q.streaming.replica.getCollectionByID(q.collectionID)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.watcherSelectCase = make([]reflect.SelectCase, 0)
|
||||
q.watcherSelectCase = make([]reflect.SelectCase, 0)
|
||||
log.Debug("register tSafe watcher and init watcher select case",
|
||||
zap.Any("collectionID", collection.ID()),
|
||||
zap.Any("dml channels", collection.getVChannels()),
|
||||
)
|
||||
for _, channel := range collection.getVChannels() {
|
||||
s.tSafeWatchers[channel] = newTSafeWatcher()
|
||||
s.streaming.tSafeReplica.registerTSafeWatcher(channel, s.tSafeWatchers[channel])
|
||||
s.watcherSelectCase = append(s.watcherSelectCase, reflect.SelectCase{
|
||||
q.tSafeWatchers[channel] = newTSafeWatcher()
|
||||
q.streaming.tSafeReplica.registerTSafeWatcher(channel, q.tSafeWatchers[channel])
|
||||
q.watcherSelectCase = append(q.watcherSelectCase, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(s.tSafeWatchers[channel].watcherChan()),
|
||||
Chan: reflect.ValueOf(q.tSafeWatchers[channel].watcherChan()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *searchCollection) addToUnsolvedMsg(msg *msgstream.SearchMsg) {
|
||||
s.unsolvedMsgMu.Lock()
|
||||
defer s.unsolvedMsgMu.Unlock()
|
||||
s.unsolvedMsg = append(s.unsolvedMsg, msg)
|
||||
func (q *queryCollection) addToUnsolvedMsg(msg *msgstream.SearchMsg) {
|
||||
q.unsolvedMsgMu.Lock()
|
||||
defer q.unsolvedMsgMu.Unlock()
|
||||
q.unsolvedMsg = append(q.unsolvedMsg, msg)
|
||||
}
|
||||
|
||||
func (s *searchCollection) popAllUnsolvedMsg() []*msgstream.SearchMsg {
|
||||
s.unsolvedMsgMu.Lock()
|
||||
defer s.unsolvedMsgMu.Unlock()
|
||||
tmp := s.unsolvedMsg
|
||||
s.unsolvedMsg = s.unsolvedMsg[:0]
|
||||
func (q *queryCollection) addToUnsolvedRetrieveMsg(msg *msgstream.RetrieveMsg) {
|
||||
q.unsolvedMsgMu.Lock()
|
||||
defer q.unsolvedMsgMu.Unlock()
|
||||
q.unsolvedRetrieveMsg = append(q.unsolvedRetrieveMsg, msg)
|
||||
}
|
||||
|
||||
func (q *queryCollection) popAllUnsolvedMsg() []*msgstream.SearchMsg {
|
||||
q.unsolvedMsgMu.Lock()
|
||||
defer q.unsolvedMsgMu.Unlock()
|
||||
tmp := q.unsolvedMsg
|
||||
q.unsolvedMsg = q.unsolvedMsg[:0]
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (s *searchCollection) waitNewTSafe() Timestamp {
|
||||
func (q *queryCollection) popAllUnsolvedRetrieveMsg() []*msgstream.RetrieveMsg {
|
||||
q.unsolvedMsgMu.Lock()
|
||||
defer q.unsolvedMsgMu.Unlock()
|
||||
tmp := q.unsolvedRetrieveMsg
|
||||
q.unsolvedRetrieveMsg = q.unsolvedRetrieveMsg[:0]
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (q *queryCollection) waitNewTSafe() Timestamp {
|
||||
// block until any vChannel updating tSafe
|
||||
_, _, recvOK := reflect.Select(s.watcherSelectCase)
|
||||
_, _, recvOK := reflect.Select(q.watcherSelectCase)
|
||||
if !recvOK {
|
||||
log.Error("tSafe has been closed", zap.Any("collectionID", s.collectionID))
|
||||
log.Error("tSafe has been closed", zap.Any("collectionID", q.collectionID))
|
||||
return invalidTimestamp
|
||||
}
|
||||
//log.Debug("wait new tSafe", zap.Any("collectionID", s.collectionID))
|
||||
t := Timestamp(math.MaxInt64)
|
||||
for channel := range s.tSafeWatchers {
|
||||
ts := s.streaming.tSafeReplica.getTSafe(channel)
|
||||
for channel := range q.tSafeWatchers {
|
||||
ts := q.streaming.tSafeReplica.getTSafe(channel)
|
||||
if ts <= t {
|
||||
t = ts
|
||||
}
|
||||
|
@ -164,64 +177,66 @@ func (s *searchCollection) waitNewTSafe() Timestamp {
|
|||
return t
|
||||
}
|
||||
|
||||
func (s *searchCollection) getServiceableTime() Timestamp {
|
||||
s.serviceableTimeMutex.Lock()
|
||||
defer s.serviceableTimeMutex.Unlock()
|
||||
return s.serviceableTime
|
||||
func (q *queryCollection) getServiceableTime() Timestamp {
|
||||
q.serviceableTimeMutex.Lock()
|
||||
defer q.serviceableTimeMutex.Unlock()
|
||||
return q.serviceableTime
|
||||
}
|
||||
|
||||
func (s *searchCollection) setServiceableTime(t Timestamp) {
|
||||
s.serviceableTimeMutex.Lock()
|
||||
defer s.serviceableTimeMutex.Unlock()
|
||||
func (q *queryCollection) setServiceableTime(t Timestamp) {
|
||||
q.serviceableTimeMutex.Lock()
|
||||
defer q.serviceableTimeMutex.Unlock()
|
||||
|
||||
if t < s.serviceableTime {
|
||||
if t < q.serviceableTime {
|
||||
return
|
||||
}
|
||||
|
||||
gracefulTimeInMilliSecond := Params.GracefulTime
|
||||
if gracefulTimeInMilliSecond > 0 {
|
||||
gracefulTime := tsoutil.ComposeTS(gracefulTimeInMilliSecond, 0)
|
||||
s.serviceableTime = t + gracefulTime
|
||||
q.serviceableTime = t + gracefulTime
|
||||
} else {
|
||||
s.serviceableTime = t
|
||||
q.serviceableTime = t
|
||||
}
|
||||
}
|
||||
|
||||
func (s *searchCollection) emptySearch(searchMsg *msgstream.SearchMsg) {
|
||||
func (q *queryCollection) emptySearch(searchMsg *msgstream.SearchMsg) {
|
||||
sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
|
||||
defer sp.Finish()
|
||||
searchMsg.SetTraceCtx(ctx)
|
||||
err := s.search(searchMsg)
|
||||
err := q.search(searchMsg)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
s.publishFailedSearchResult(searchMsg, err.Error())
|
||||
q.publishFailedSearchResult(searchMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *searchCollection) consumeSearch() {
|
||||
func (q *queryCollection) consumeQuery() {
|
||||
for {
|
||||
select {
|
||||
case <-s.releaseCtx.Done():
|
||||
log.Debug("stop searchCollection's receiveSearchMsg", zap.Int64("collectionID", s.collectionID))
|
||||
case <-q.releaseCtx.Done():
|
||||
log.Debug("stop queryCollection's receiveQueryMsg", zap.Int64("collectionID", q.collectionID))
|
||||
return
|
||||
default:
|
||||
msgPack := s.searchMsgStream.Consume()
|
||||
msgPack := q.queryMsgStream.Consume()
|
||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||
msgPackNil := msgPack == nil
|
||||
msgPackEmpty := true
|
||||
if msgPack != nil {
|
||||
msgPackEmpty = len(msgPack.Msgs) <= 0
|
||||
}
|
||||
log.Debug("consume search message failed", zap.Any("msgPack is Nil", msgPackNil),
|
||||
log.Debug("consume query message failed", zap.Any("msgPack is Nil", msgPackNil),
|
||||
zap.Any("msgPackEmpty", msgPackEmpty))
|
||||
continue
|
||||
}
|
||||
for _, msg := range msgPack.Msgs {
|
||||
switch sm := msg.(type) {
|
||||
case *msgstream.SearchMsg:
|
||||
s.receiveSearch(sm)
|
||||
q.receiveSearch(sm)
|
||||
case *msgstream.LoadBalanceSegmentsMsg:
|
||||
s.loadBalance(sm)
|
||||
q.loadBalance(sm)
|
||||
case *msgstream.RetrieveMsg:
|
||||
q.receiveRetrieve(sm)
|
||||
default:
|
||||
log.Warn("unsupported msg type in search channel", zap.Any("msg", sm))
|
||||
}
|
||||
|
@ -230,7 +245,7 @@ func (s *searchCollection) consumeSearch() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *searchCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) {
|
||||
func (q *queryCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) {
|
||||
//TODO:: get loadBalance info from etcd
|
||||
//log.Debug("consume load balance message",
|
||||
// zap.Int64("msgID", msg.ID()))
|
||||
|
@ -261,8 +276,82 @@ func (s *searchCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) {
|
|||
// zap.Int("num of segment", len(msg.Infos)))
|
||||
}
|
||||
|
||||
func (s *searchCollection) receiveSearch(msg *msgstream.SearchMsg) {
|
||||
if msg.CollectionID != s.collectionID {
|
||||
func (q *queryCollection) receiveRetrieve(msg *msgstream.RetrieveMsg) {
|
||||
if msg.CollectionID != q.collectionID {
|
||||
log.Debug("not target collection retrieve request",
|
||||
zap.Any("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("consume retrieve message",
|
||||
zap.Any("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
msg.SetTraceCtx(ctx)
|
||||
|
||||
// check if collection has been released
|
||||
collection, err := q.historical.replica.getCollectionByID(msg.CollectionID)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
q.publishFailedRetrieveResult(msg, err.Error())
|
||||
return
|
||||
}
|
||||
if msg.BeginTs() >= collection.getReleaseTime() {
|
||||
err := errors.New("retrieve failed, collection has been released, msgID = " +
|
||||
fmt.Sprintln(msg.ID()) +
|
||||
", collectionID = " +
|
||||
fmt.Sprintln(msg.CollectionID))
|
||||
log.Error(err.Error())
|
||||
q.publishFailedRetrieveResult(msg, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
serviceTime := q.getServiceableTime()
|
||||
if msg.BeginTs() > serviceTime {
|
||||
bt, _ := tsoutil.ParseTS(msg.BeginTs())
|
||||
st, _ := tsoutil.ParseTS(serviceTime)
|
||||
log.Debug("query node::receiveRetrieveMsg: add to unsolvedMsg",
|
||||
zap.Any("collectionID", q.collectionID),
|
||||
zap.Any("sm.BeginTs", bt),
|
||||
zap.Any("serviceTime", st),
|
||||
zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)),
|
||||
zap.Any("msgID", msg.ID()),
|
||||
)
|
||||
q.addToUnsolvedRetrieveMsg(msg)
|
||||
sp.LogFields(
|
||||
oplog.String("send to unsolved buffer", "send to unsolved buffer"),
|
||||
oplog.Object("begin ts", bt),
|
||||
oplog.Object("serviceTime", st),
|
||||
oplog.Float64("delta seconds", float64(msg.BeginTs()-serviceTime)/(1000.0*1000.0*1000.0)),
|
||||
)
|
||||
sp.Finish()
|
||||
return
|
||||
}
|
||||
log.Debug("doing retrieve in receiveRetrieveMsg...",
|
||||
zap.Int64("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
err = q.retrieve(msg)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
log.Debug("do retrieve failed in receiveRetrieveMsg, prepare to publish failed retrieve result",
|
||||
zap.Int64("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
q.publishFailedRetrieveResult(msg, err.Error())
|
||||
}
|
||||
log.Debug("do retrieve done in receiveRetrieve",
|
||||
zap.Int64("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
sp.Finish()
|
||||
}
|
||||
|
||||
func (q *queryCollection) receiveSearch(msg *msgstream.SearchMsg) {
|
||||
if msg.CollectionID != q.collectionID {
|
||||
log.Debug("not target collection search request",
|
||||
zap.Any("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
|
@ -278,10 +367,10 @@ func (s *searchCollection) receiveSearch(msg *msgstream.SearchMsg) {
|
|||
msg.SetTraceCtx(ctx)
|
||||
|
||||
// check if collection has been released
|
||||
collection, err := s.historical.replica.getCollectionByID(msg.CollectionID)
|
||||
collection, err := q.historical.replica.getCollectionByID(msg.CollectionID)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
s.publishFailedSearchResult(msg, err.Error())
|
||||
q.publishFailedSearchResult(msg, err.Error())
|
||||
return
|
||||
}
|
||||
if msg.BeginTs() >= collection.getReleaseTime() {
|
||||
|
@ -290,22 +379,22 @@ func (s *searchCollection) receiveSearch(msg *msgstream.SearchMsg) {
|
|||
", collectionID = " +
|
||||
fmt.Sprintln(msg.CollectionID))
|
||||
log.Error(err.Error())
|
||||
s.publishFailedSearchResult(msg, err.Error())
|
||||
q.publishFailedSearchResult(msg, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
serviceTime := s.getServiceableTime()
|
||||
serviceTime := q.getServiceableTime()
|
||||
if msg.BeginTs() > serviceTime {
|
||||
bt, _ := tsoutil.ParseTS(msg.BeginTs())
|
||||
st, _ := tsoutil.ParseTS(serviceTime)
|
||||
log.Debug("query node::receiveSearchMsg: add to unsolvedMsg",
|
||||
zap.Any("collectionID", s.collectionID),
|
||||
zap.Any("collectionID", q.collectionID),
|
||||
zap.Any("sm.BeginTs", bt),
|
||||
zap.Any("serviceTime", st),
|
||||
zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)),
|
||||
zap.Any("msgID", msg.ID()),
|
||||
)
|
||||
s.addToUnsolvedMsg(msg)
|
||||
q.addToUnsolvedMsg(msg)
|
||||
sp.LogFields(
|
||||
oplog.String("send to unsolved buffer", "send to unsolved buffer"),
|
||||
oplog.Object("begin ts", bt),
|
||||
|
@ -319,14 +408,14 @@ func (s *searchCollection) receiveSearch(msg *msgstream.SearchMsg) {
|
|||
zap.Int64("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
err = s.search(msg)
|
||||
err = q.search(msg)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
log.Debug("do search failed in receiveSearchMsg, prepare to publish failed search result",
|
||||
zap.Int64("collectionID", msg.CollectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
)
|
||||
s.publishFailedSearchResult(msg, err.Error())
|
||||
q.publishFailedSearchResult(msg, err.Error())
|
||||
}
|
||||
log.Debug("do search done in receiveSearch",
|
||||
zap.Int64("collectionID", msg.CollectionID),
|
||||
|
@ -335,28 +424,28 @@ func (s *searchCollection) receiveSearch(msg *msgstream.SearchMsg) {
|
|||
sp.Finish()
|
||||
}
|
||||
|
||||
func (s *searchCollection) doUnsolvedMsgSearch() {
|
||||
log.Debug("starting doUnsolvedMsgSearch...", zap.Any("collectionID", s.collectionID))
|
||||
func (q *queryCollection) doUnsolvedMsgSearch() {
|
||||
log.Debug("starting doUnsolvedMsgSearch...", zap.Any("collectionID", q.collectionID))
|
||||
for {
|
||||
select {
|
||||
case <-s.releaseCtx.Done():
|
||||
log.Debug("stop searchCollection's doUnsolvedMsgSearch", zap.Int64("collectionID", s.collectionID))
|
||||
case <-q.releaseCtx.Done():
|
||||
log.Debug("stop searchCollection's doUnsolvedMsgSearch", zap.Int64("collectionID", q.collectionID))
|
||||
return
|
||||
default:
|
||||
//time.Sleep(10 * time.Millisecond)
|
||||
serviceTime := s.waitNewTSafe()
|
||||
serviceTime := q.waitNewTSafe()
|
||||
st, _ := tsoutil.ParseTS(serviceTime)
|
||||
log.Debug("get tSafe from flow graph",
|
||||
zap.Int64("collectionID", s.collectionID),
|
||||
zap.Int64("collectionID", q.collectionID),
|
||||
zap.Any("tSafe", st))
|
||||
|
||||
s.setServiceableTime(serviceTime)
|
||||
q.setServiceableTime(serviceTime)
|
||||
//log.Debug("query node::doUnsolvedMsgSearch: setServiceableTime",
|
||||
// zap.Any("serviceTime", st),
|
||||
//)
|
||||
|
||||
searchMsg := make([]*msgstream.SearchMsg, 0)
|
||||
tempMsg := s.popAllUnsolvedMsg()
|
||||
tempMsg := q.popAllUnsolvedMsg()
|
||||
|
||||
for _, sm := range tempMsg {
|
||||
bt, _ := tsoutil.ParseTS(sm.EndTs())
|
||||
|
@ -374,13 +463,13 @@ func (s *searchCollection) doUnsolvedMsgSearch() {
|
|||
continue
|
||||
}
|
||||
log.Debug("query node::doUnsolvedMsgSearch: add to unsolvedMsg",
|
||||
zap.Any("collectionID", s.collectionID),
|
||||
zap.Any("collectionID", q.collectionID),
|
||||
zap.Any("sm.BeginTs", bt),
|
||||
zap.Any("serviceTime", st),
|
||||
zap.Any("delta seconds", (sm.BeginTs()-serviceTime)/(1000*1000*1000)),
|
||||
zap.Any("msgID", sm.ID()),
|
||||
)
|
||||
s.addToUnsolvedMsg(sm)
|
||||
q.addToUnsolvedMsg(sm)
|
||||
}
|
||||
|
||||
if len(searchMsg) <= 0 {
|
||||
|
@ -393,14 +482,14 @@ func (s *searchCollection) doUnsolvedMsgSearch() {
|
|||
zap.Int64("collectionID", sm.CollectionID),
|
||||
zap.Int64("msgID", sm.ID()),
|
||||
)
|
||||
err := s.search(sm)
|
||||
err := q.search(sm)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
log.Debug("do search failed in doUnsolvedMsgSearch, prepare to publish failed search result",
|
||||
zap.Int64("collectionID", sm.CollectionID),
|
||||
zap.Int64("msgID", sm.ID()),
|
||||
)
|
||||
s.publishFailedSearchResult(sm, err.Error())
|
||||
q.publishFailedSearchResult(sm, err.Error())
|
||||
}
|
||||
sp.Finish()
|
||||
log.Debug("do search done in doUnsolvedMsgSearch",
|
||||
|
@ -637,14 +726,14 @@ func translateHits(schema *typeutil.SchemaHelper, fieldIDs []int64, rawHits [][]
|
|||
|
||||
// TODO:: cache map[dsl]plan
|
||||
// TODO: reBatched search requests
|
||||
func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
||||
func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error {
|
||||
sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
|
||||
defer sp.Finish()
|
||||
searchMsg.SetTraceCtx(ctx)
|
||||
searchTimestamp := searchMsg.SearchRequest.TravelTimestamp
|
||||
|
||||
collectionID := searchMsg.CollectionID
|
||||
collection, err := s.streaming.replica.getCollectionByID(collectionID)
|
||||
collection, err := q.streaming.replica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -691,7 +780,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
sealedSegmentSearched := make([]UniqueID, 0)
|
||||
|
||||
// historical search
|
||||
hisSearchResults, hisSegmentResults, err1 := s.historical.search(searchRequests, collectionID, searchMsg.PartitionIDs, plan, searchTimestamp)
|
||||
hisSearchResults, hisSegmentResults, err1 := q.historical.search(searchRequests, collectionID, searchMsg.PartitionIDs, plan, searchTimestamp)
|
||||
if err1 != nil {
|
||||
log.Error(err1.Error())
|
||||
return err1
|
||||
|
@ -707,7 +796,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
for _, channel := range collection.getVChannels() {
|
||||
var strSearchResults []*SearchResult
|
||||
var strSegmentResults []*Segment
|
||||
strSearchResults, strSegmentResults, err2 = s.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, searchTimestamp)
|
||||
strSearchResults, strSegmentResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, searchTimestamp)
|
||||
if err2 != nil {
|
||||
log.Error(err2.Error())
|
||||
return err2
|
||||
|
@ -772,7 +861,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
zap.Any("vChannels", collection.getVChannels()),
|
||||
zap.Any("sealedSegmentSearched", sealedSegmentSearched),
|
||||
)
|
||||
err = s.publishSearchResult(searchResultMsg, searchMsg.CollectionID)
|
||||
err = q.publishSearchResult(searchResultMsg, searchMsg.CollectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -892,7 +981,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
// fmt.Println(testHits.IDs)
|
||||
// fmt.Println(testHits.Scores)
|
||||
//}
|
||||
err = s.publishSearchResult(searchResultMsg, searchMsg.CollectionID)
|
||||
err = q.publishSearchResult(searchResultMsg, searchMsg.CollectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -907,7 +996,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *searchCollection) publishSearchResult(msg msgstream.TsMsg, collectionID UniqueID) error {
|
||||
func (q *queryCollection) publishSearchResult(msg msgstream.TsMsg, collectionID UniqueID) error {
|
||||
log.Debug("publishing search result...",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
|
@ -917,7 +1006,7 @@ func (s *searchCollection) publishSearchResult(msg msgstream.TsMsg, collectionID
|
|||
msg.SetTraceCtx(ctx)
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err := s.searchResultMsgStream.Produce(&msgPack)
|
||||
err := q.queryResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
log.Error("publishing search result failed, err = "+err.Error(),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
|
@ -932,7 +1021,7 @@ func (s *searchCollection) publishSearchResult(msg msgstream.TsMsg, collectionID
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *searchCollection) publishFailedSearchResult(searchMsg *msgstream.SearchMsg, errMsg string) {
|
||||
func (q *queryCollection) publishFailedSearchResult(searchMsg *msgstream.SearchMsg, errMsg string) {
|
||||
span, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
|
||||
defer span.Finish()
|
||||
searchMsg.SetTraceCtx(ctx)
|
||||
|
@ -955,8 +1044,299 @@ func (s *searchCollection) publishFailedSearchResult(searchMsg *msgstream.Search
|
|||
}
|
||||
|
||||
msgPack.Msgs = append(msgPack.Msgs, searchResultMsg)
|
||||
err := s.searchResultMsgStream.Produce(&msgPack)
|
||||
err := q.queryResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
log.Error("publish FailedSearchResult failed" + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (q *queryCollection) doUnsolvedMsgRetrieve() {
|
||||
log.Debug("starting doUnsolvedMsgRetrieve...", zap.Any("collectionID", q.collectionID))
|
||||
for {
|
||||
select {
|
||||
case <-q.releaseCtx.Done():
|
||||
log.Debug("stop retrieveCollection's doUnsolvedMsgRertieve", zap.Int64("collectionID", q.collectionID))
|
||||
return
|
||||
default:
|
||||
//time.Sleep(10 * time.Millisecond)
|
||||
serviceTime := q.waitNewTSafe()
|
||||
st, _ := tsoutil.ParseTS(serviceTime)
|
||||
log.Debug("get tSafe from flow graph",
|
||||
zap.Int64("collectionID", q.collectionID),
|
||||
zap.Any("tSafe", st))
|
||||
|
||||
q.setServiceableTime(serviceTime)
|
||||
//log.Debug("query node::doUnsolvedMsgSearch: setServiceableTime",
|
||||
// zap.Any("serviceTime", st),
|
||||
//)
|
||||
|
||||
retrieveMsg := make([]*msgstream.RetrieveMsg, 0)
|
||||
tempMsg := q.popAllUnsolvedRetrieveMsg()
|
||||
|
||||
for _, rm := range tempMsg {
|
||||
bt, _ := tsoutil.ParseTS(rm.EndTs())
|
||||
st, _ = tsoutil.ParseTS(serviceTime)
|
||||
log.Debug("get retrieve message from unsolvedMsg",
|
||||
zap.Int64("collectionID", rm.CollectionID),
|
||||
zap.Int64("msgID", rm.ID()),
|
||||
zap.Any("reqTime_p", bt),
|
||||
zap.Any("serviceTime_p", st),
|
||||
zap.Any("reqTime_l", rm.EndTs()),
|
||||
zap.Any("serviceTime_l", serviceTime),
|
||||
)
|
||||
if rm.EndTs() <= serviceTime {
|
||||
retrieveMsg = append(retrieveMsg, rm)
|
||||
continue
|
||||
}
|
||||
log.Debug("query node::doUnsolvedMsgRetrieve: add to unsolvedMsg",
|
||||
zap.Any("collectionID", q.collectionID),
|
||||
zap.Any("sm.BeginTs", bt),
|
||||
zap.Any("serviceTime", st),
|
||||
zap.Any("delta seconds", (rm.BeginTs()-serviceTime)/(1000*1000*1000)),
|
||||
zap.Any("msgID", rm.ID()),
|
||||
)
|
||||
q.addToUnsolvedRetrieveMsg(rm)
|
||||
}
|
||||
|
||||
if len(retrieveMsg) <= 0 {
|
||||
continue
|
||||
}
|
||||
for _, rm := range retrieveMsg {
|
||||
sp, ctx := trace.StartSpanFromContext(rm.TraceCtx())
|
||||
rm.SetTraceCtx(ctx)
|
||||
log.Debug("doing search in doUnsolvedMsgRetrieve...",
|
||||
zap.Int64("collectionID", rm.CollectionID),
|
||||
zap.Int64("msgID", rm.ID()),
|
||||
)
|
||||
err := q.retrieve(rm)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
log.Debug("do retrieve failed in doUnsolvedMsgSearch, prepare to publish failed retrieve result",
|
||||
zap.Int64("collectionID", rm.CollectionID),
|
||||
zap.Int64("msgID", rm.ID()),
|
||||
)
|
||||
q.publishFailedRetrieveResult(rm, err.Error())
|
||||
}
|
||||
sp.Finish()
|
||||
log.Debug("do retrieve done in doUnsolvedMsgSearch",
|
||||
zap.Int64("collectionID", rm.CollectionID),
|
||||
zap.Int64("msgID", rm.ID()),
|
||||
)
|
||||
}
|
||||
log.Debug("doUnsolvedMsgRetrieve, do retrieve done", zap.Int("num of retrieveMsg", len(retrieveMsg)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
|
||||
// TODO(yukun)
|
||||
// step 1: get retrieve object and defer destruction
|
||||
// step 2: for each segment, call retrieve to get ids proto buffer
|
||||
// step 3: merge all proto in go
|
||||
// step 4: publish results
|
||||
// retrieveProtoBlob, err := proto.Marshal(&retrieveMsg.RetrieveRequest)
|
||||
sp, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx())
|
||||
defer sp.Finish()
|
||||
retrieveMsg.SetTraceCtx(ctx)
|
||||
timestamp := retrieveMsg.RetrieveRequest.TravelTimestamp
|
||||
|
||||
collectionID := retrieveMsg.CollectionID
|
||||
collection, err := q.streaming.replica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &segcorepb.RetrieveRequest{
|
||||
Ids: retrieveMsg.Ids,
|
||||
OutputFields: retrieveMsg.OutputFields,
|
||||
}
|
||||
|
||||
plan, err := createRetrievePlan(collection, req, timestamp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer plan.delete()
|
||||
|
||||
var partitionIDsInHistorical []UniqueID
|
||||
var partitionIDsInStreaming []UniqueID
|
||||
partitionIDsInQuery := retrieveMsg.PartitionIDs
|
||||
if len(partitionIDsInQuery) == 0 {
|
||||
partitionIDsInHistoricalCol, err1 := q.historical.replica.getPartitionIDs(collectionID)
|
||||
partitionIDsInStreamingCol, err2 := q.streaming.replica.getPartitionIDs(collectionID)
|
||||
if err1 != nil && err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if len(partitionIDsInHistoricalCol) == 0 {
|
||||
return errors.New("none of this collection's partition has been loaded")
|
||||
}
|
||||
partitionIDsInHistorical = partitionIDsInHistoricalCol
|
||||
partitionIDsInStreaming = partitionIDsInStreamingCol
|
||||
} else {
|
||||
for _, id := range partitionIDsInQuery {
|
||||
_, err1 := q.historical.replica.getPartitionByID(id)
|
||||
if err1 == nil {
|
||||
partitionIDsInHistorical = append(partitionIDsInHistorical, id)
|
||||
}
|
||||
_, err2 := q.streaming.replica.getPartitionByID(id)
|
||||
if err2 == nil {
|
||||
partitionIDsInStreaming = append(partitionIDsInStreaming, id)
|
||||
}
|
||||
if err1 != nil && err2 != nil {
|
||||
return err2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealedSegmentRetrieved := make([]UniqueID, 0)
|
||||
var mergeList []*segcorepb.RetrieveResults
|
||||
for _, partitionID := range partitionIDsInHistorical {
|
||||
segmentIDs, err := q.historical.replica.getSegmentIDs(partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, segmentID := range segmentIDs {
|
||||
segment, err := q.historical.replica.getSegmentByID(segmentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := segment.segmentGetEntityByIds(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mergeList = append(mergeList, result)
|
||||
sealedSegmentRetrieved = append(sealedSegmentRetrieved, segmentID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, partitionID := range partitionIDsInStreaming {
|
||||
segmentIDs, err := q.streaming.replica.getSegmentIDs(partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, segmentID := range segmentIDs {
|
||||
segment, err := q.streaming.replica.getSegmentByID(segmentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := segment.segmentGetEntityByIds(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mergeList = append(mergeList, result)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := mergeRetrieveResults(mergeList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resultChannelInt := 0
|
||||
retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{Ctx: retrieveMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}},
|
||||
RetrieveResults: internalpb.RetrieveResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_RetrieveResult,
|
||||
MsgID: retrieveMsg.Base.MsgID,
|
||||
SourceID: retrieveMsg.Base.SourceID,
|
||||
},
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Ids: result.Ids,
|
||||
FieldsData: result.FieldsData,
|
||||
ResultChannelID: retrieveMsg.ResultChannelID,
|
||||
SealedSegmentIDsRetrieved: sealedSegmentRetrieved,
|
||||
ChannelIDsRetrieved: collection.getPChannels(),
|
||||
//TODO(yukun):: get global sealed segment from etcd
|
||||
GlobalSealedSegmentIDs: sealedSegmentRetrieved,
|
||||
},
|
||||
}
|
||||
log.Debug("QueryNode RetrieveResultMsg",
|
||||
zap.Any("pChannels", collection.getPChannels()),
|
||||
zap.Any("collectionID", collection.ID()),
|
||||
zap.Any("sealedSegmentRetrieved", sealedSegmentRetrieved),
|
||||
)
|
||||
err3 := q.publishRetrieveResult(retrieveResultMsg, retrieveMsg.CollectionID)
|
||||
if err3 != nil {
|
||||
return err3
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
|
||||
var final *segcorepb.RetrieveResults
|
||||
for _, data := range dataArr {
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if final == nil {
|
||||
final = proto.Clone(data).(*segcorepb.RetrieveResults)
|
||||
continue
|
||||
}
|
||||
|
||||
proto.Merge(final.Ids, data.Ids)
|
||||
if len(final.FieldsData) != len(data.FieldsData) {
|
||||
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
|
||||
}
|
||||
|
||||
for i := range final.FieldsData {
|
||||
proto.Merge(final.FieldsData[i], data.FieldsData[i])
|
||||
}
|
||||
}
|
||||
|
||||
return final, nil
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishRetrieveResult(msg msgstream.TsMsg, collectionID UniqueID) error {
|
||||
log.Debug("publishing retrieve result...",
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
defer span.Finish()
|
||||
msg.SetTraceCtx(ctx)
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err := q.queryResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
} else {
|
||||
log.Debug("publish retrieve result done",
|
||||
zap.Int64("msgID", msg.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishFailedRetrieveResult(retrieveMsg *msgstream.RetrieveMsg, errMsg string) error {
|
||||
span, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx())
|
||||
defer span.Finish()
|
||||
retrieveMsg.SetTraceCtx(ctx)
|
||||
msgPack := msgstream.MsgPack{}
|
||||
|
||||
resultChannelInt := 0
|
||||
retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}},
|
||||
RetrieveResults: internalpb.RetrieveResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_RetrieveResult,
|
||||
MsgID: retrieveMsg.Base.MsgID,
|
||||
Timestamp: retrieveMsg.Base.Timestamp,
|
||||
SourceID: retrieveMsg.Base.SourceID,
|
||||
},
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
ResultChannelID: retrieveMsg.ResultChannelID,
|
||||
Ids: nil,
|
||||
FieldsData: nil,
|
||||
},
|
||||
}
|
||||
|
||||
msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg)
|
||||
err := q.queryResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -58,8 +58,7 @@ type QueryNode struct {
|
|||
streaming *streaming
|
||||
|
||||
// internal services
|
||||
searchService *searchService
|
||||
retrieveService *retrieveService
|
||||
queryService *queryService
|
||||
|
||||
// clients
|
||||
queryCoord types.QueryCoord
|
||||
|
@ -83,8 +82,7 @@ func NewQueryNode(ctx context.Context, queryNodeID UniqueID, factory msgstream.F
|
|||
queryNodeLoopCtx: ctx1,
|
||||
queryNodeLoopCancel: cancel,
|
||||
QueryNodeID: queryNodeID,
|
||||
searchService: nil,
|
||||
retrieveService: nil,
|
||||
queryService: nil,
|
||||
msFactory: factory,
|
||||
}
|
||||
|
||||
|
@ -99,8 +97,7 @@ func NewQueryNodeWithoutID(ctx context.Context, factory msgstream.Factory) *Quer
|
|||
node := &QueryNode{
|
||||
queryNodeLoopCtx: ctx1,
|
||||
queryNodeLoopCancel: cancel,
|
||||
searchService: nil,
|
||||
retrieveService: nil,
|
||||
queryService: nil,
|
||||
msFactory: factory,
|
||||
}
|
||||
|
||||
|
@ -216,22 +213,15 @@ func (node *QueryNode) Start() error {
|
|||
|
||||
// init services and manager
|
||||
// TODO: pass node.streaming.replica to search service
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx,
|
||||
node.queryService = newQueryService(node.queryNodeLoopCtx,
|
||||
node.historical,
|
||||
node.streaming,
|
||||
node.msFactory)
|
||||
|
||||
node.retrieveService = newRetrieveService(node.queryNodeLoopCtx,
|
||||
node.historical,
|
||||
node.streaming,
|
||||
node.msFactory,
|
||||
)
|
||||
|
||||
// start task scheduler
|
||||
go node.scheduler.Start()
|
||||
|
||||
// start services
|
||||
go node.retrieveService.start()
|
||||
go node.historical.start()
|
||||
node.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
return nil
|
||||
|
@ -248,11 +238,11 @@ func (node *QueryNode) Stop() error {
|
|||
if node.streaming != nil {
|
||||
node.streaming.close()
|
||||
}
|
||||
if node.searchService != nil {
|
||||
node.searchService.close()
|
||||
if node.queryService != nil {
|
||||
node.queryService.close()
|
||||
}
|
||||
if node.retrieveService != nil {
|
||||
node.retrieveService.close()
|
||||
if node.queryService != nil {
|
||||
node.queryService.close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package querynode
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
)
|
||||
|
||||
type queryService struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
historical *historical
|
||||
streaming *streaming
|
||||
|
||||
queryNodeID UniqueID
|
||||
queryCollections map[UniqueID]*queryCollection
|
||||
|
||||
factory msgstream.Factory
|
||||
}
|
||||
|
||||
func newQueryService(ctx context.Context,
|
||||
historical *historical,
|
||||
streaming *streaming,
|
||||
factory msgstream.Factory) *queryService {
|
||||
|
||||
queryServiceCtx, queryServiceCancel := context.WithCancel(ctx)
|
||||
return &queryService{
|
||||
ctx: queryServiceCtx,
|
||||
cancel: queryServiceCancel,
|
||||
|
||||
historical: historical,
|
||||
streaming: streaming,
|
||||
|
||||
queryNodeID: Params.QueryNodeID,
|
||||
queryCollections: make(map[UniqueID]*queryCollection),
|
||||
|
||||
factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queryService) close() {
|
||||
log.Debug("search service closed")
|
||||
for collectionID := range q.queryCollections {
|
||||
q.stopQueryCollection(collectionID)
|
||||
}
|
||||
q.queryCollections = make(map[UniqueID]*queryCollection)
|
||||
q.cancel()
|
||||
}
|
||||
|
||||
func (q *queryService) addQueryCollection(collectionID UniqueID) {
|
||||
if _, ok := q.queryCollections[collectionID]; ok {
|
||||
log.Warn("query collection already exists", zap.Any("collectionID", collectionID))
|
||||
return
|
||||
}
|
||||
|
||||
ctx1, cancel := context.WithCancel(q.ctx)
|
||||
qc := newQueryCollection(ctx1,
|
||||
cancel,
|
||||
collectionID,
|
||||
q.historical,
|
||||
q.streaming,
|
||||
q.factory)
|
||||
q.queryCollections[collectionID] = qc
|
||||
}
|
||||
|
||||
func (q *queryService) hasQueryCollection(collectionID UniqueID) bool {
|
||||
_, ok := q.queryCollections[collectionID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (q *queryService) stopQueryCollection(collectionID UniqueID) {
|
||||
sc, ok := q.queryCollections[collectionID]
|
||||
if !ok {
|
||||
log.Error("stopQueryCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))
|
||||
return
|
||||
}
|
||||
sc.close()
|
||||
sc.cancel()
|
||||
delete(q.queryCollections, collectionID)
|
||||
}
|
|
@ -141,11 +141,11 @@ func TestSearch_Search(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// start search service
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx,
|
||||
node.queryService = newQueryService(node.queryNodeLoopCtx,
|
||||
node.historical,
|
||||
node.streaming,
|
||||
msFactory)
|
||||
node.searchService.addSearchCollection(collectionID)
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
|
||||
// load segment
|
||||
err = node.historical.replica.addSegment(segmentID, defaultPartitionID, collectionID, "", segmentTypeSealed, true)
|
||||
|
@ -179,11 +179,11 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// start search service
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx,
|
||||
node.queryService = newQueryService(node.queryNodeLoopCtx,
|
||||
node.historical,
|
||||
node.streaming,
|
||||
msFactory)
|
||||
node.searchService.addSearchCollection(collectionID)
|
||||
node.queryService.addQueryCollection(collectionID)
|
||||
|
||||
// load segments
|
||||
err = node.historical.replica.addSegment(segmentID1, defaultPartitionID, collectionID, "", segmentTypeSealed, true)
|
|
@ -14,12 +14,10 @@ package querynode
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
oplog "github.com/opentracing/opentracing-go/log"
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -317,31 +315,6 @@ func (rc *retrieveCollection) doUnsolvedMsgRetrieve() {
|
|||
}
|
||||
}
|
||||
|
||||
func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
|
||||
var final *segcorepb.RetrieveResults
|
||||
for _, data := range dataArr {
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if final == nil {
|
||||
final = proto.Clone(data).(*segcorepb.RetrieveResults)
|
||||
continue
|
||||
}
|
||||
|
||||
proto.Merge(final.Ids, data.Ids)
|
||||
if len(final.FieldsData) != len(data.FieldsData) {
|
||||
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
|
||||
}
|
||||
|
||||
for i := range final.FieldsData {
|
||||
proto.Merge(final.FieldsData[i], data.FieldsData[i])
|
||||
}
|
||||
}
|
||||
|
||||
return final, nil
|
||||
}
|
||||
|
||||
func (rc *retrieveCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
|
||||
// TODO(yukun)
|
||||
// step 1: get retrieve object and defer destruction
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package querynode
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
)
|
||||
|
||||
type searchService struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
historical *historical
|
||||
streaming *streaming
|
||||
|
||||
queryNodeID UniqueID
|
||||
searchCollections map[UniqueID]*searchCollection
|
||||
|
||||
factory msgstream.Factory
|
||||
}
|
||||
|
||||
func newSearchService(ctx context.Context,
|
||||
historical *historical,
|
||||
streaming *streaming,
|
||||
factory msgstream.Factory) *searchService {
|
||||
|
||||
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
|
||||
return &searchService{
|
||||
ctx: searchServiceCtx,
|
||||
cancel: searchServiceCancel,
|
||||
|
||||
historical: historical,
|
||||
streaming: streaming,
|
||||
|
||||
queryNodeID: Params.QueryNodeID,
|
||||
searchCollections: make(map[UniqueID]*searchCollection),
|
||||
|
||||
factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *searchService) close() {
|
||||
log.Debug("search service closed")
|
||||
for collectionID := range s.searchCollections {
|
||||
s.stopSearchCollection(collectionID)
|
||||
}
|
||||
s.searchCollections = make(map[UniqueID]*searchCollection)
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func (s *searchService) addSearchCollection(collectionID UniqueID) {
|
||||
if _, ok := s.searchCollections[collectionID]; ok {
|
||||
log.Warn("search collection already exists", zap.Any("collectionID", collectionID))
|
||||
return
|
||||
}
|
||||
|
||||
ctx1, cancel := context.WithCancel(s.ctx)
|
||||
sc := newSearchCollection(ctx1,
|
||||
cancel,
|
||||
collectionID,
|
||||
s.historical,
|
||||
s.streaming,
|
||||
s.factory)
|
||||
s.searchCollections[collectionID] = sc
|
||||
}
|
||||
|
||||
func (s *searchService) hasSearchCollection(collectionID UniqueID) bool {
|
||||
_, ok := s.searchCollections[collectionID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *searchService) stopSearchCollection(collectionID UniqueID) {
|
||||
sc, ok := s.searchCollections[collectionID]
|
||||
if !ok {
|
||||
log.Error("stopSearchCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))
|
||||
return
|
||||
}
|
||||
sc.close()
|
||||
sc.cancel()
|
||||
delete(s.searchCollections, collectionID)
|
||||
}
|
|
@ -412,7 +412,7 @@ func (r *releaseCollectionTask) Execute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
r.node.streaming.replica.removeExcludedSegments(r.req.CollectionID)
|
||||
r.node.searchService.stopSearchCollection(r.req.CollectionID)
|
||||
r.node.queryService.stopQueryCollection(r.req.CollectionID)
|
||||
|
||||
hasCollectionInHistorical := r.node.historical.replica.hasCollection(r.req.CollectionID)
|
||||
if hasCollectionInHistorical {
|
||||
|
|
Loading…
Reference in New Issue