Refactor load Release to async call in query service

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/4973/head^2
xige-16 2021-04-15 15:15:46 +08:00 committed by yefu.chen
parent e4256a4400
commit 78155d3959
6 changed files with 1547 additions and 777 deletions

View File

@ -654,7 +654,7 @@ func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error {
consumer, hasWatched = ms.consumers[seekChannel]
if hasWatched {
return errors.New("the channel should has been subscribed")
return errors.New("the channel should has not been subscribed")
}
fn := func() error {

View File

@ -0,0 +1,43 @@
package queryservice
import (
"context"
"errors"
)
type Condition interface {
WaitToFinish() error
Notify(err error)
Ctx() context.Context
}
type TaskCondition struct {
done chan error
ctx context.Context
}
func (tc *TaskCondition) WaitToFinish() error {
for {
select {
case <-tc.ctx.Done():
return errors.New("timeout")
case err := <-tc.done:
return err
}
}
}
func (tc *TaskCondition) Notify(err error) {
tc.done <- err
}
func (tc *TaskCondition) Ctx() context.Context {
return tc.ctx
}
func NewTaskCondition(ctx context.Context) *TaskCondition {
return &TaskCondition{
done: make(chan error, 1),
ctx: ctx,
}
}

View File

@ -0,0 +1,484 @@
package queryservice
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"go.uber.org/zap"
nodeclient "github.com/zilliztech/milvus-distributed/internal/distributed/querynode/client"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
"github.com/zilliztech/milvus-distributed/internal/util/retry"
)
func (qs *QueryService) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
serviceComponentInfo := &internalpb.ComponentInfo{
NodeID: Params.QueryServiceID,
StateCode: qs.stateCode.Load().(internalpb.StateCode),
}
subComponentInfos := make([]*internalpb.ComponentInfo, 0)
for nodeID, node := range qs.queryNodes {
componentStates, err := node.GetComponentStates(ctx)
if err != nil {
subComponentInfos = append(subComponentInfos, &internalpb.ComponentInfo{
NodeID: nodeID,
StateCode: internalpb.StateCode_Abnormal,
})
continue
}
subComponentInfos = append(subComponentInfos, componentStates.State)
}
return &internalpb.ComponentStates{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
State: serviceComponentInfo,
SubcomponentStates: subComponentInfos,
}, nil
}
func (qs *QueryService) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: Params.TimeTickChannelName,
}, nil
}
func (qs *QueryService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: Params.StatsChannelName,
}, nil
}
func (qs *QueryService) RegisterNode(ctx context.Context, req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) {
log.Debug("register query node", zap.String("address", req.Address.String()))
// TODO:: add mutex
nodeID := req.Base.SourceID
if _, ok := qs.queryNodes[nodeID]; ok {
err := errors.New("nodeID already exists")
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: err.Error(),
},
}, err
}
registerNodeAddress := req.Address.Ip + ":" + strconv.FormatInt(req.Address.Port, 10)
client := nodeclient.NewClient(registerNodeAddress)
if err := client.Init(); err != nil {
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
InitParams: new(internalpb.InitParams),
}, err
}
if err := client.Start(); err != nil {
return nil, err
}
qs.queryNodes[nodeID] = newQueryNodeInfo(client)
//TODO::return init params to queryNode
startParams := []*commonpb.KeyValuePair{
{Key: "StatsChannelName", Value: Params.StatsChannelName},
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
}
qs.qcMutex.Lock()
for _, queryChannel := range qs.queryChannels {
startParams = append(startParams, &commonpb.KeyValuePair{
Key: "QueryChannelName",
Value: queryChannel.requestChannel,
})
startParams = append(startParams, &commonpb.KeyValuePair{
Key: "QueryResultChannelName",
Value: queryChannel.responseChannel,
})
}
qs.qcMutex.Unlock()
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
InitParams: &internalpb.InitParams{
NodeID: nodeID,
StartParams: startParams,
},
}, nil
}
func (qs *QueryService) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
dbID := req.DbID
log.Debug("show collection start, dbID = ", zap.String("dbID", strconv.FormatInt(dbID, 10)))
collections, err := qs.replica.getCollections(dbID)
collectionIDs := make([]UniqueID, 0)
for _, collection := range collections {
collectionIDs = append(collectionIDs, collection.id)
}
if err != nil {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: err.Error(),
},
}, err
}
log.Debug("show collection end")
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: collectionIDs,
}, nil
}
func (qs *QueryService) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
dbID := req.DbID
collectionID := req.CollectionID
schema := req.Schema
watchNeeded := false
log.Debug("LoadCollectionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID),
zap.Stringer("schema", req.Schema))
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}
_, err := qs.replica.getCollectionByID(dbID, collectionID)
if err != nil {
watchNeeded = true
err = qs.replica.addCollection(dbID, collectionID, schema)
if err != nil {
status.Reason = err.Error()
return status, err
}
}
loadCtx, cancel := context.WithTimeout(qs.loopCtx, 30*time.Second)
loadCollectionTask := &LoadCollectionTask{
BaseTask: BaseTask{
ctx: loadCtx,
cancel: cancel,
Condition: NewTaskCondition(loadCtx),
},
LoadCollectionRequest: req,
masterService: qs.masterServiceClient,
dataService: qs.dataServiceClient,
queryNodes: qs.queryNodes,
meta: qs.replica,
watchNeeded: watchNeeded,
}
err = qs.sched.DdQueue.Enqueue(loadCollectionTask)
if err != nil {
status.Reason = err.Error()
return status, err
}
err = loadCollectionTask.WaitToFinish()
if err != nil {
status.Reason = err.Error()
return status, err
}
log.Debug("LoadCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
status.ErrorCode = commonpb.ErrorCode_Success
return status, nil
}
func (qs *QueryService) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
dbID := req.DbID
collectionID := req.CollectionID
log.Debug("ReleaseCollectionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
_, err := qs.replica.getCollectionByID(dbID, collectionID)
if err != nil {
log.Error("release collection end, query service don't have the log of", zap.String("collectionID", fmt.Sprintln(collectionID)))
return status, nil
}
releaseCollectionTask := &ReleaseCollectionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
},
ReleaseCollectionRequest: req,
queryNodes: qs.queryNodes,
meta: qs.replica,
}
err = qs.sched.DdQueue.Enqueue(releaseCollectionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err = releaseCollectionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
log.Debug("ReleaseCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
//TODO:: queryNode cancel subscribe dmChannels
return status, nil
}
func (qs *QueryService) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
dbID := req.DbID
collectionID := req.CollectionID
partitions, err := qs.replica.getPartitions(dbID, collectionID)
partitionIDs := make([]UniqueID, 0)
for _, partition := range partitions {
partitionIDs = append(partitionIDs, partition.id)
}
if err != nil {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionIDs: partitionIDs,
}, nil
}
func (qs *QueryService) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
//TODO::suggest different partitions have different dm channel
collectionID := req.CollectionID
partitionIDs := req.PartitionIDs
log.Debug("LoadPartitionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
status, watchNeeded, err := LoadPartitionMetaCheck(qs.replica, req)
if err != nil {
return status, err
}
releaseCtx, cancel := context.WithTimeout(qs.loopCtx, 30*time.Second)
loadPartitionTask := &LoadPartitionTask{
BaseTask: BaseTask{
ctx: releaseCtx,
cancel: cancel,
Condition: NewTaskCondition(releaseCtx),
},
LoadPartitionsRequest: req,
masterService: qs.masterServiceClient,
dataService: qs.dataServiceClient,
queryNodes: qs.queryNodes,
meta: qs.replica,
watchNeeded: watchNeeded,
}
err = qs.sched.DdQueue.Enqueue(loadPartitionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err = loadPartitionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
log.Debug("LoadPartitionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID))
return status, nil
}
func LoadPartitionMetaCheck(meta Replica, req *querypb.LoadPartitionsRequest) (*commonpb.Status, bool, error) {
dbID := req.DbID
collectionID := req.CollectionID
partitionIDs := req.PartitionIDs
schema := req.Schema
watchNeeded := false
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}
if len(partitionIDs) == 0 {
err := errors.New("partitionIDs are empty")
status.Reason = err.Error()
return status, watchNeeded, err
}
_, err := meta.getCollectionByID(dbID, collectionID)
if err != nil {
err = meta.addCollection(dbID, collectionID, schema)
if err != nil {
status.Reason = err.Error()
return status, watchNeeded, err
}
watchNeeded = true
}
for _, partitionID := range partitionIDs {
_, err = meta.getPartitionByID(dbID, collectionID, partitionID)
if err == nil {
continue
}
err = meta.addPartition(dbID, collectionID, partitionID)
if err != nil {
status.Reason = err.Error()
return status, watchNeeded, err
}
}
status.ErrorCode = commonpb.ErrorCode_Success
return status, watchNeeded, nil
}
func (qs *QueryService) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
dbID := req.DbID
collectionID := req.CollectionID
partitionIDs := req.PartitionIDs
log.Debug("ReleasePartitionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID), zap.Int64s("partitionIDs", partitionIDs))
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
toReleasedPartitionID := make([]UniqueID, 0)
for _, partitionID := range partitionIDs {
_, err := qs.replica.getPartitionByID(dbID, collectionID, partitionID)
if err == nil {
toReleasedPartitionID = append(toReleasedPartitionID, partitionID)
}
}
if len(toReleasedPartitionID) > 0 {
req.PartitionIDs = toReleasedPartitionID
releasePartitionTask := &ReleasePartitionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
},
ReleasePartitionsRequest: req,
queryNodes: qs.queryNodes,
meta: qs.replica,
}
err := qs.sched.DdQueue.Enqueue(releasePartitionTask)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
err = releasePartitionTask.WaitToFinish()
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
return status, err
}
}
log.Debug("ReleasePartitionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
//TODO:: queryNode cancel subscribe dmChannels
return status, nil
}
func (qs *QueryService) CreateQueryChannel(ctx context.Context) (*querypb.CreateQueryChannelResponse, error) {
channelID := len(qs.queryChannels)
allocatedQueryChannel := "query-" + strconv.FormatInt(int64(channelID), 10)
allocatedQueryResultChannel := "queryResult-" + strconv.FormatInt(int64(channelID), 10)
qs.qcMutex.Lock()
qs.queryChannels = append(qs.queryChannels, &queryChannelInfo{
requestChannel: allocatedQueryChannel,
responseChannel: allocatedQueryResultChannel,
})
addQueryChannelsRequest := &querypb.AddQueryChannelRequest{
RequestChannelID: allocatedQueryChannel,
ResultChannelID: allocatedQueryResultChannel,
}
log.Debug("query service create query channel", zap.String("queryChannelName", allocatedQueryChannel))
for nodeID, node := range qs.queryNodes {
log.Debug("node watch query channel", zap.String("nodeID", fmt.Sprintln(nodeID)))
fn := func() error {
_, err := node.AddQueryChannel(ctx, addQueryChannelsRequest)
return err
}
err := retry.Retry(10, time.Millisecond*200, fn)
if err != nil {
qs.qcMutex.Unlock()
return &querypb.CreateQueryChannelResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
}
qs.qcMutex.Unlock()
return &querypb.CreateQueryChannelResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
RequestChannel: allocatedQueryChannel,
ResultChannel: allocatedQueryResultChannel,
}, nil
}
func (qs *QueryService) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
states, err := qs.replica.getPartitionStates(req.DbID, req.CollectionID, req.PartitionIDs)
if err != nil {
return &querypb.GetPartitionStatesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
PartitionDescriptions: states,
}, err
}
return &querypb.GetPartitionStatesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionDescriptions: states,
}, nil
}
func (qs *QueryService) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
segmentInfos := make([]*querypb.SegmentInfo, 0)
totalMemSize := int64(0)
for nodeID, node := range qs.queryNodes {
segmentInfo, err := node.client.GetSegmentInfo(ctx, req)
if err != nil {
return &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
segmentInfos = append(segmentInfos, segmentInfo.Infos...)
for _, info := range segmentInfo.Infos {
totalMemSize = totalMemSize + info.MemSize
}
log.Debug("getSegmentInfo", zap.Int64("nodeID", nodeID), zap.Int64("memory size", totalMemSize))
}
return &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Infos: segmentInfos,
}, nil
}

View File

@ -2,29 +2,20 @@ package queryservice
import (
"context"
"errors"
"fmt"
"math/rand"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
nodeclient "github.com/zilliztech/milvus-distributed/internal/distributed/querynode/client"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
"github.com/zilliztech/milvus-distributed/internal/types"
"github.com/zilliztech/milvus-distributed/internal/util/retry"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
)
type Timestamp = typeutil.Timestamp
type queryChannelInfo struct {
requestChannel string
responseChannel string
@ -36,6 +27,7 @@ type QueryService struct {
queryServiceID uint64
replica Replica
sched *TaskScheduler
dataServiceClient types.DataService
masterServiceClient types.MasterService
@ -55,11 +47,15 @@ func (qs *QueryService) Init() error {
}
func (qs *QueryService) Start() error {
qs.sched.Start()
log.Debug("start scheduler ...")
qs.UpdateStateCode(internalpb.StateCode_Healthy)
return nil
}
func (qs *QueryService) Stop() error {
qs.sched.Close()
log.Debug("close scheduler ...")
qs.loopCancel()
qs.UpdateStateCode(internalpb.StateCode_Abnormal)
return nil
@ -69,610 +65,18 @@ func (qs *QueryService) UpdateStateCode(code internalpb.StateCode) {
qs.stateCode.Store(code)
}
func (qs *QueryService) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
serviceComponentInfo := &internalpb.ComponentInfo{
NodeID: Params.QueryServiceID,
StateCode: qs.stateCode.Load().(internalpb.StateCode),
}
subComponentInfos := make([]*internalpb.ComponentInfo, 0)
for nodeID, node := range qs.queryNodes {
componentStates, err := node.GetComponentStates(ctx)
if err != nil {
subComponentInfos = append(subComponentInfos, &internalpb.ComponentInfo{
NodeID: nodeID,
StateCode: internalpb.StateCode_Abnormal,
})
continue
}
subComponentInfos = append(subComponentInfos, componentStates.State)
}
return &internalpb.ComponentStates{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
State: serviceComponentInfo,
SubcomponentStates: subComponentInfos,
}, nil
}
func (qs *QueryService) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: Params.TimeTickChannelName,
}, nil
}
func (qs *QueryService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: Params.StatsChannelName,
}, nil
}
func (qs *QueryService) RegisterNode(ctx context.Context, req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) {
log.Debug("register query node", zap.String("address", req.Address.String()))
// TODO:: add mutex
nodeID := req.Base.SourceID
if _, ok := qs.queryNodes[nodeID]; ok {
err := errors.New("nodeID already exists")
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: err.Error(),
},
}, err
}
registerNodeAddress := req.Address.Ip + ":" + strconv.FormatInt(req.Address.Port, 10)
client := nodeclient.NewClient(registerNodeAddress)
if err := client.Init(); err != nil {
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
InitParams: new(internalpb.InitParams),
}, err
}
if err := client.Start(); err != nil {
return nil, err
}
qs.queryNodes[nodeID] = newQueryNodeInfo(client)
//TODO::return init params to queryNode
startParams := []*commonpb.KeyValuePair{
{Key: "StatsChannelName", Value: Params.StatsChannelName},
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
}
qs.qcMutex.Lock()
for _, queryChannel := range qs.queryChannels {
startParams = append(startParams, &commonpb.KeyValuePair{
Key: "QueryChannelName",
Value: queryChannel.requestChannel,
})
startParams = append(startParams, &commonpb.KeyValuePair{
Key: "QueryResultChannelName",
Value: queryChannel.responseChannel,
})
}
qs.qcMutex.Unlock()
return &querypb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
InitParams: &internalpb.InitParams{
NodeID: nodeID,
StartParams: startParams,
},
}, nil
}
func (qs *QueryService) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
dbID := req.DbID
log.Debug("show collection start, dbID = ", zap.String("dbID", strconv.FormatInt(dbID, 10)))
collections, err := qs.replica.getCollections(dbID)
collectionIDs := make([]UniqueID, 0)
for _, collection := range collections {
collectionIDs = append(collectionIDs, collection.id)
}
if err != nil {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: err.Error(),
},
}, err
}
log.Debug("show collection end")
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: collectionIDs,
}, nil
}
func (qs *QueryService) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
log.Debug("LoadCollectionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID),
zap.Stringer("schema", req.Schema))
dbID := req.DbID
collectionID := req.CollectionID
schema := req.Schema
fn := func(err error) *commonpb.Status {
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
}
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
}
log.Debug("load collection start", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
_, err := qs.replica.getCollectionByID(dbID, collectionID)
if err != nil {
err = qs.replica.addCollection(dbID, collectionID, schema)
if err != nil {
return fn(err), err
}
}
// get partitionIDs
showPartitionRequest := &milvuspb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowPartitions,
MsgID: req.Base.MsgID,
},
CollectionID: collectionID,
}
showPartitionResponse, err := qs.masterServiceClient.ShowPartitions(ctx, showPartitionRequest)
if err != nil {
return fn(err), fmt.Errorf("call master ShowPartitions: %s", err)
}
log.Debug("ShowPartitions returned from Master", zap.String("role", Params.RoleName), zap.Int64("msgID", showPartitionRequest.Base.MsgID))
if showPartitionResponse.Status.ErrorCode != commonpb.ErrorCode_Success {
return showPartitionResponse.Status, err
}
partitionIDs := showPartitionResponse.PartitionIDs
partitionIDsToLoad := make([]UniqueID, 0)
partitionsInReplica, err := qs.replica.getPartitions(dbID, collectionID)
if err != nil {
return fn(err), err
}
for _, id := range partitionIDs {
cached := false
for _, partition := range partitionsInReplica {
if id == partition.id {
cached = true
break
}
}
if !cached {
partitionIDsToLoad = append(partitionIDsToLoad, id)
}
}
if len(partitionIDsToLoad) == 0 {
log.Debug("LoadCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.String("collectionID", fmt.Sprintln(collectionID)))
return &commonpb.Status{
Reason: "Partitions has been already loaded!",
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
loadPartitionsRequest := &querypb.LoadPartitionsRequest{
Base: req.Base,
DbID: dbID,
CollectionID: collectionID,
PartitionIDs: partitionIDsToLoad,
Schema: schema,
}
status, err := qs.LoadPartitions(ctx, loadPartitionsRequest)
if err != nil {
log.Error("LoadCollectionRequest failed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
return status, fmt.Errorf("load partitions: %s", err)
}
err = qs.watchDmChannels(ctx, dbID, collectionID)
if err != nil {
log.Error("LoadCollectionRequest failed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
return fn(err), err
}
log.Debug("LoadCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID))
return status, nil
}
func (qs *QueryService) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
dbID := req.DbID
collectionID := req.CollectionID
log.Debug("release collection start", zap.String("collectionID", fmt.Sprintln(collectionID)))
_, err := qs.replica.getCollectionByID(dbID, collectionID)
if err != nil {
log.Error("release collection end, query service don't have the log of", zap.String("collectionID", fmt.Sprintln(collectionID)))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
for nodeID, node := range qs.queryNodes {
status, err := node.ReleaseCollection(ctx, req)
if err != nil {
log.Error("release collection end, node occur error", zap.String("nodeID", fmt.Sprintln(nodeID)))
return status, err
}
}
err = qs.replica.releaseCollection(dbID, collectionID)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, err
}
log.Debug("release collection end", zap.Int64("collectionID", collectionID))
//TODO:: queryNode cancel subscribe dmChannels
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func (qs *QueryService) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
dbID := req.DbID
collectionID := req.CollectionID
partitions, err := qs.replica.getPartitions(dbID, collectionID)
partitionIDs := make([]UniqueID, 0)
for _, partition := range partitions {
partitionIDs = append(partitionIDs, partition.id)
}
if err != nil {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionIDs: partitionIDs,
}, nil
}
func (qs *QueryService) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
//TODO::suggest different partitions have different dm channel
log.Debug("LoadPartitionRequest received", zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID),
zap.Stringer("schema", req.Schema))
dbID := req.DbID
collectionID := req.CollectionID
partitionIDs := req.PartitionIDs
schema := req.Schema
fn := func(err error) *commonpb.Status {
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
}
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
}
log.Debug("load partitions start", zap.String("partitionIDs", fmt.Sprintln(partitionIDs)))
if len(partitionIDs) == 0 {
err := errors.New("partitionIDs are empty")
return fn(err), err
}
watchNeeded := false
_, err := qs.replica.getCollectionByID(dbID, collectionID)
if err != nil {
err = qs.replica.addCollection(dbID, collectionID, schema)
if err != nil {
return fn(err), err
}
watchNeeded = true
}
for _, partitionID := range partitionIDs {
_, err = qs.replica.getPartitionByID(dbID, collectionID, partitionID)
if err == nil {
continue
}
err = qs.replica.addPartition(dbID, collectionID, partitionID)
if err != nil {
return fn(err), err
}
showSegmentRequest := &milvuspb.ShowSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowSegments,
},
CollectionID: collectionID,
PartitionID: partitionID,
}
showSegmentResponse, err := qs.masterServiceClient.ShowSegments(ctx, showSegmentRequest)
if err != nil {
return fn(err), err
}
segmentIDs := showSegmentResponse.SegmentIDs
if len(segmentIDs) == 0 {
loadSegmentRequest := &querypb.LoadSegmentsRequest{
CollectionID: collectionID,
PartitionID: partitionID,
Schema: schema,
}
for _, node := range qs.queryNodes {
_, err := node.LoadSegments(ctx, loadSegmentRequest)
if err != nil {
return fn(err), nil
}
}
}
qs.replica.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_PartialInMemory)
segmentStates := make(map[UniqueID]*datapb.SegmentStateInfo)
channel2segs := make(map[string][]UniqueID)
resp, err := qs.dataServiceClient.GetSegmentStates(ctx, &datapb.GetSegmentStatesRequest{
SegmentIDs: segmentIDs,
})
if err != nil {
return fn(err), err
}
log.Debug("getSegmentStates result ", zap.Any("segment states", resp.States), zap.Any("result status", resp.Status))
for _, state := range resp.States {
log.Debug("segment ", zap.String("state.SegmentID", fmt.Sprintln(state.SegmentID)), zap.String("state", fmt.Sprintln(state.StartPosition)))
segmentID := state.SegmentID
segmentStates[segmentID] = state
channelName := state.StartPosition.ChannelName
if _, ok := channel2segs[channelName]; !ok {
segments := make([]UniqueID, 0)
segments = append(segments, segmentID)
channel2segs[channelName] = segments
} else {
channel2segs[channelName] = append(channel2segs[channelName], segmentID)
}
}
excludeSegment := make([]UniqueID, 0)
for id, state := range segmentStates {
if state.State > commonpb.SegmentState_Growing {
excludeSegment = append(excludeSegment, id)
}
}
for channel, segmentIDs := range channel2segs {
sort.Slice(segmentIDs, func(i, j int) bool {
return segmentStates[segmentIDs[i]].StartPosition.Timestamp < segmentStates[segmentIDs[j]].StartPosition.Timestamp
})
toLoadSegmentIDs := make([]UniqueID, 0)
var watchedStartPos *internalpb.MsgPosition = nil
var startPosition *internalpb.MsgPosition = nil
for index, id := range segmentIDs {
if segmentStates[id].State <= commonpb.SegmentState_Growing {
if index > 0 {
pos := segmentStates[id].StartPosition
if len(pos.MsgID) == 0 {
watchedStartPos = startPosition
break
}
}
watchedStartPos = segmentStates[id].StartPosition
break
}
toLoadSegmentIDs = append(toLoadSegmentIDs, id)
watchedStartPos = segmentStates[id].EndPosition
startPosition = segmentStates[id].StartPosition
}
if watchedStartPos == nil {
watchedStartPos = &internalpb.MsgPosition{
ChannelName: channel,
}
}
err = qs.replica.addDmChannel(dbID, collectionID, channel, watchedStartPos)
if err != nil {
return fn(err), err
}
err = qs.replica.addExcludeSegmentIDs(dbID, collectionID, toLoadSegmentIDs)
if err != nil {
return fn(err), err
}
segment2Node := qs.shuffleSegmentsToQueryNode(toLoadSegmentIDs)
for nodeID, assignedSegmentIDs := range segment2Node {
loadSegmentRequest := &querypb.LoadSegmentsRequest{
// TODO: use unique id allocator to assign reqID
Base: &commonpb.MsgBase{
MsgID: rand.Int63n(10000000000),
},
CollectionID: collectionID,
PartitionID: partitionID,
SegmentIDs: assignedSegmentIDs,
Schema: schema,
}
queryNode := qs.queryNodes[nodeID]
status, err := queryNode.LoadSegments(ctx, loadSegmentRequest)
if err != nil {
return status, err
}
queryNode.AddSegments(assignedSegmentIDs, collectionID)
}
}
qs.replica.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_InMemory)
}
if watchNeeded {
err = qs.watchDmChannels(ctx, dbID, collectionID)
if err != nil {
log.Debug("LoadPartitionRequest completed", zap.Int64("msgID", req.Base.MsgID), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err))
return fn(err), err
}
}
log.Debug("LoadPartitionRequest completed", zap.Int64("msgID", req.Base.MsgID), zap.Int64s("partitionIDs", partitionIDs))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func (qs *QueryService) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
dbID := req.DbID
collectionID := req.CollectionID
partitionIDs := req.PartitionIDs
log.Debug("start release partitions start", zap.String("partitionIDs", fmt.Sprintln(partitionIDs)))
toReleasedPartitionID := make([]UniqueID, 0)
for _, partitionID := range partitionIDs {
_, err := qs.replica.getPartitionByID(dbID, collectionID, partitionID)
if err == nil {
toReleasedPartitionID = append(toReleasedPartitionID, partitionID)
}
}
req.PartitionIDs = toReleasedPartitionID
for _, node := range qs.queryNodes {
status, err := node.client.ReleasePartitions(ctx, req)
if err != nil {
return status, err
}
}
for _, partitionID := range toReleasedPartitionID {
err := qs.replica.releasePartition(dbID, collectionID, partitionID)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, err
}
}
log.Debug("start release partitions end", zap.String("partitionIDs", fmt.Sprintln(partitionIDs)))
//TODO:: queryNode cancel subscribe dmChannels
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func (qs *QueryService) CreateQueryChannel(ctx context.Context) (*querypb.CreateQueryChannelResponse, error) {
channelID := len(qs.queryChannels)
allocatedQueryChannel := "query-" + strconv.FormatInt(int64(channelID), 10)
allocatedQueryResultChannel := "queryResult-" + strconv.FormatInt(int64(channelID), 10)
qs.qcMutex.Lock()
qs.queryChannels = append(qs.queryChannels, &queryChannelInfo{
requestChannel: allocatedQueryChannel,
responseChannel: allocatedQueryResultChannel,
})
addQueryChannelsRequest := &querypb.AddQueryChannelRequest{
RequestChannelID: allocatedQueryChannel,
ResultChannelID: allocatedQueryResultChannel,
}
log.Debug("query service create query channel", zap.String("queryChannelName", allocatedQueryChannel))
for nodeID, node := range qs.queryNodes {
log.Debug("node watch query channel", zap.String("nodeID", fmt.Sprintln(nodeID)))
fn := func() error {
_, err := node.AddQueryChannel(ctx, addQueryChannelsRequest)
return err
}
err := retry.Retry(10, time.Millisecond*200, fn)
if err != nil {
qs.qcMutex.Unlock()
return &querypb.CreateQueryChannelResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
}
qs.qcMutex.Unlock()
return &querypb.CreateQueryChannelResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
RequestChannel: allocatedQueryChannel,
ResultChannel: allocatedQueryResultChannel,
}, nil
}
func (qs *QueryService) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
states, err := qs.replica.getPartitionStates(req.DbID, req.CollectionID, req.PartitionIDs)
if err != nil {
return &querypb.GetPartitionStatesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
PartitionDescriptions: states,
}, err
}
return &querypb.GetPartitionStatesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionDescriptions: states,
}, nil
}
func (qs *QueryService) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
segmentInfos := make([]*querypb.SegmentInfo, 0)
totalMemSize := int64(0)
for nodeID, node := range qs.queryNodes {
segmentInfo, err := node.client.GetSegmentInfo(ctx, req)
if err != nil {
return &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
segmentInfos = append(segmentInfos, segmentInfo.Infos...)
for _, info := range segmentInfo.Infos {
totalMemSize = totalMemSize + info.MemSize
}
log.Debug("getSegmentInfo", zap.Int64("nodeID", nodeID), zap.Int64("memory size", totalMemSize))
}
return &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Infos: segmentInfos,
}, nil
}
func NewQueryService(ctx context.Context, factory msgstream.Factory) (*QueryService, error) {
rand.Seed(time.Now().UnixNano())
nodes := make(map[int64]*queryNodeInfo)
queryChannels := make([]*queryChannelInfo, 0)
ctx1, cancel := context.WithCancel(ctx)
replica := newMetaReplica()
scheduler := NewTaskScheduler(ctx1)
service := &QueryService{
loopCtx: ctx1,
loopCancel: cancel,
replica: replica,
sched: scheduler,
queryNodes: nodes,
queryChannels: queryChannels,
qcMutex: &sync.Mutex{},
@ -690,173 +94,3 @@ func (qs *QueryService) SetMasterService(masterService types.MasterService) {
func (qs *QueryService) SetDataService(dataService types.DataService) {
qs.dataServiceClient = dataService
}
func (qs *QueryService) watchDmChannels(ctx context.Context, dbID UniqueID, collectionID UniqueID) error {
collection, err := qs.replica.getCollectionByID(0, collectionID)
if err != nil {
return err
}
channelRequest := datapb.GetInsertChannelsRequest{
DbID: dbID,
CollectionID: collectionID,
}
resp, err := qs.dataServiceClient.GetInsertChannels(ctx, &channelRequest)
if err != nil {
return err
}
if len(resp.Values) == 0 {
err = errors.New("haven't assign dm channel to collection")
return err
}
dmChannels := resp.Values
channelsWithoutPos := make([]string, 0)
for _, channel := range dmChannels {
findChannel := false
ChannelsWithPos := collection.dmChannels
for _, ch := range ChannelsWithPos {
if channel == ch {
findChannel = true
break
}
}
if !findChannel {
channelsWithoutPos = append(channelsWithoutPos, channel)
}
}
for _, ch := range channelsWithoutPos {
pos := &internalpb.MsgPosition{
ChannelName: ch,
}
err = qs.replica.addDmChannel(dbID, collectionID, ch, pos)
if err != nil {
return err
}
}
channels2NodeID := qs.shuffleChannelsToQueryNode(dmChannels)
for nodeID, channels := range channels2NodeID {
node := qs.queryNodes[nodeID]
watchDmChannelsInfo := make([]*querypb.WatchDmChannelInfo, 0)
for _, ch := range channels {
info := &querypb.WatchDmChannelInfo{
ChannelID: ch,
Pos: collection.dmChannels2Pos[ch],
ExcludedSegments: collection.excludeSegmentIds,
}
watchDmChannelsInfo = append(watchDmChannelsInfo, info)
}
request := &querypb.WatchDmChannelsRequest{
CollectionID: collectionID,
ChannelIDs: channels,
Infos: watchDmChannelsInfo,
}
_, err := node.WatchDmChannels(ctx, request)
if err != nil {
return err
}
node.AddDmChannels(channels, collectionID)
log.Debug("query node ", zap.String("nodeID", strconv.FormatInt(nodeID, 10)), zap.String("watch channels", fmt.Sprintln(channels)))
}
return nil
}
func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[int64][]string {
maxNumChannels := 0
for _, node := range qs.queryNodes {
numChannels := node.getNumChannels()
if numChannels > maxNumChannels {
maxNumChannels = numChannels
}
}
res := make(map[int64][]string)
offset := 0
loopAll := false
for {
lastOffset := offset
if !loopAll {
for id, node := range qs.queryNodes {
if node.getSegmentsLength() >= maxNumChannels {
continue
}
if _, ok := res[id]; !ok {
res[id] = make([]string, 0)
}
res[id] = append(res[id], dmChannels[offset])
offset++
if offset == len(dmChannels) {
return res
}
}
} else {
for id := range qs.queryNodes {
if _, ok := res[id]; !ok {
res[id] = make([]string, 0)
}
res[id] = append(res[id], dmChannels[offset])
offset++
if offset == len(dmChannels) {
return res
}
}
}
if lastOffset == offset {
loopAll = true
}
}
}
func (qs *QueryService) shuffleSegmentsToQueryNode(segmentIDs []UniqueID) map[int64][]UniqueID {
maxNumSegments := 0
for _, node := range qs.queryNodes {
numSegments := node.getNumSegments()
if numSegments > maxNumSegments {
maxNumSegments = numSegments
}
}
res := make(map[int64][]UniqueID)
for nodeID := range qs.queryNodes {
segments := make([]UniqueID, 0)
res[nodeID] = segments
}
if len(segmentIDs) == 0 {
return res
}
offset := 0
loopAll := false
for {
lastOffset := offset
if !loopAll {
for id, node := range qs.queryNodes {
if node.getSegmentsLength() >= maxNumSegments {
continue
}
if _, ok := res[id]; !ok {
res[id] = make([]UniqueID, 0)
}
res[id] = append(res[id], segmentIDs[offset])
offset++
if offset == len(segmentIDs) {
return res
}
}
} else {
for id := range qs.queryNodes {
if _, ok := res[id]; !ok {
res[id] = make([]UniqueID, 0)
}
res[id] = append(res[id], segmentIDs[offset])
offset++
if offset == len(segmentIDs) {
return res
}
}
}
if lastOffset == offset {
loopAll = true
}
}
}

View File

@ -0,0 +1,757 @@
package queryservice
import (
"context"
"errors"
"fmt"
"math/rand"
"sort"
"strconv"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
"github.com/zilliztech/milvus-distributed/internal/types"
)
const (
LoadCollectionTaskName = "LoadCollectionTask"
LoadPartitionTaskName = "LoadPartitionTask"
ReleaseCollectionTaskName = "ReleaseCollection"
ReleasePartitionTaskName = "ReleasePartition"
)
type task interface {
TraceCtx() context.Context
ID() UniqueID // return ReqId
Name() string
Type() commonpb.MsgType
Timestamp() Timestamp
PreExecute(ctx context.Context) error
Execute(ctx context.Context) error
PostExecute(ctx context.Context) error
WaitToFinish() error
Notify(err error)
}
type BaseTask struct {
Condition
ctx context.Context
cancel context.CancelFunc
result *commonpb.Status
}
func (bt *BaseTask) TraceCtx() context.Context {
return bt.ctx
}
type LoadCollectionTask struct {
BaseTask
*querypb.LoadCollectionRequest
masterService types.MasterService
dataService types.DataService
queryNodes map[int64]*queryNodeInfo
meta Replica
watchNeeded bool
}
func (lct *LoadCollectionTask) ID() UniqueID {
return lct.Base.MsgID
}
func (lct *LoadCollectionTask) Type() commonpb.MsgType {
return lct.Base.MsgType
}
func (lct *LoadCollectionTask) Timestamp() Timestamp {
return lct.Base.Timestamp
}
func (lct *LoadCollectionTask) Name() string {
return LoadCollectionTaskName
}
func (lct *LoadCollectionTask) PreExecute(ctx context.Context) error {
collectionID := lct.CollectionID
schema := lct.Schema
log.Debug("start do LoadCollectionTask",
zap.Int64("msgID", lct.ID()),
zap.Int64("collectionID", collectionID),
zap.Stringer("schema", schema))
return nil
}
func (lct *LoadCollectionTask) Execute(ctx context.Context) error {
dbID := lct.DbID
collectionID := lct.CollectionID
schema := lct.LoadCollectionRequest.Schema
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}
// get partitionIDs
showPartitionRequest := &milvuspb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowPartitions,
MsgID: lct.Base.MsgID,
},
CollectionID: collectionID,
}
showPartitionResponse, err := lct.masterService.ShowPartitions(ctx, showPartitionRequest)
if err != nil {
status.Reason = err.Error()
lct.result = status
return fmt.Errorf("call master ShowPartitions: %s", err)
}
log.Debug("ShowPartitions returned from Master", zap.String("role", Params.RoleName), zap.Int64("msgID", showPartitionRequest.Base.MsgID))
if showPartitionResponse.Status.ErrorCode != commonpb.ErrorCode_Success {
lct.result = showPartitionResponse.Status
return err
}
partitionIDs := showPartitionResponse.PartitionIDs
partitionIDsToLoad := make([]UniqueID, 0)
partitionsInReplica, err := lct.meta.getPartitions(dbID, collectionID)
if err != nil {
status.Reason = err.Error()
lct.result = status
return err
}
for _, id := range partitionIDs {
cached := false
for _, partition := range partitionsInReplica {
if id == partition.id {
cached = true
break
}
}
if !cached {
partitionIDsToLoad = append(partitionIDsToLoad, id)
}
}
if len(partitionIDsToLoad) == 0 {
log.Debug("load collection done", zap.String("role", Params.RoleName), zap.Int64("msgID", lct.ID()), zap.String("collectionID", fmt.Sprintln(collectionID)))
status.ErrorCode = commonpb.ErrorCode_Success
status.Reason = "Partitions has been already loaded!"
lct.result = status
return nil
}
loadPartitionsRequest := &querypb.LoadPartitionsRequest{
Base: lct.Base,
DbID: dbID,
CollectionID: collectionID,
PartitionIDs: partitionIDsToLoad,
Schema: schema,
}
status, _, err = LoadPartitionMetaCheck(lct.meta, loadPartitionsRequest)
if err != nil {
lct.result = status
return err
}
loadPartitionTask := &LoadPartitionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
},
LoadPartitionsRequest: loadPartitionsRequest,
masterService: lct.masterService,
dataService: lct.dataService,
queryNodes: lct.queryNodes,
meta: lct.meta,
watchNeeded: false,
}
err = loadPartitionTask.PreExecute(ctx)
if err != nil {
status.Reason = err.Error()
lct.result = status
return err
}
err = loadPartitionTask.Execute(ctx)
if err != nil {
status.Reason = err.Error()
lct.result = status
return err
}
log.Debug("LoadCollection execute done",
zap.Int64("msgID", lct.ID()),
zap.Int64("collectionID", collectionID),
zap.Stringer("schema", schema))
return nil
}
func (lct *LoadCollectionTask) PostExecute(ctx context.Context) error {
dbID := lct.DbID
collectionID := lct.CollectionID
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
if lct.watchNeeded {
err := watchDmChannels(ctx, lct.dataService, lct.queryNodes, lct.meta, dbID, collectionID, lct.Base)
if err != nil {
log.Debug("watchDmChannels failed", zap.Int64("msgID", lct.ID()), zap.Int64("collectionID", collectionID), zap.Error(err))
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
lct.result = status
return err
}
}
log.Debug("LoadCollectionTask postExecute done",
zap.Int64("msgID", lct.ID()),
zap.Int64("collectionID", collectionID))
lct.result = status
//lct.cancel()
return nil
}
type ReleaseCollectionTask struct {
BaseTask
*querypb.ReleaseCollectionRequest
queryNodes map[int64]*queryNodeInfo
meta Replica
}
func (rct *ReleaseCollectionTask) ID() UniqueID {
return rct.Base.MsgID
}
func (rct *ReleaseCollectionTask) Type() commonpb.MsgType {
return rct.Base.MsgType
}
func (rct *ReleaseCollectionTask) Timestamp() Timestamp {
return rct.Base.Timestamp
}
func (rct *ReleaseCollectionTask) Name() string {
return ReleaseCollectionTaskName
}
func (rct *ReleaseCollectionTask) PreExecute(ctx context.Context) error {
collectionID := rct.CollectionID
log.Debug("start do ReleaseCollectionTask",
zap.Int64("msgID", rct.ID()),
zap.Int64("collectionID", collectionID))
return nil
}
func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error {
collectionID := rct.CollectionID
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
for nodeID, node := range rct.queryNodes {
_, err := node.ReleaseCollection(ctx, rct.ReleaseCollectionRequest)
if err != nil {
log.Error("release collection end, node occur error", zap.String("nodeID", fmt.Sprintln(nodeID)))
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
rct.result = status
return err
}
}
rct.result = status
log.Debug("ReleaseCollectionTask Execute done",
zap.Int64("msgID", rct.ID()),
zap.Int64("collectionID", collectionID))
return nil
}
func (rct *ReleaseCollectionTask) PostExecute(ctx context.Context) error {
dbID := rct.DbID
collectionID := rct.CollectionID
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
err := rct.meta.releaseCollection(dbID, collectionID)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
rct.result = status
return err
}
rct.result = status
log.Debug("ReleaseCollectionTask postExecute done",
zap.Int64("msgID", rct.ID()),
zap.Int64("collectionID", collectionID))
return nil
}
type LoadPartitionTask struct {
BaseTask
*querypb.LoadPartitionsRequest
masterService types.MasterService
dataService types.DataService
queryNodes map[int64]*queryNodeInfo
meta Replica
watchNeeded bool
}
func (lpt *LoadPartitionTask) ID() UniqueID {
return lpt.Base.MsgID
}
func (lpt *LoadPartitionTask) Type() commonpb.MsgType {
return lpt.Base.MsgType
}
func (lpt *LoadPartitionTask) Timestamp() Timestamp {
return lpt.Base.Timestamp
}
func (lpt *LoadPartitionTask) Name() string {
return LoadPartitionTaskName
}
func (lpt *LoadPartitionTask) PreExecute(ctx context.Context) error {
collectionID := lpt.CollectionID
log.Debug("start do LoadPartitionTask",
zap.Int64("msgID", lpt.ID()),
zap.Int64("collectionID", collectionID))
return nil
}
func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
//TODO::suggest different partitions have different dm channel
dbID := lpt.DbID
collectionID := lpt.CollectionID
partitionIDs := lpt.PartitionIDs
schema := lpt.Schema
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}
for _, partitionID := range partitionIDs {
showSegmentRequest := &milvuspb.ShowSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowSegments,
},
CollectionID: collectionID,
PartitionID: partitionID,
}
showSegmentResponse, err := lpt.masterService.ShowSegments(ctx, showSegmentRequest)
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
segmentIDs := showSegmentResponse.SegmentIDs
if len(segmentIDs) == 0 {
loadSegmentRequest := &querypb.LoadSegmentsRequest{
CollectionID: collectionID,
PartitionID: partitionID,
Schema: schema,
}
for _, node := range lpt.queryNodes {
_, err := node.LoadSegments(ctx, loadSegmentRequest)
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
}
}
lpt.meta.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_PartialInMemory)
segmentStates := make(map[UniqueID]*datapb.SegmentStateInfo)
channel2segs := make(map[string][]UniqueID)
resp, err := lpt.dataService.GetSegmentStates(ctx, &datapb.GetSegmentStatesRequest{
SegmentIDs: segmentIDs,
})
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
log.Debug("getSegmentStates result ", zap.Any("segment states", resp.States), zap.Any("result status", resp.Status))
for _, state := range resp.States {
log.Debug("segment ", zap.String("state.SegmentID", fmt.Sprintln(state.SegmentID)), zap.String("state", fmt.Sprintln(state.StartPosition)))
segmentID := state.SegmentID
segmentStates[segmentID] = state
channelName := state.StartPosition.ChannelName
if _, ok := channel2segs[channelName]; !ok {
segments := make([]UniqueID, 0)
segments = append(segments, segmentID)
channel2segs[channelName] = segments
} else {
channel2segs[channelName] = append(channel2segs[channelName], segmentID)
}
}
excludeSegment := make([]UniqueID, 0)
for id, state := range segmentStates {
if state.State > commonpb.SegmentState_Growing {
excludeSegment = append(excludeSegment, id)
}
}
for channel, segmentIDs := range channel2segs {
sort.Slice(segmentIDs, func(i, j int) bool {
return segmentStates[segmentIDs[i]].StartPosition.Timestamp < segmentStates[segmentIDs[j]].StartPosition.Timestamp
})
toLoadSegmentIDs := make([]UniqueID, 0)
var watchedStartPos *internalpb.MsgPosition = nil
var startPosition *internalpb.MsgPosition = nil
for index, id := range segmentIDs {
if segmentStates[id].State <= commonpb.SegmentState_Growing {
if index > 0 {
pos := segmentStates[id].StartPosition
if len(pos.MsgID) == 0 {
watchedStartPos = startPosition
break
}
}
watchedStartPos = segmentStates[id].StartPosition
break
}
toLoadSegmentIDs = append(toLoadSegmentIDs, id)
watchedStartPos = segmentStates[id].EndPosition
startPosition = segmentStates[id].StartPosition
}
if watchedStartPos == nil {
watchedStartPos = &internalpb.MsgPosition{
ChannelName: channel,
}
}
err = lpt.meta.addDmChannel(dbID, collectionID, channel, watchedStartPos)
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
err = lpt.meta.addExcludeSegmentIDs(dbID, collectionID, toLoadSegmentIDs)
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
segment2Node := shuffleSegmentsToQueryNode(toLoadSegmentIDs, lpt.queryNodes)
for nodeID, assignedSegmentIDs := range segment2Node {
loadSegmentRequest := &querypb.LoadSegmentsRequest{
// TODO: use unique id allocator to assign reqID
Base: &commonpb.MsgBase{
MsgID: rand.Int63n(10000000000),
},
CollectionID: collectionID,
PartitionID: partitionID,
SegmentIDs: assignedSegmentIDs,
Schema: schema,
}
queryNode := lpt.queryNodes[nodeID]
status, err := queryNode.LoadSegments(ctx, loadSegmentRequest)
if err != nil {
lpt.result = status
return err
}
queryNode.AddSegments(assignedSegmentIDs, collectionID)
}
}
lpt.meta.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_InMemory)
}
log.Debug("LoadPartitionTask Execute done",
zap.Int64("msgID", lpt.ID()),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs))
status.ErrorCode = commonpb.ErrorCode_Success
lpt.result = status
return nil
}
func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
dbID := lpt.DbID
collectionID := lpt.CollectionID
partitionIDs := lpt.PartitionIDs
if lpt.watchNeeded {
err := watchDmChannels(ctx, lpt.dataService, lpt.queryNodes, lpt.meta, dbID, collectionID, lpt.Base)
if err != nil {
log.Debug("watchDmChannels failed", zap.Int64("msgID", lpt.ID()), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err))
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
lpt.result = status
return err
}
}
log.Debug("watchDmChannels completed", zap.Int64("msgID", lpt.ID()), zap.Int64s("partitionIDs", partitionIDs))
lpt.result = status
log.Debug("LoadPartitionTask postExecute done",
zap.Int64("msgID", lpt.ID()),
zap.Int64("collectionID", collectionID))
//lpt.cancel()
return nil
}
type ReleasePartitionTask struct {
BaseTask
*querypb.ReleasePartitionsRequest
queryNodes map[int64]*queryNodeInfo
meta Replica
}
func (rpt *ReleasePartitionTask) ID() UniqueID {
return rpt.Base.MsgID
}
func (rpt *ReleasePartitionTask) Type() commonpb.MsgType {
return rpt.Base.MsgType
}
func (rpt *ReleasePartitionTask) Timestamp() Timestamp {
return rpt.Base.Timestamp
}
func (rpt *ReleasePartitionTask) Name() string {
return ReleasePartitionTaskName
}
func (rpt *ReleasePartitionTask) PreExecute(ctx context.Context) error {
collectionID := rpt.CollectionID
log.Debug("start do releasePartitionTask",
zap.Int64("msgID", rpt.ID()),
zap.Int64("collectionID", collectionID))
return nil
}
func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error {
collectionID := rpt.CollectionID
partitionIDs := rpt.PartitionIDs
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
for _, node := range rpt.queryNodes {
status, err := node.client.ReleasePartitions(ctx, rpt.ReleasePartitionsRequest)
if err != nil {
rpt.result = status
return err
}
}
rpt.result = status
log.Debug("ReleasePartitionTask Execute done",
zap.Int64("msgID", rpt.ID()),
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs))
return nil
}
func (rpt *ReleasePartitionTask) PostExecute(ctx context.Context) error {
dbID := rpt.DbID
collectionID := rpt.CollectionID
partitionIDs := rpt.PartitionIDs
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
for _, partitionID := range partitionIDs {
err := rpt.meta.releasePartition(dbID, collectionID, partitionID)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
rpt.result = status
return err
}
}
rpt.result = status
log.Debug("ReleasePartitionTask postExecute done",
zap.Int64("msgID", rpt.ID()),
zap.Int64("collectionID", collectionID))
return nil
}
func watchDmChannels(ctx context.Context,
dataService types.DataService,
queryNodes map[int64]*queryNodeInfo,
meta Replica,
dbID UniqueID,
collectionID UniqueID,
msgBase *commonpb.MsgBase) error {
collection, err := meta.getCollectionByID(0, collectionID)
if err != nil {
return err
}
channelRequest := datapb.GetInsertChannelsRequest{
DbID: dbID,
CollectionID: collectionID,
}
resp, err := dataService.GetInsertChannels(ctx, &channelRequest)
if err != nil {
return err
}
if len(resp.Values) == 0 {
err = errors.New("haven't assign dm channel to collection")
return err
}
dmChannels := resp.Values
channelsWithoutPos := make([]string, 0)
for _, channel := range dmChannels {
findChannel := false
ChannelsWithPos := collection.dmChannels
for _, ch := range ChannelsWithPos {
if channel == ch {
findChannel = true
break
}
}
if !findChannel {
channelsWithoutPos = append(channelsWithoutPos, channel)
}
}
for _, ch := range channelsWithoutPos {
pos := &internalpb.MsgPosition{
ChannelName: ch,
}
err = meta.addDmChannel(dbID, collectionID, ch, pos)
if err != nil {
return err
}
}
channels2NodeID := shuffleChannelsToQueryNode(dmChannels, queryNodes)
for nodeID, channels := range channels2NodeID {
node := queryNodes[nodeID]
watchDmChannelsInfo := make([]*querypb.WatchDmChannelInfo, 0)
for _, ch := range channels {
info := &querypb.WatchDmChannelInfo{
ChannelID: ch,
Pos: collection.dmChannels2Pos[ch],
ExcludedSegments: collection.excludeSegmentIds,
}
watchDmChannelsInfo = append(watchDmChannelsInfo, info)
}
request := &querypb.WatchDmChannelsRequest{
Base: msgBase,
CollectionID: collectionID,
ChannelIDs: channels,
Infos: watchDmChannelsInfo,
}
_, err := node.WatchDmChannels(ctx, request)
if err != nil {
return err
}
node.AddDmChannels(channels, collectionID)
log.Debug("query node ", zap.String("nodeID", strconv.FormatInt(nodeID, 10)), zap.String("watch channels", fmt.Sprintln(channels)))
}
return nil
}
func shuffleChannelsToQueryNode(dmChannels []string, queryNodes map[int64]*queryNodeInfo) map[int64][]string {
maxNumChannels := 0
for _, node := range queryNodes {
numChannels := node.getNumChannels()
if numChannels > maxNumChannels {
maxNumChannels = numChannels
}
}
res := make(map[int64][]string)
offset := 0
loopAll := false
for {
lastOffset := offset
if !loopAll {
for id, node := range queryNodes {
if node.getSegmentsLength() >= maxNumChannels {
continue
}
if _, ok := res[id]; !ok {
res[id] = make([]string, 0)
}
res[id] = append(res[id], dmChannels[offset])
offset++
if offset == len(dmChannels) {
return res
}
}
} else {
for id := range queryNodes {
if _, ok := res[id]; !ok {
res[id] = make([]string, 0)
}
res[id] = append(res[id], dmChannels[offset])
offset++
if offset == len(dmChannels) {
return res
}
}
}
if lastOffset == offset {
loopAll = true
}
}
}
func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, queryNodes map[int64]*queryNodeInfo) map[int64][]UniqueID {
maxNumSegments := 0
for _, node := range queryNodes {
numSegments := node.getNumSegments()
if numSegments > maxNumSegments {
maxNumSegments = numSegments
}
}
res := make(map[int64][]UniqueID)
for nodeID := range queryNodes {
segments := make([]UniqueID, 0)
res[nodeID] = segments
}
if len(segmentIDs) == 0 {
return res
}
offset := 0
loopAll := false
for {
lastOffset := offset
if !loopAll {
for id, node := range queryNodes {
if node.getSegmentsLength() >= maxNumSegments {
continue
}
if _, ok := res[id]; !ok {
res[id] = make([]UniqueID, 0)
}
res[id] = append(res[id], segmentIDs[offset])
offset++
if offset == len(segmentIDs) {
return res
}
}
} else {
for id := range queryNodes {
if _, ok := res[id]; !ok {
res[id] = make([]UniqueID, 0)
}
res[id] = append(res[id], segmentIDs[offset])
offset++
if offset == len(segmentIDs) {
return res
}
}
}
if lastOffset == offset {
loopAll = true
}
}
}

View File

@ -0,0 +1,252 @@
package queryservice
import (
"container/list"
"context"
"errors"
"sync"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
oplog "github.com/opentracing/opentracing-go/log"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/util/trace"
)
type TaskQueue interface {
utChan() <-chan int
utEmpty() bool
utFull() bool
addUnissuedTask(t task) error
FrontUnissuedTask() task
PopUnissuedTask() task
AddActiveTask(t task)
PopActiveTask(ts Timestamp) task
Enqueue(t task) error
}
type BaseTaskQueue struct {
unissuedTasks *list.List
activeTasks map[Timestamp]task
utLock sync.Mutex
atLock sync.Mutex
maxTaskNum int64
utBufChan chan int // to block scheduler
sched *TaskScheduler
}
func (queue *BaseTaskQueue) utChan() <-chan int {
return queue.utBufChan
}
func (queue *BaseTaskQueue) utEmpty() bool {
queue.utLock.Lock()
defer queue.utLock.Unlock()
return queue.unissuedTasks.Len() == 0
}
func (queue *BaseTaskQueue) utFull() bool {
return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum
}
func (queue *BaseTaskQueue) addUnissuedTask(t task) error {
queue.utLock.Lock()
defer queue.utLock.Unlock()
if queue.utFull() {
return errors.New("task queue is full")
}
if queue.unissuedTasks.Len() <= 0 {
queue.unissuedTasks.PushBack(t)
queue.utBufChan <- 1
return nil
}
if t.Timestamp() >= queue.unissuedTasks.Back().Value.(task).Timestamp() {
queue.unissuedTasks.PushBack(t)
queue.utBufChan <- 1
return nil
}
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
if t.Timestamp() <= e.Value.(task).Timestamp() {
queue.unissuedTasks.InsertBefore(t, e)
queue.utBufChan <- 1
return nil
}
}
return errors.New("unexpected error in addUnissuedTask")
}
func (queue *BaseTaskQueue) FrontUnissuedTask() task {
queue.utLock.Lock()
defer queue.utLock.Unlock()
if queue.unissuedTasks.Len() <= 0 {
log.Warn("sorry, but the unissued task list is empty!")
return nil
}
return queue.unissuedTasks.Front().Value.(task)
}
func (queue *BaseTaskQueue) PopUnissuedTask() task {
queue.utLock.Lock()
defer queue.utLock.Unlock()
if queue.unissuedTasks.Len() <= 0 {
log.Warn("sorry, but the unissued task list is empty!")
return nil
}
ft := queue.unissuedTasks.Front()
queue.unissuedTasks.Remove(ft)
return ft.Value.(task)
}
func (queue *BaseTaskQueue) AddActiveTask(t task) {
queue.atLock.Lock()
defer queue.atLock.Unlock()
ts := t.Timestamp()
_, ok := queue.activeTasks[ts]
if ok {
log.Debug("queryService", zap.Uint64("task with timestamp ts already in active task list! ts:", ts))
}
queue.activeTasks[ts] = t
}
func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) task {
queue.atLock.Lock()
defer queue.atLock.Unlock()
t, ok := queue.activeTasks[ts]
if ok {
log.Debug("queryService", zap.Uint64("task with timestamp ts has been deleted in active task list! ts:", ts))
delete(queue.activeTasks, ts)
return t
}
return nil
}
func (queue *BaseTaskQueue) Enqueue(t task) error {
return queue.addUnissuedTask(t)
}
type DdTaskQueue struct {
BaseTaskQueue
lock sync.Mutex
}
func (queue *DdTaskQueue) Enqueue(t task) error {
queue.lock.Lock()
defer queue.lock.Unlock()
return queue.BaseTaskQueue.Enqueue(t)
}
func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue {
return &DdTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[Timestamp]task),
maxTaskNum: 1024,
utBufChan: make(chan int, 1024),
sched: sched,
},
}
}
type TaskScheduler struct {
DdQueue TaskQueue
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func NewTaskScheduler(ctx context.Context) *TaskScheduler {
ctx1, cancel := context.WithCancel(ctx)
s := &TaskScheduler{
ctx: ctx1,
cancel: cancel,
}
s.DdQueue = NewDdTaskQueue(s)
return s
}
func (sched *TaskScheduler) scheduleDdTask() task {
return sched.DdQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
span, ctx := trace.StartSpanFromContext(t.TraceCtx(),
opentracing.Tags{
"Type": t.Name(),
"ID": t.ID(),
})
defer span.Finish()
span.LogFields(oplog.Int64("scheduler process PreExecute", t.ID()))
err := t.PreExecute(ctx)
defer func() {
t.Notify(err)
}()
if err != nil {
log.Debug("preExecute err", zap.String("reason", err.Error()))
trace.LogError(span, err)
return
}
span.LogFields(oplog.Int64("scheduler process AddActiveTask", t.ID()))
q.AddActiveTask(t)
defer func() {
span.LogFields(oplog.Int64("scheduler process PopActiveTask", t.ID()))
q.PopActiveTask(t.Timestamp())
}()
span.LogFields(oplog.Int64("scheduler process Execute", t.ID()))
err = t.Execute(ctx)
if err != nil {
log.Debug("execute err", zap.String("reason", err.Error()))
trace.LogError(span, err)
return
}
span.LogFields(oplog.Int64("scheduler process PostExecute", t.ID()))
err = t.PostExecute(ctx)
}
func (sched *TaskScheduler) definitionLoop() {
defer sched.wg.Done()
for {
select {
case <-sched.ctx.Done():
return
case <-sched.DdQueue.utChan():
if !sched.DdQueue.utEmpty() {
t := sched.scheduleDdTask()
sched.processTask(t, sched.DdQueue)
}
}
}
}
func (sched *TaskScheduler) Start() error {
sched.wg.Add(1)
go sched.definitionLoop()
return nil
}
func (sched *TaskScheduler) Close() {
sched.cancel()
sched.wg.Wait()
}