mirror of https://github.com/milvus-io/milvus.git
Make query node asynchronously load and release collection
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/4973/head^2
parent
ce7a5ea699
commit
9e7559b865
|
@ -3,14 +3,11 @@ package querynode
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
|
@ -141,83 +138,38 @@ func (node *QueryNode) RemoveQueryChannel(ctx context.Context, in *queryPb.Remov
|
|||
}
|
||||
|
||||
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
log.Debug("starting WatchDmChannels ...", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs)))
|
||||
collectionID := in.CollectionID
|
||||
ds, err := node.getDataSyncService(collectionID)
|
||||
if err != nil || ds.dmStream == nil {
|
||||
errMsg := "null data sync service or null data manipulation stream, collectionID = " + fmt.Sprintln(collectionID)
|
||||
dct := &watchDmChannelsTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error),
|
||||
},
|
||||
req: in,
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("watchDmChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
switch t := ds.dmStream.(type) {
|
||||
case *msgstream.MqTtMsgStream:
|
||||
default:
|
||||
_ = t
|
||||
errMsg := "type assertion failed for dm message stream"
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
getUniqueSubName := func() string {
|
||||
prefixName := Params.MsgChannelSubName
|
||||
return prefixName + "-" + strconv.FormatInt(collectionID, 10)
|
||||
}
|
||||
|
||||
// add request channel
|
||||
consumeChannels := in.ChannelIDs
|
||||
toSeekInfo := make([]*internalpb.MsgPosition, 0)
|
||||
toDirSubChannels := make([]string, 0)
|
||||
|
||||
consumeSubName := getUniqueSubName()
|
||||
|
||||
for _, info := range in.Infos {
|
||||
if len(info.Pos.MsgID) == 0 {
|
||||
toDirSubChannels = append(toDirSubChannels, info.ChannelID)
|
||||
continue
|
||||
}
|
||||
info.Pos.MsgGroup = consumeSubName
|
||||
toSeekInfo = append(toSeekInfo, info.Pos)
|
||||
|
||||
log.Debug("prevent inserting segments", zap.String("segmentIDs", fmt.Sprintln(info.ExcludedSegments)))
|
||||
err := node.replica.addExcludedSegments(collectionID, info.ExcludedSegments)
|
||||
go func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName)
|
||||
for _, pos := range toSeekInfo {
|
||||
err := ds.dmStream.Seek(pos)
|
||||
if err != nil {
|
||||
errMsg := "msgStream seek error :" + err.Error()
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
}
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
log.Debug("watchDmChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
}()
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
log.Debug("WatchDmChannels done", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs)))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
|
@ -242,16 +194,14 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
|
|||
}
|
||||
log.Debug("loadSegmentsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
go func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("loadSegmentsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
log.Debug("loadSegmentsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
}()
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -280,16 +230,14 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
|
|||
}
|
||||
log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
go func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
}()
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -318,16 +266,14 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
|
|||
}
|
||||
log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
go func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
}()
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
|
|
@ -98,13 +98,13 @@ func (s *searchService) consumeSearch() {
|
|||
zap.Int64("collectionID", sm.CollectionID))
|
||||
continue
|
||||
}
|
||||
sc, ok := s.searchCollections[sm.CollectionID]
|
||||
_, ok = s.searchCollections[sm.CollectionID]
|
||||
if !ok {
|
||||
s.startSearchCollection(sm.CollectionID)
|
||||
log.Debug("new search collection, start search collection service",
|
||||
zap.Int64("collectionID", sm.CollectionID))
|
||||
}
|
||||
sc.msgBuffer <- sm
|
||||
s.searchCollections[sm.CollectionID].msgBuffer <- sm
|
||||
sp.Finish()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,10 +5,14 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
)
|
||||
|
||||
|
@ -30,6 +34,12 @@ type baseTask struct {
|
|||
id UniqueID
|
||||
}
|
||||
|
||||
type watchDmChannelsTask struct {
|
||||
baseTask
|
||||
req *queryPb.WatchDmChannelsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
type loadSegmentsTask struct {
|
||||
baseTask
|
||||
req *queryPb.LoadSegmentsRequest
|
||||
|
@ -57,20 +67,107 @@ func (b *baseTask) SetID(uid UniqueID) {
|
|||
}
|
||||
|
||||
func (b *baseTask) WaitToFinish() error {
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
return errors.New("task timeout")
|
||||
case err := <-b.done:
|
||||
return err
|
||||
}
|
||||
err := <-b.done
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *baseTask) Notify(err error) {
|
||||
b.done <- err
|
||||
}
|
||||
|
||||
// watchDmChannelsTask
|
||||
func (w *watchDmChannelsTask) Timestamp() Timestamp {
|
||||
if w.req.Base == nil {
|
||||
log.Error("nil base req in watchDmChannelsTask", zap.Any("collectionID", w.req.CollectionID))
|
||||
return 0
|
||||
}
|
||||
return w.req.Base.Timestamp
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) OnEnqueue() error {
|
||||
if w.req == nil || w.req.Base == nil {
|
||||
w.SetID(rand.Int63n(100000000000))
|
||||
} else {
|
||||
w.SetID(w.req.Base.MsgID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
|
||||
log.Debug("starting WatchDmChannels ...", zap.String("ChannelIDs", fmt.Sprintln(w.req.ChannelIDs)))
|
||||
collectionID := w.req.CollectionID
|
||||
ds, err := w.node.getDataSyncService(collectionID)
|
||||
if err != nil || ds.dmStream == nil {
|
||||
errMsg := "null data sync service or null data manipulation stream, collectionID = " + fmt.Sprintln(collectionID)
|
||||
log.Error(errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
switch t := ds.dmStream.(type) {
|
||||
case *msgstream.MqTtMsgStream:
|
||||
default:
|
||||
_ = t
|
||||
errMsg := "type assertion failed for dm message stream"
|
||||
log.Error(errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
getUniqueSubName := func() string {
|
||||
prefixName := Params.MsgChannelSubName
|
||||
return prefixName + "-" + strconv.FormatInt(collectionID, 10)
|
||||
}
|
||||
|
||||
// add request channel
|
||||
consumeChannels := w.req.ChannelIDs
|
||||
toSeekInfo := make([]*internalpb.MsgPosition, 0)
|
||||
toDirSubChannels := make([]string, 0)
|
||||
|
||||
consumeSubName := getUniqueSubName()
|
||||
|
||||
for _, info := range w.req.Infos {
|
||||
if len(info.Pos.MsgID) == 0 {
|
||||
toDirSubChannels = append(toDirSubChannels, info.ChannelID)
|
||||
continue
|
||||
}
|
||||
info.Pos.MsgGroup = consumeSubName
|
||||
toSeekInfo = append(toSeekInfo, info.Pos)
|
||||
|
||||
log.Debug("prevent inserting segments", zap.String("segmentIDs", fmt.Sprintln(info.ExcludedSegments)))
|
||||
err := w.node.replica.addExcludedSegments(collectionID, info.ExcludedSegments)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName)
|
||||
for _, pos := range toSeekInfo {
|
||||
err := ds.dmStream.Seek(pos)
|
||||
if err != nil {
|
||||
errMsg := "msgStream seek error :" + err.Error()
|
||||
log.Error(errMsg)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
}
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
log.Debug("WatchDmChannels done", zap.String("ChannelIDs", fmt.Sprintln(w.req.ChannelIDs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadSegmentsTask
|
||||
func (l *loadSegmentsTask) Timestamp() Timestamp {
|
||||
if l.req.Base == nil {
|
||||
log.Error("nil base req in loadSegmentsTask", zap.Any("collectionID", l.req.CollectionID))
|
||||
return 0
|
||||
}
|
||||
return l.req.Base.Timestamp
|
||||
}
|
||||
|
||||
|
@ -146,6 +243,10 @@ func (l *loadSegmentsTask) PostExecute(ctx context.Context) error {
|
|||
|
||||
// releaseCollectionTask
|
||||
func (r *releaseCollectionTask) Timestamp() Timestamp {
|
||||
if r.req.Base == nil {
|
||||
log.Error("nil base req in releaseCollectionTask", zap.Any("collectionID", r.req.CollectionID))
|
||||
return 0
|
||||
}
|
||||
return r.req.Base.Timestamp
|
||||
}
|
||||
|
||||
|
@ -190,6 +291,10 @@ func (r *releaseCollectionTask) PostExecute(ctx context.Context) error {
|
|||
|
||||
// releasePartitionsTask
|
||||
func (r *releasePartitionsTask) Timestamp() Timestamp {
|
||||
if r.req.Base == nil {
|
||||
log.Error("nil base req in releasePartitionsTask", zap.Any("collectionID", r.req.CollectionID))
|
||||
return 0
|
||||
}
|
||||
return r.req.Base.Timestamp
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ func (s *taskScheduler) loadAndReleaseLoop() {
|
|||
case <-s.queue.utChan():
|
||||
if !s.queue.utEmpty() {
|
||||
t := s.queue.PopUnissuedTask()
|
||||
go s.processTask(t, s.queue)
|
||||
s.processTask(t, s.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -353,6 +353,11 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
|
|||
segmentIDs := showSegmentResponse.SegmentIDs
|
||||
if len(segmentIDs) == 0 {
|
||||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
// TODO: use unique id allocator to assign reqID
|
||||
Base: &commonpb.MsgBase{
|
||||
Timestamp: lpt.Base.Timestamp,
|
||||
MsgID: rand.Int63n(10000000000),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
Schema: schema,
|
||||
|
@ -447,7 +452,8 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
|
|||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
// TODO: use unique id allocator to assign reqID
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63n(10000000000),
|
||||
Timestamp: lpt.Base.Timestamp,
|
||||
MsgID: rand.Int63n(10000000000),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
|
|
Loading…
Reference in New Issue