mirror of https://github.com/milvus-io/milvus.git
Refactor load Release to async call in query service
Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/4973/head^2
parent
e4256a4400
commit
78155d3959
|
@ -654,7 +654,7 @@ func (ms *MqTtMsgStream) Seek(mp *internalpb.MsgPosition) error {
|
|||
consumer, hasWatched = ms.consumers[seekChannel]
|
||||
|
||||
if hasWatched {
|
||||
return errors.New("the channel should has been subscribed")
|
||||
return errors.New("the channel should has not been subscribed")
|
||||
}
|
||||
|
||||
fn := func() error {
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
package queryservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type Condition interface {
|
||||
WaitToFinish() error
|
||||
Notify(err error)
|
||||
Ctx() context.Context
|
||||
}
|
||||
|
||||
type TaskCondition struct {
|
||||
done chan error
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (tc *TaskCondition) WaitToFinish() error {
|
||||
for {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
return errors.New("timeout")
|
||||
case err := <-tc.done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TaskCondition) Notify(err error) {
|
||||
tc.done <- err
|
||||
}
|
||||
|
||||
func (tc *TaskCondition) Ctx() context.Context {
|
||||
return tc.ctx
|
||||
}
|
||||
|
||||
func NewTaskCondition(ctx context.Context) *TaskCondition {
|
||||
return &TaskCondition{
|
||||
done: make(chan error, 1),
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,484 @@
|
|||
package queryservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
nodeclient "github.com/zilliztech/milvus-distributed/internal/distributed/querynode/client"
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/retry"
|
||||
)
|
||||
|
||||
func (qs *QueryService) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
serviceComponentInfo := &internalpb.ComponentInfo{
|
||||
NodeID: Params.QueryServiceID,
|
||||
StateCode: qs.stateCode.Load().(internalpb.StateCode),
|
||||
}
|
||||
subComponentInfos := make([]*internalpb.ComponentInfo, 0)
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
componentStates, err := node.GetComponentStates(ctx)
|
||||
if err != nil {
|
||||
subComponentInfos = append(subComponentInfos, &internalpb.ComponentInfo{
|
||||
NodeID: nodeID,
|
||||
StateCode: internalpb.StateCode_Abnormal,
|
||||
})
|
||||
continue
|
||||
}
|
||||
subComponentInfos = append(subComponentInfos, componentStates.State)
|
||||
}
|
||||
return &internalpb.ComponentStates{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
State: serviceComponentInfo,
|
||||
SubcomponentStates: subComponentInfos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.TimeTickChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.StatsChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) RegisterNode(ctx context.Context, req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) {
|
||||
log.Debug("register query node", zap.String("address", req.Address.String()))
|
||||
// TODO:: add mutex
|
||||
nodeID := req.Base.SourceID
|
||||
if _, ok := qs.queryNodes[nodeID]; ok {
|
||||
err := errors.New("nodeID already exists")
|
||||
return &querypb.RegisterNodeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
|
||||
registerNodeAddress := req.Address.Ip + ":" + strconv.FormatInt(req.Address.Port, 10)
|
||||
client := nodeclient.NewClient(registerNodeAddress)
|
||||
if err := client.Init(); err != nil {
|
||||
return &querypb.RegisterNodeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
InitParams: new(internalpb.InitParams),
|
||||
}, err
|
||||
}
|
||||
if err := client.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qs.queryNodes[nodeID] = newQueryNodeInfo(client)
|
||||
|
||||
//TODO::return init params to queryNode
|
||||
startParams := []*commonpb.KeyValuePair{
|
||||
{Key: "StatsChannelName", Value: Params.StatsChannelName},
|
||||
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
|
||||
}
|
||||
qs.qcMutex.Lock()
|
||||
for _, queryChannel := range qs.queryChannels {
|
||||
startParams = append(startParams, &commonpb.KeyValuePair{
|
||||
Key: "QueryChannelName",
|
||||
Value: queryChannel.requestChannel,
|
||||
})
|
||||
startParams = append(startParams, &commonpb.KeyValuePair{
|
||||
Key: "QueryResultChannelName",
|
||||
Value: queryChannel.responseChannel,
|
||||
})
|
||||
}
|
||||
qs.qcMutex.Unlock()
|
||||
|
||||
return &querypb.RegisterNodeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
InitParams: &internalpb.InitParams{
|
||||
NodeID: nodeID,
|
||||
StartParams: startParams,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
dbID := req.DbID
|
||||
log.Debug("show collection start, dbID = ", zap.String("dbID", strconv.FormatInt(dbID, 10)))
|
||||
collections, err := qs.replica.getCollections(dbID)
|
||||
collectionIDs := make([]UniqueID, 0)
|
||||
for _, collection := range collections {
|
||||
collectionIDs = append(collectionIDs, collection.id)
|
||||
}
|
||||
if err != nil {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
log.Debug("show collection end")
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionIDs: collectionIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
schema := req.Schema
|
||||
watchNeeded := false
|
||||
log.Debug("LoadCollectionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID),
|
||||
zap.Stringer("schema", req.Schema))
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
}
|
||||
|
||||
_, err := qs.replica.getCollectionByID(dbID, collectionID)
|
||||
if err != nil {
|
||||
watchNeeded = true
|
||||
err = qs.replica.addCollection(dbID, collectionID, schema)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
loadCtx, cancel := context.WithTimeout(qs.loopCtx, 30*time.Second)
|
||||
loadCollectionTask := &LoadCollectionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: loadCtx,
|
||||
cancel: cancel,
|
||||
Condition: NewTaskCondition(loadCtx),
|
||||
},
|
||||
LoadCollectionRequest: req,
|
||||
masterService: qs.masterServiceClient,
|
||||
dataService: qs.dataServiceClient,
|
||||
queryNodes: qs.queryNodes,
|
||||
meta: qs.replica,
|
||||
watchNeeded: watchNeeded,
|
||||
}
|
||||
err = qs.sched.DdQueue.Enqueue(loadCollectionTask)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
err = loadCollectionTask.WaitToFinish()
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
log.Debug("LoadCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
|
||||
status.ErrorCode = commonpb.ErrorCode_Success
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
log.Debug("ReleaseCollectionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
_, err := qs.replica.getCollectionByID(dbID, collectionID)
|
||||
if err != nil {
|
||||
log.Error("release collection end, query service don't have the log of", zap.String("collectionID", fmt.Sprintln(collectionID)))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
releaseCollectionTask := &ReleaseCollectionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
},
|
||||
ReleaseCollectionRequest: req,
|
||||
queryNodes: qs.queryNodes,
|
||||
meta: qs.replica,
|
||||
}
|
||||
err = qs.sched.DdQueue.Enqueue(releaseCollectionTask)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
err = releaseCollectionTask.WaitToFinish()
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
log.Debug("ReleaseCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
|
||||
//TODO:: queryNode cancel subscribe dmChannels
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
partitions, err := qs.replica.getPartitions(dbID, collectionID)
|
||||
partitionIDs := make([]UniqueID, 0)
|
||||
for _, partition := range partitions {
|
||||
partitionIDs = append(partitionIDs, partition.id)
|
||||
}
|
||||
if err != nil {
|
||||
return &querypb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
return &querypb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
PartitionIDs: partitionIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
//TODO::suggest different partitions have different dm channel
|
||||
collectionID := req.CollectionID
|
||||
partitionIDs := req.PartitionIDs
|
||||
log.Debug("LoadPartitionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
|
||||
status, watchNeeded, err := LoadPartitionMetaCheck(qs.replica, req)
|
||||
if err != nil {
|
||||
return status, err
|
||||
}
|
||||
|
||||
releaseCtx, cancel := context.WithTimeout(qs.loopCtx, 30*time.Second)
|
||||
loadPartitionTask := &LoadPartitionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: releaseCtx,
|
||||
cancel: cancel,
|
||||
Condition: NewTaskCondition(releaseCtx),
|
||||
},
|
||||
LoadPartitionsRequest: req,
|
||||
masterService: qs.masterServiceClient,
|
||||
dataService: qs.dataServiceClient,
|
||||
queryNodes: qs.queryNodes,
|
||||
meta: qs.replica,
|
||||
watchNeeded: watchNeeded,
|
||||
}
|
||||
err = qs.sched.DdQueue.Enqueue(loadPartitionTask)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
err = loadPartitionTask.WaitToFinish()
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
log.Debug("LoadPartitionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func LoadPartitionMetaCheck(meta Replica, req *querypb.LoadPartitionsRequest) (*commonpb.Status, bool, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
partitionIDs := req.PartitionIDs
|
||||
schema := req.Schema
|
||||
watchNeeded := false
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
}
|
||||
|
||||
if len(partitionIDs) == 0 {
|
||||
err := errors.New("partitionIDs are empty")
|
||||
status.Reason = err.Error()
|
||||
return status, watchNeeded, err
|
||||
}
|
||||
|
||||
_, err := meta.getCollectionByID(dbID, collectionID)
|
||||
if err != nil {
|
||||
err = meta.addCollection(dbID, collectionID, schema)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
return status, watchNeeded, err
|
||||
}
|
||||
watchNeeded = true
|
||||
}
|
||||
|
||||
for _, partitionID := range partitionIDs {
|
||||
_, err = meta.getPartitionByID(dbID, collectionID, partitionID)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
err = meta.addPartition(dbID, collectionID, partitionID)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
return status, watchNeeded, err
|
||||
}
|
||||
}
|
||||
|
||||
status.ErrorCode = commonpb.ErrorCode_Success
|
||||
return status, watchNeeded, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
partitionIDs := req.PartitionIDs
|
||||
log.Debug("ReleasePartitionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID), zap.Int64s("partitionIDs", partitionIDs))
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
toReleasedPartitionID := make([]UniqueID, 0)
|
||||
for _, partitionID := range partitionIDs {
|
||||
_, err := qs.replica.getPartitionByID(dbID, collectionID, partitionID)
|
||||
if err == nil {
|
||||
toReleasedPartitionID = append(toReleasedPartitionID, partitionID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toReleasedPartitionID) > 0 {
|
||||
req.PartitionIDs = toReleasedPartitionID
|
||||
releasePartitionTask := &ReleasePartitionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
},
|
||||
ReleasePartitionsRequest: req,
|
||||
queryNodes: qs.queryNodes,
|
||||
meta: qs.replica,
|
||||
}
|
||||
err := qs.sched.DdQueue.Enqueue(releasePartitionTask)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
err = releasePartitionTask.WaitToFinish()
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
log.Debug("ReleasePartitionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
|
||||
//TODO:: queryNode cancel subscribe dmChannels
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) CreateQueryChannel(ctx context.Context) (*querypb.CreateQueryChannelResponse, error) {
|
||||
channelID := len(qs.queryChannels)
|
||||
allocatedQueryChannel := "query-" + strconv.FormatInt(int64(channelID), 10)
|
||||
allocatedQueryResultChannel := "queryResult-" + strconv.FormatInt(int64(channelID), 10)
|
||||
|
||||
qs.qcMutex.Lock()
|
||||
qs.queryChannels = append(qs.queryChannels, &queryChannelInfo{
|
||||
requestChannel: allocatedQueryChannel,
|
||||
responseChannel: allocatedQueryResultChannel,
|
||||
})
|
||||
|
||||
addQueryChannelsRequest := &querypb.AddQueryChannelRequest{
|
||||
RequestChannelID: allocatedQueryChannel,
|
||||
ResultChannelID: allocatedQueryResultChannel,
|
||||
}
|
||||
log.Debug("query service create query channel", zap.String("queryChannelName", allocatedQueryChannel))
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
log.Debug("node watch query channel", zap.String("nodeID", fmt.Sprintln(nodeID)))
|
||||
fn := func() error {
|
||||
_, err := node.AddQueryChannel(ctx, addQueryChannelsRequest)
|
||||
return err
|
||||
}
|
||||
err := retry.Retry(10, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
qs.qcMutex.Unlock()
|
||||
return &querypb.CreateQueryChannelResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
}
|
||||
qs.qcMutex.Unlock()
|
||||
|
||||
return &querypb.CreateQueryChannelResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
RequestChannel: allocatedQueryChannel,
|
||||
ResultChannel: allocatedQueryResultChannel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
|
||||
states, err := qs.replica.getPartitionStates(req.DbID, req.CollectionID, req.PartitionIDs)
|
||||
if err != nil {
|
||||
return &querypb.GetPartitionStatesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
PartitionDescriptions: states,
|
||||
}, err
|
||||
}
|
||||
return &querypb.GetPartitionStatesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
PartitionDescriptions: states,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||
segmentInfos := make([]*querypb.SegmentInfo, 0)
|
||||
totalMemSize := int64(0)
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
segmentInfo, err := node.client.GetSegmentInfo(ctx, req)
|
||||
if err != nil {
|
||||
return &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
segmentInfos = append(segmentInfos, segmentInfo.Infos...)
|
||||
for _, info := range segmentInfo.Infos {
|
||||
totalMemSize = totalMemSize + info.MemSize
|
||||
}
|
||||
log.Debug("getSegmentInfo", zap.Int64("nodeID", nodeID), zap.Int64("memory size", totalMemSize))
|
||||
}
|
||||
return &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Infos: segmentInfos,
|
||||
}, nil
|
||||
}
|
|
@ -2,29 +2,20 @@ package queryservice
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
nodeclient "github.com/zilliztech/milvus-distributed/internal/distributed/querynode/client"
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/types"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/retry"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type Timestamp = typeutil.Timestamp
|
||||
|
||||
type queryChannelInfo struct {
|
||||
requestChannel string
|
||||
responseChannel string
|
||||
|
@ -36,6 +27,7 @@ type QueryService struct {
|
|||
|
||||
queryServiceID uint64
|
||||
replica Replica
|
||||
sched *TaskScheduler
|
||||
|
||||
dataServiceClient types.DataService
|
||||
masterServiceClient types.MasterService
|
||||
|
@ -55,11 +47,15 @@ func (qs *QueryService) Init() error {
|
|||
}
|
||||
|
||||
func (qs *QueryService) Start() error {
|
||||
qs.sched.Start()
|
||||
log.Debug("start scheduler ...")
|
||||
qs.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) Stop() error {
|
||||
qs.sched.Close()
|
||||
log.Debug("close scheduler ...")
|
||||
qs.loopCancel()
|
||||
qs.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
return nil
|
||||
|
@ -69,610 +65,18 @@ func (qs *QueryService) UpdateStateCode(code internalpb.StateCode) {
|
|||
qs.stateCode.Store(code)
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
serviceComponentInfo := &internalpb.ComponentInfo{
|
||||
NodeID: Params.QueryServiceID,
|
||||
StateCode: qs.stateCode.Load().(internalpb.StateCode),
|
||||
}
|
||||
subComponentInfos := make([]*internalpb.ComponentInfo, 0)
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
componentStates, err := node.GetComponentStates(ctx)
|
||||
if err != nil {
|
||||
subComponentInfos = append(subComponentInfos, &internalpb.ComponentInfo{
|
||||
NodeID: nodeID,
|
||||
StateCode: internalpb.StateCode_Abnormal,
|
||||
})
|
||||
continue
|
||||
}
|
||||
subComponentInfos = append(subComponentInfos, componentStates.State)
|
||||
}
|
||||
return &internalpb.ComponentStates{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
State: serviceComponentInfo,
|
||||
SubcomponentStates: subComponentInfos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.TimeTickChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.StatsChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) RegisterNode(ctx context.Context, req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) {
|
||||
log.Debug("register query node", zap.String("address", req.Address.String()))
|
||||
// TODO:: add mutex
|
||||
nodeID := req.Base.SourceID
|
||||
if _, ok := qs.queryNodes[nodeID]; ok {
|
||||
err := errors.New("nodeID already exists")
|
||||
return &querypb.RegisterNodeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
|
||||
registerNodeAddress := req.Address.Ip + ":" + strconv.FormatInt(req.Address.Port, 10)
|
||||
client := nodeclient.NewClient(registerNodeAddress)
|
||||
if err := client.Init(); err != nil {
|
||||
return &querypb.RegisterNodeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
InitParams: new(internalpb.InitParams),
|
||||
}, err
|
||||
}
|
||||
if err := client.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qs.queryNodes[nodeID] = newQueryNodeInfo(client)
|
||||
|
||||
//TODO::return init params to queryNode
|
||||
startParams := []*commonpb.KeyValuePair{
|
||||
{Key: "StatsChannelName", Value: Params.StatsChannelName},
|
||||
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
|
||||
}
|
||||
qs.qcMutex.Lock()
|
||||
for _, queryChannel := range qs.queryChannels {
|
||||
startParams = append(startParams, &commonpb.KeyValuePair{
|
||||
Key: "QueryChannelName",
|
||||
Value: queryChannel.requestChannel,
|
||||
})
|
||||
startParams = append(startParams, &commonpb.KeyValuePair{
|
||||
Key: "QueryResultChannelName",
|
||||
Value: queryChannel.responseChannel,
|
||||
})
|
||||
}
|
||||
qs.qcMutex.Unlock()
|
||||
|
||||
return &querypb.RegisterNodeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
InitParams: &internalpb.InitParams{
|
||||
NodeID: nodeID,
|
||||
StartParams: startParams,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
dbID := req.DbID
|
||||
log.Debug("show collection start, dbID = ", zap.String("dbID", strconv.FormatInt(dbID, 10)))
|
||||
collections, err := qs.replica.getCollections(dbID)
|
||||
collectionIDs := make([]UniqueID, 0)
|
||||
for _, collection := range collections {
|
||||
collectionIDs = append(collectionIDs, collection.id)
|
||||
}
|
||||
if err != nil {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
log.Debug("show collection end")
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionIDs: collectionIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||
log.Debug("LoadCollectionRequest received", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID),
|
||||
zap.Stringer("schema", req.Schema))
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
schema := req.Schema
|
||||
fn := func(err error) *commonpb.Status {
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
}
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("load collection start", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", collectionID))
|
||||
|
||||
_, err := qs.replica.getCollectionByID(dbID, collectionID)
|
||||
if err != nil {
|
||||
err = qs.replica.addCollection(dbID, collectionID, schema)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
}
|
||||
|
||||
// get partitionIDs
|
||||
showPartitionRequest := &milvuspb.ShowPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowPartitions,
|
||||
MsgID: req.Base.MsgID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
}
|
||||
|
||||
showPartitionResponse, err := qs.masterServiceClient.ShowPartitions(ctx, showPartitionRequest)
|
||||
if err != nil {
|
||||
return fn(err), fmt.Errorf("call master ShowPartitions: %s", err)
|
||||
}
|
||||
log.Debug("ShowPartitions returned from Master", zap.String("role", Params.RoleName), zap.Int64("msgID", showPartitionRequest.Base.MsgID))
|
||||
if showPartitionResponse.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return showPartitionResponse.Status, err
|
||||
}
|
||||
partitionIDs := showPartitionResponse.PartitionIDs
|
||||
|
||||
partitionIDsToLoad := make([]UniqueID, 0)
|
||||
partitionsInReplica, err := qs.replica.getPartitions(dbID, collectionID)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
for _, id := range partitionIDs {
|
||||
cached := false
|
||||
for _, partition := range partitionsInReplica {
|
||||
if id == partition.id {
|
||||
cached = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !cached {
|
||||
partitionIDsToLoad = append(partitionIDsToLoad, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(partitionIDsToLoad) == 0 {
|
||||
log.Debug("LoadCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.String("collectionID", fmt.Sprintln(collectionID)))
|
||||
return &commonpb.Status{
|
||||
Reason: "Partitions has been already loaded!",
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
loadPartitionsRequest := &querypb.LoadPartitionsRequest{
|
||||
Base: req.Base,
|
||||
DbID: dbID,
|
||||
CollectionID: collectionID,
|
||||
PartitionIDs: partitionIDsToLoad,
|
||||
Schema: schema,
|
||||
}
|
||||
|
||||
status, err := qs.LoadPartitions(ctx, loadPartitionsRequest)
|
||||
if err != nil {
|
||||
log.Error("LoadCollectionRequest failed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
|
||||
return status, fmt.Errorf("load partitions: %s", err)
|
||||
}
|
||||
|
||||
err = qs.watchDmChannels(ctx, dbID, collectionID)
|
||||
if err != nil {
|
||||
log.Error("LoadCollectionRequest failed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
|
||||
return fn(err), err
|
||||
}
|
||||
|
||||
log.Debug("LoadCollectionRequest completed", zap.String("role", Params.RoleName), zap.Int64("msgID", req.Base.MsgID))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
log.Debug("release collection start", zap.String("collectionID", fmt.Sprintln(collectionID)))
|
||||
_, err := qs.replica.getCollectionByID(dbID, collectionID)
|
||||
if err != nil {
|
||||
log.Error("release collection end, query service don't have the log of", zap.String("collectionID", fmt.Sprintln(collectionID)))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
status, err := node.ReleaseCollection(ctx, req)
|
||||
if err != nil {
|
||||
log.Error("release collection end, node occur error", zap.String("nodeID", fmt.Sprintln(nodeID)))
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
|
||||
err = qs.replica.releaseCollection(dbID, collectionID)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}, err
|
||||
}
|
||||
|
||||
log.Debug("release collection end", zap.Int64("collectionID", collectionID))
|
||||
//TODO:: queryNode cancel subscribe dmChannels
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
partitions, err := qs.replica.getPartitions(dbID, collectionID)
|
||||
partitionIDs := make([]UniqueID, 0)
|
||||
for _, partition := range partitions {
|
||||
partitionIDs = append(partitionIDs, partition.id)
|
||||
}
|
||||
if err != nil {
|
||||
return &querypb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
return &querypb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
PartitionIDs: partitionIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
//TODO::suggest different partitions have different dm channel
|
||||
log.Debug("LoadPartitionRequest received", zap.Int64("msgID", req.Base.MsgID), zap.Int64("collectionID", req.CollectionID),
|
||||
zap.Stringer("schema", req.Schema))
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
partitionIDs := req.PartitionIDs
|
||||
schema := req.Schema
|
||||
|
||||
fn := func(err error) *commonpb.Status {
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
}
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
}
|
||||
log.Debug("load partitions start", zap.String("partitionIDs", fmt.Sprintln(partitionIDs)))
|
||||
|
||||
if len(partitionIDs) == 0 {
|
||||
err := errors.New("partitionIDs are empty")
|
||||
return fn(err), err
|
||||
}
|
||||
|
||||
watchNeeded := false
|
||||
_, err := qs.replica.getCollectionByID(dbID, collectionID)
|
||||
if err != nil {
|
||||
err = qs.replica.addCollection(dbID, collectionID, schema)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
watchNeeded = true
|
||||
}
|
||||
|
||||
for _, partitionID := range partitionIDs {
|
||||
_, err = qs.replica.getPartitionByID(dbID, collectionID, partitionID)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
err = qs.replica.addPartition(dbID, collectionID, partitionID)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
|
||||
showSegmentRequest := &milvuspb.ShowSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowSegments,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
}
|
||||
showSegmentResponse, err := qs.masterServiceClient.ShowSegments(ctx, showSegmentRequest)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
segmentIDs := showSegmentResponse.SegmentIDs
|
||||
if len(segmentIDs) == 0 {
|
||||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
Schema: schema,
|
||||
}
|
||||
for _, node := range qs.queryNodes {
|
||||
_, err := node.LoadSegments(ctx, loadSegmentRequest)
|
||||
if err != nil {
|
||||
return fn(err), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
qs.replica.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_PartialInMemory)
|
||||
|
||||
segmentStates := make(map[UniqueID]*datapb.SegmentStateInfo)
|
||||
channel2segs := make(map[string][]UniqueID)
|
||||
resp, err := qs.dataServiceClient.GetSegmentStates(ctx, &datapb.GetSegmentStatesRequest{
|
||||
SegmentIDs: segmentIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
log.Debug("getSegmentStates result ", zap.Any("segment states", resp.States), zap.Any("result status", resp.Status))
|
||||
for _, state := range resp.States {
|
||||
log.Debug("segment ", zap.String("state.SegmentID", fmt.Sprintln(state.SegmentID)), zap.String("state", fmt.Sprintln(state.StartPosition)))
|
||||
segmentID := state.SegmentID
|
||||
segmentStates[segmentID] = state
|
||||
channelName := state.StartPosition.ChannelName
|
||||
if _, ok := channel2segs[channelName]; !ok {
|
||||
segments := make([]UniqueID, 0)
|
||||
segments = append(segments, segmentID)
|
||||
channel2segs[channelName] = segments
|
||||
} else {
|
||||
channel2segs[channelName] = append(channel2segs[channelName], segmentID)
|
||||
}
|
||||
}
|
||||
|
||||
excludeSegment := make([]UniqueID, 0)
|
||||
for id, state := range segmentStates {
|
||||
if state.State > commonpb.SegmentState_Growing {
|
||||
excludeSegment = append(excludeSegment, id)
|
||||
}
|
||||
}
|
||||
for channel, segmentIDs := range channel2segs {
|
||||
sort.Slice(segmentIDs, func(i, j int) bool {
|
||||
return segmentStates[segmentIDs[i]].StartPosition.Timestamp < segmentStates[segmentIDs[j]].StartPosition.Timestamp
|
||||
})
|
||||
toLoadSegmentIDs := make([]UniqueID, 0)
|
||||
var watchedStartPos *internalpb.MsgPosition = nil
|
||||
var startPosition *internalpb.MsgPosition = nil
|
||||
for index, id := range segmentIDs {
|
||||
if segmentStates[id].State <= commonpb.SegmentState_Growing {
|
||||
if index > 0 {
|
||||
pos := segmentStates[id].StartPosition
|
||||
if len(pos.MsgID) == 0 {
|
||||
watchedStartPos = startPosition
|
||||
break
|
||||
}
|
||||
}
|
||||
watchedStartPos = segmentStates[id].StartPosition
|
||||
break
|
||||
}
|
||||
toLoadSegmentIDs = append(toLoadSegmentIDs, id)
|
||||
watchedStartPos = segmentStates[id].EndPosition
|
||||
startPosition = segmentStates[id].StartPosition
|
||||
}
|
||||
if watchedStartPos == nil {
|
||||
watchedStartPos = &internalpb.MsgPosition{
|
||||
ChannelName: channel,
|
||||
}
|
||||
}
|
||||
|
||||
err = qs.replica.addDmChannel(dbID, collectionID, channel, watchedStartPos)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
err = qs.replica.addExcludeSegmentIDs(dbID, collectionID, toLoadSegmentIDs)
|
||||
if err != nil {
|
||||
return fn(err), err
|
||||
}
|
||||
|
||||
segment2Node := qs.shuffleSegmentsToQueryNode(toLoadSegmentIDs)
|
||||
for nodeID, assignedSegmentIDs := range segment2Node {
|
||||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
// TODO: use unique id allocator to assign reqID
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63n(10000000000),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
SegmentIDs: assignedSegmentIDs,
|
||||
Schema: schema,
|
||||
}
|
||||
|
||||
queryNode := qs.queryNodes[nodeID]
|
||||
status, err := queryNode.LoadSegments(ctx, loadSegmentRequest)
|
||||
if err != nil {
|
||||
return status, err
|
||||
}
|
||||
queryNode.AddSegments(assignedSegmentIDs, collectionID)
|
||||
}
|
||||
}
|
||||
|
||||
qs.replica.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_InMemory)
|
||||
}
|
||||
|
||||
if watchNeeded {
|
||||
err = qs.watchDmChannels(ctx, dbID, collectionID)
|
||||
if err != nil {
|
||||
log.Debug("LoadPartitionRequest completed", zap.Int64("msgID", req.Base.MsgID), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err))
|
||||
return fn(err), err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("LoadPartitionRequest completed", zap.Int64("msgID", req.Base.MsgID), zap.Int64s("partitionIDs", partitionIDs))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
dbID := req.DbID
|
||||
collectionID := req.CollectionID
|
||||
partitionIDs := req.PartitionIDs
|
||||
log.Debug("start release partitions start", zap.String("partitionIDs", fmt.Sprintln(partitionIDs)))
|
||||
toReleasedPartitionID := make([]UniqueID, 0)
|
||||
for _, partitionID := range partitionIDs {
|
||||
_, err := qs.replica.getPartitionByID(dbID, collectionID, partitionID)
|
||||
if err == nil {
|
||||
toReleasedPartitionID = append(toReleasedPartitionID, partitionID)
|
||||
}
|
||||
}
|
||||
|
||||
req.PartitionIDs = toReleasedPartitionID
|
||||
|
||||
for _, node := range qs.queryNodes {
|
||||
status, err := node.client.ReleasePartitions(ctx, req)
|
||||
if err != nil {
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, partitionID := range toReleasedPartitionID {
|
||||
err := qs.replica.releasePartition(dbID, collectionID, partitionID)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("start release partitions end", zap.String("partitionIDs", fmt.Sprintln(partitionIDs)))
|
||||
//TODO:: queryNode cancel subscribe dmChannels
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) CreateQueryChannel(ctx context.Context) (*querypb.CreateQueryChannelResponse, error) {
|
||||
channelID := len(qs.queryChannels)
|
||||
allocatedQueryChannel := "query-" + strconv.FormatInt(int64(channelID), 10)
|
||||
allocatedQueryResultChannel := "queryResult-" + strconv.FormatInt(int64(channelID), 10)
|
||||
|
||||
qs.qcMutex.Lock()
|
||||
qs.queryChannels = append(qs.queryChannels, &queryChannelInfo{
|
||||
requestChannel: allocatedQueryChannel,
|
||||
responseChannel: allocatedQueryResultChannel,
|
||||
})
|
||||
|
||||
addQueryChannelsRequest := &querypb.AddQueryChannelRequest{
|
||||
RequestChannelID: allocatedQueryChannel,
|
||||
ResultChannelID: allocatedQueryResultChannel,
|
||||
}
|
||||
log.Debug("query service create query channel", zap.String("queryChannelName", allocatedQueryChannel))
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
log.Debug("node watch query channel", zap.String("nodeID", fmt.Sprintln(nodeID)))
|
||||
fn := func() error {
|
||||
_, err := node.AddQueryChannel(ctx, addQueryChannelsRequest)
|
||||
return err
|
||||
}
|
||||
err := retry.Retry(10, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
qs.qcMutex.Unlock()
|
||||
return &querypb.CreateQueryChannelResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
}
|
||||
qs.qcMutex.Unlock()
|
||||
|
||||
return &querypb.CreateQueryChannelResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
RequestChannel: allocatedQueryChannel,
|
||||
ResultChannel: allocatedQueryResultChannel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
|
||||
states, err := qs.replica.getPartitionStates(req.DbID, req.CollectionID, req.PartitionIDs)
|
||||
if err != nil {
|
||||
return &querypb.GetPartitionStatesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
PartitionDescriptions: states,
|
||||
}, err
|
||||
}
|
||||
return &querypb.GetPartitionStatesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
PartitionDescriptions: states,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||
segmentInfos := make([]*querypb.SegmentInfo, 0)
|
||||
totalMemSize := int64(0)
|
||||
for nodeID, node := range qs.queryNodes {
|
||||
segmentInfo, err := node.client.GetSegmentInfo(ctx, req)
|
||||
if err != nil {
|
||||
return &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
segmentInfos = append(segmentInfos, segmentInfo.Infos...)
|
||||
for _, info := range segmentInfo.Infos {
|
||||
totalMemSize = totalMemSize + info.MemSize
|
||||
}
|
||||
log.Debug("getSegmentInfo", zap.Int64("nodeID", nodeID), zap.Int64("memory size", totalMemSize))
|
||||
}
|
||||
return &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Infos: segmentInfos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewQueryService(ctx context.Context, factory msgstream.Factory) (*QueryService, error) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
nodes := make(map[int64]*queryNodeInfo)
|
||||
queryChannels := make([]*queryChannelInfo, 0)
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
replica := newMetaReplica()
|
||||
scheduler := NewTaskScheduler(ctx1)
|
||||
service := &QueryService{
|
||||
loopCtx: ctx1,
|
||||
loopCancel: cancel,
|
||||
replica: replica,
|
||||
sched: scheduler,
|
||||
queryNodes: nodes,
|
||||
queryChannels: queryChannels,
|
||||
qcMutex: &sync.Mutex{},
|
||||
|
@ -690,173 +94,3 @@ func (qs *QueryService) SetMasterService(masterService types.MasterService) {
|
|||
func (qs *QueryService) SetDataService(dataService types.DataService) {
|
||||
qs.dataServiceClient = dataService
|
||||
}
|
||||
|
||||
func (qs *QueryService) watchDmChannels(ctx context.Context, dbID UniqueID, collectionID UniqueID) error {
|
||||
collection, err := qs.replica.getCollectionByID(0, collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channelRequest := datapb.GetInsertChannelsRequest{
|
||||
DbID: dbID,
|
||||
CollectionID: collectionID,
|
||||
}
|
||||
resp, err := qs.dataServiceClient.GetInsertChannels(ctx, &channelRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.Values) == 0 {
|
||||
err = errors.New("haven't assign dm channel to collection")
|
||||
return err
|
||||
}
|
||||
|
||||
dmChannels := resp.Values
|
||||
channelsWithoutPos := make([]string, 0)
|
||||
for _, channel := range dmChannels {
|
||||
findChannel := false
|
||||
ChannelsWithPos := collection.dmChannels
|
||||
for _, ch := range ChannelsWithPos {
|
||||
if channel == ch {
|
||||
findChannel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !findChannel {
|
||||
channelsWithoutPos = append(channelsWithoutPos, channel)
|
||||
}
|
||||
}
|
||||
for _, ch := range channelsWithoutPos {
|
||||
pos := &internalpb.MsgPosition{
|
||||
ChannelName: ch,
|
||||
}
|
||||
err = qs.replica.addDmChannel(dbID, collectionID, ch, pos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
channels2NodeID := qs.shuffleChannelsToQueryNode(dmChannels)
|
||||
for nodeID, channels := range channels2NodeID {
|
||||
node := qs.queryNodes[nodeID]
|
||||
watchDmChannelsInfo := make([]*querypb.WatchDmChannelInfo, 0)
|
||||
for _, ch := range channels {
|
||||
info := &querypb.WatchDmChannelInfo{
|
||||
ChannelID: ch,
|
||||
Pos: collection.dmChannels2Pos[ch],
|
||||
ExcludedSegments: collection.excludeSegmentIds,
|
||||
}
|
||||
watchDmChannelsInfo = append(watchDmChannelsInfo, info)
|
||||
}
|
||||
request := &querypb.WatchDmChannelsRequest{
|
||||
CollectionID: collectionID,
|
||||
ChannelIDs: channels,
|
||||
Infos: watchDmChannelsInfo,
|
||||
}
|
||||
_, err := node.WatchDmChannels(ctx, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.AddDmChannels(channels, collectionID)
|
||||
log.Debug("query node ", zap.String("nodeID", strconv.FormatInt(nodeID, 10)), zap.String("watch channels", fmt.Sprintln(channels)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[int64][]string {
|
||||
maxNumChannels := 0
|
||||
for _, node := range qs.queryNodes {
|
||||
numChannels := node.getNumChannels()
|
||||
if numChannels > maxNumChannels {
|
||||
maxNumChannels = numChannels
|
||||
}
|
||||
}
|
||||
res := make(map[int64][]string)
|
||||
offset := 0
|
||||
loopAll := false
|
||||
for {
|
||||
lastOffset := offset
|
||||
if !loopAll {
|
||||
for id, node := range qs.queryNodes {
|
||||
if node.getSegmentsLength() >= maxNumChannels {
|
||||
continue
|
||||
}
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]string, 0)
|
||||
}
|
||||
res[id] = append(res[id], dmChannels[offset])
|
||||
offset++
|
||||
if offset == len(dmChannels) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for id := range qs.queryNodes {
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]string, 0)
|
||||
}
|
||||
res[id] = append(res[id], dmChannels[offset])
|
||||
offset++
|
||||
if offset == len(dmChannels) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastOffset == offset {
|
||||
loopAll = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (qs *QueryService) shuffleSegmentsToQueryNode(segmentIDs []UniqueID) map[int64][]UniqueID {
|
||||
maxNumSegments := 0
|
||||
for _, node := range qs.queryNodes {
|
||||
numSegments := node.getNumSegments()
|
||||
if numSegments > maxNumSegments {
|
||||
maxNumSegments = numSegments
|
||||
}
|
||||
}
|
||||
res := make(map[int64][]UniqueID)
|
||||
for nodeID := range qs.queryNodes {
|
||||
segments := make([]UniqueID, 0)
|
||||
res[nodeID] = segments
|
||||
}
|
||||
|
||||
if len(segmentIDs) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
offset := 0
|
||||
loopAll := false
|
||||
for {
|
||||
lastOffset := offset
|
||||
if !loopAll {
|
||||
for id, node := range qs.queryNodes {
|
||||
if node.getSegmentsLength() >= maxNumSegments {
|
||||
continue
|
||||
}
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]UniqueID, 0)
|
||||
}
|
||||
res[id] = append(res[id], segmentIDs[offset])
|
||||
offset++
|
||||
if offset == len(segmentIDs) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for id := range qs.queryNodes {
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]UniqueID, 0)
|
||||
}
|
||||
res[id] = append(res[id], segmentIDs[offset])
|
||||
offset++
|
||||
if offset == len(segmentIDs) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastOffset == offset {
|
||||
loopAll = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,757 @@
|
|||
package queryservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
LoadCollectionTaskName = "LoadCollectionTask"
|
||||
LoadPartitionTaskName = "LoadPartitionTask"
|
||||
ReleaseCollectionTaskName = "ReleaseCollection"
|
||||
ReleasePartitionTaskName = "ReleasePartition"
|
||||
)
|
||||
|
||||
type task interface {
|
||||
TraceCtx() context.Context
|
||||
ID() UniqueID // return ReqId
|
||||
Name() string
|
||||
Type() commonpb.MsgType
|
||||
Timestamp() Timestamp
|
||||
PreExecute(ctx context.Context) error
|
||||
Execute(ctx context.Context) error
|
||||
PostExecute(ctx context.Context) error
|
||||
WaitToFinish() error
|
||||
Notify(err error)
|
||||
}
|
||||
|
||||
type BaseTask struct {
|
||||
Condition
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
result *commonpb.Status
|
||||
}
|
||||
|
||||
func (bt *BaseTask) TraceCtx() context.Context {
|
||||
return bt.ctx
|
||||
}
|
||||
|
||||
type LoadCollectionTask struct {
|
||||
BaseTask
|
||||
*querypb.LoadCollectionRequest
|
||||
masterService types.MasterService
|
||||
dataService types.DataService
|
||||
queryNodes map[int64]*queryNodeInfo
|
||||
meta Replica
|
||||
watchNeeded bool
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) ID() UniqueID {
|
||||
return lct.Base.MsgID
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) Type() commonpb.MsgType {
|
||||
return lct.Base.MsgType
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) Timestamp() Timestamp {
|
||||
return lct.Base.Timestamp
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) Name() string {
|
||||
return LoadCollectionTaskName
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) PreExecute(ctx context.Context) error {
|
||||
collectionID := lct.CollectionID
|
||||
schema := lct.Schema
|
||||
log.Debug("start do LoadCollectionTask",
|
||||
zap.Int64("msgID", lct.ID()),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Stringer("schema", schema))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) Execute(ctx context.Context) error {
|
||||
dbID := lct.DbID
|
||||
collectionID := lct.CollectionID
|
||||
schema := lct.LoadCollectionRequest.Schema
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
}
|
||||
|
||||
// get partitionIDs
|
||||
showPartitionRequest := &milvuspb.ShowPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowPartitions,
|
||||
MsgID: lct.Base.MsgID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
}
|
||||
|
||||
showPartitionResponse, err := lct.masterService.ShowPartitions(ctx, showPartitionRequest)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lct.result = status
|
||||
return fmt.Errorf("call master ShowPartitions: %s", err)
|
||||
}
|
||||
log.Debug("ShowPartitions returned from Master", zap.String("role", Params.RoleName), zap.Int64("msgID", showPartitionRequest.Base.MsgID))
|
||||
if showPartitionResponse.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
lct.result = showPartitionResponse.Status
|
||||
return err
|
||||
}
|
||||
partitionIDs := showPartitionResponse.PartitionIDs
|
||||
|
||||
partitionIDsToLoad := make([]UniqueID, 0)
|
||||
partitionsInReplica, err := lct.meta.getPartitions(dbID, collectionID)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lct.result = status
|
||||
return err
|
||||
}
|
||||
for _, id := range partitionIDs {
|
||||
cached := false
|
||||
for _, partition := range partitionsInReplica {
|
||||
if id == partition.id {
|
||||
cached = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !cached {
|
||||
partitionIDsToLoad = append(partitionIDsToLoad, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(partitionIDsToLoad) == 0 {
|
||||
log.Debug("load collection done", zap.String("role", Params.RoleName), zap.Int64("msgID", lct.ID()), zap.String("collectionID", fmt.Sprintln(collectionID)))
|
||||
status.ErrorCode = commonpb.ErrorCode_Success
|
||||
status.Reason = "Partitions has been already loaded!"
|
||||
lct.result = status
|
||||
return nil
|
||||
}
|
||||
|
||||
loadPartitionsRequest := &querypb.LoadPartitionsRequest{
|
||||
Base: lct.Base,
|
||||
DbID: dbID,
|
||||
CollectionID: collectionID,
|
||||
PartitionIDs: partitionIDsToLoad,
|
||||
Schema: schema,
|
||||
}
|
||||
|
||||
status, _, err = LoadPartitionMetaCheck(lct.meta, loadPartitionsRequest)
|
||||
if err != nil {
|
||||
lct.result = status
|
||||
return err
|
||||
}
|
||||
|
||||
loadPartitionTask := &LoadPartitionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
},
|
||||
LoadPartitionsRequest: loadPartitionsRequest,
|
||||
masterService: lct.masterService,
|
||||
dataService: lct.dataService,
|
||||
queryNodes: lct.queryNodes,
|
||||
meta: lct.meta,
|
||||
watchNeeded: false,
|
||||
}
|
||||
|
||||
err = loadPartitionTask.PreExecute(ctx)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lct.result = status
|
||||
return err
|
||||
}
|
||||
err = loadPartitionTask.Execute(ctx)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lct.result = status
|
||||
return err
|
||||
}
|
||||
log.Debug("LoadCollection execute done",
|
||||
zap.Int64("msgID", lct.ID()),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Stringer("schema", schema))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lct *LoadCollectionTask) PostExecute(ctx context.Context) error {
|
||||
dbID := lct.DbID
|
||||
collectionID := lct.CollectionID
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
if lct.watchNeeded {
|
||||
err := watchDmChannels(ctx, lct.dataService, lct.queryNodes, lct.meta, dbID, collectionID, lct.Base)
|
||||
if err != nil {
|
||||
log.Debug("watchDmChannels failed", zap.Int64("msgID", lct.ID()), zap.Int64("collectionID", collectionID), zap.Error(err))
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
lct.result = status
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("LoadCollectionTask postExecute done",
|
||||
zap.Int64("msgID", lct.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
lct.result = status
|
||||
//lct.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
type ReleaseCollectionTask struct {
|
||||
BaseTask
|
||||
*querypb.ReleaseCollectionRequest
|
||||
queryNodes map[int64]*queryNodeInfo
|
||||
meta Replica
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) ID() UniqueID {
|
||||
return rct.Base.MsgID
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) Type() commonpb.MsgType {
|
||||
return rct.Base.MsgType
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) Timestamp() Timestamp {
|
||||
return rct.Base.Timestamp
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) Name() string {
|
||||
return ReleaseCollectionTaskName
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) PreExecute(ctx context.Context) error {
|
||||
collectionID := rct.CollectionID
|
||||
log.Debug("start do ReleaseCollectionTask",
|
||||
zap.Int64("msgID", rct.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error {
|
||||
collectionID := rct.CollectionID
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
for nodeID, node := range rct.queryNodes {
|
||||
_, err := node.ReleaseCollection(ctx, rct.ReleaseCollectionRequest)
|
||||
if err != nil {
|
||||
log.Error("release collection end, node occur error", zap.String("nodeID", fmt.Sprintln(nodeID)))
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
rct.result = status
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
rct.result = status
|
||||
log.Debug("ReleaseCollectionTask Execute done",
|
||||
zap.Int64("msgID", rct.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rct *ReleaseCollectionTask) PostExecute(ctx context.Context) error {
|
||||
dbID := rct.DbID
|
||||
collectionID := rct.CollectionID
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
err := rct.meta.releaseCollection(dbID, collectionID)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
rct.result = status
|
||||
return err
|
||||
}
|
||||
|
||||
rct.result = status
|
||||
log.Debug("ReleaseCollectionTask postExecute done",
|
||||
zap.Int64("msgID", rct.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
type LoadPartitionTask struct {
|
||||
BaseTask
|
||||
*querypb.LoadPartitionsRequest
|
||||
masterService types.MasterService
|
||||
dataService types.DataService
|
||||
queryNodes map[int64]*queryNodeInfo
|
||||
meta Replica
|
||||
watchNeeded bool
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) ID() UniqueID {
|
||||
return lpt.Base.MsgID
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) Type() commonpb.MsgType {
|
||||
return lpt.Base.MsgType
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) Timestamp() Timestamp {
|
||||
return lpt.Base.Timestamp
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) Name() string {
|
||||
return LoadPartitionTaskName
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) PreExecute(ctx context.Context) error {
|
||||
collectionID := lpt.CollectionID
|
||||
log.Debug("start do LoadPartitionTask",
|
||||
zap.Int64("msgID", lpt.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
|
||||
//TODO::suggest different partitions have different dm channel
|
||||
dbID := lpt.DbID
|
||||
collectionID := lpt.CollectionID
|
||||
partitionIDs := lpt.PartitionIDs
|
||||
schema := lpt.Schema
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
}
|
||||
|
||||
for _, partitionID := range partitionIDs {
|
||||
showSegmentRequest := &milvuspb.ShowSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowSegments,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
}
|
||||
showSegmentResponse, err := lpt.masterService.ShowSegments(ctx, showSegmentRequest)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
segmentIDs := showSegmentResponse.SegmentIDs
|
||||
if len(segmentIDs) == 0 {
|
||||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
Schema: schema,
|
||||
}
|
||||
for _, node := range lpt.queryNodes {
|
||||
_, err := node.LoadSegments(ctx, loadSegmentRequest)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lpt.meta.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_PartialInMemory)
|
||||
|
||||
segmentStates := make(map[UniqueID]*datapb.SegmentStateInfo)
|
||||
channel2segs := make(map[string][]UniqueID)
|
||||
resp, err := lpt.dataService.GetSegmentStates(ctx, &datapb.GetSegmentStatesRequest{
|
||||
SegmentIDs: segmentIDs,
|
||||
})
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
log.Debug("getSegmentStates result ", zap.Any("segment states", resp.States), zap.Any("result status", resp.Status))
|
||||
for _, state := range resp.States {
|
||||
log.Debug("segment ", zap.String("state.SegmentID", fmt.Sprintln(state.SegmentID)), zap.String("state", fmt.Sprintln(state.StartPosition)))
|
||||
segmentID := state.SegmentID
|
||||
segmentStates[segmentID] = state
|
||||
channelName := state.StartPosition.ChannelName
|
||||
if _, ok := channel2segs[channelName]; !ok {
|
||||
segments := make([]UniqueID, 0)
|
||||
segments = append(segments, segmentID)
|
||||
channel2segs[channelName] = segments
|
||||
} else {
|
||||
channel2segs[channelName] = append(channel2segs[channelName], segmentID)
|
||||
}
|
||||
}
|
||||
|
||||
excludeSegment := make([]UniqueID, 0)
|
||||
for id, state := range segmentStates {
|
||||
if state.State > commonpb.SegmentState_Growing {
|
||||
excludeSegment = append(excludeSegment, id)
|
||||
}
|
||||
}
|
||||
for channel, segmentIDs := range channel2segs {
|
||||
sort.Slice(segmentIDs, func(i, j int) bool {
|
||||
return segmentStates[segmentIDs[i]].StartPosition.Timestamp < segmentStates[segmentIDs[j]].StartPosition.Timestamp
|
||||
})
|
||||
toLoadSegmentIDs := make([]UniqueID, 0)
|
||||
var watchedStartPos *internalpb.MsgPosition = nil
|
||||
var startPosition *internalpb.MsgPosition = nil
|
||||
for index, id := range segmentIDs {
|
||||
if segmentStates[id].State <= commonpb.SegmentState_Growing {
|
||||
if index > 0 {
|
||||
pos := segmentStates[id].StartPosition
|
||||
if len(pos.MsgID) == 0 {
|
||||
watchedStartPos = startPosition
|
||||
break
|
||||
}
|
||||
}
|
||||
watchedStartPos = segmentStates[id].StartPosition
|
||||
break
|
||||
}
|
||||
toLoadSegmentIDs = append(toLoadSegmentIDs, id)
|
||||
watchedStartPos = segmentStates[id].EndPosition
|
||||
startPosition = segmentStates[id].StartPosition
|
||||
}
|
||||
if watchedStartPos == nil {
|
||||
watchedStartPos = &internalpb.MsgPosition{
|
||||
ChannelName: channel,
|
||||
}
|
||||
}
|
||||
|
||||
err = lpt.meta.addDmChannel(dbID, collectionID, channel, watchedStartPos)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
err = lpt.meta.addExcludeSegmentIDs(dbID, collectionID, toLoadSegmentIDs)
|
||||
if err != nil {
|
||||
status.Reason = err.Error()
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
|
||||
segment2Node := shuffleSegmentsToQueryNode(toLoadSegmentIDs, lpt.queryNodes)
|
||||
for nodeID, assignedSegmentIDs := range segment2Node {
|
||||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
// TODO: use unique id allocator to assign reqID
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63n(10000000000),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
SegmentIDs: assignedSegmentIDs,
|
||||
Schema: schema,
|
||||
}
|
||||
|
||||
queryNode := lpt.queryNodes[nodeID]
|
||||
status, err := queryNode.LoadSegments(ctx, loadSegmentRequest)
|
||||
if err != nil {
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
queryNode.AddSegments(assignedSegmentIDs, collectionID)
|
||||
}
|
||||
}
|
||||
|
||||
lpt.meta.updatePartitionState(dbID, collectionID, partitionID, querypb.PartitionState_InMemory)
|
||||
}
|
||||
|
||||
log.Debug("LoadPartitionTask Execute done",
|
||||
zap.Int64("msgID", lpt.ID()),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64s("partitionIDs", partitionIDs))
|
||||
status.ErrorCode = commonpb.ErrorCode_Success
|
||||
lpt.result = status
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
dbID := lpt.DbID
|
||||
collectionID := lpt.CollectionID
|
||||
partitionIDs := lpt.PartitionIDs
|
||||
if lpt.watchNeeded {
|
||||
err := watchDmChannels(ctx, lpt.dataService, lpt.queryNodes, lpt.meta, dbID, collectionID, lpt.Base)
|
||||
if err != nil {
|
||||
log.Debug("watchDmChannels failed", zap.Int64("msgID", lpt.ID()), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err))
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
lpt.result = status
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Debug("watchDmChannels completed", zap.Int64("msgID", lpt.ID()), zap.Int64s("partitionIDs", partitionIDs))
|
||||
lpt.result = status
|
||||
log.Debug("LoadPartitionTask postExecute done",
|
||||
zap.Int64("msgID", lpt.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
//lpt.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
type ReleasePartitionTask struct {
|
||||
BaseTask
|
||||
*querypb.ReleasePartitionsRequest
|
||||
queryNodes map[int64]*queryNodeInfo
|
||||
meta Replica
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) ID() UniqueID {
|
||||
return rpt.Base.MsgID
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) Type() commonpb.MsgType {
|
||||
return rpt.Base.MsgType
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) Timestamp() Timestamp {
|
||||
return rpt.Base.Timestamp
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) Name() string {
|
||||
return ReleasePartitionTaskName
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) PreExecute(ctx context.Context) error {
|
||||
collectionID := rpt.CollectionID
|
||||
log.Debug("start do releasePartitionTask",
|
||||
zap.Int64("msgID", rpt.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error {
|
||||
collectionID := rpt.CollectionID
|
||||
partitionIDs := rpt.PartitionIDs
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
for _, node := range rpt.queryNodes {
|
||||
status, err := node.client.ReleasePartitions(ctx, rpt.ReleasePartitionsRequest)
|
||||
if err != nil {
|
||||
rpt.result = status
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
rpt.result = status
|
||||
log.Debug("ReleasePartitionTask Execute done",
|
||||
zap.Int64("msgID", rpt.ID()),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64s("partitionIDs", partitionIDs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rpt *ReleasePartitionTask) PostExecute(ctx context.Context) error {
|
||||
dbID := rpt.DbID
|
||||
collectionID := rpt.CollectionID
|
||||
partitionIDs := rpt.PartitionIDs
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
for _, partitionID := range partitionIDs {
|
||||
err := rpt.meta.releasePartition(dbID, collectionID, partitionID)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
rpt.result = status
|
||||
return err
|
||||
}
|
||||
}
|
||||
rpt.result = status
|
||||
log.Debug("ReleasePartitionTask postExecute done",
|
||||
zap.Int64("msgID", rpt.ID()),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func watchDmChannels(ctx context.Context,
|
||||
dataService types.DataService,
|
||||
queryNodes map[int64]*queryNodeInfo,
|
||||
meta Replica,
|
||||
dbID UniqueID,
|
||||
collectionID UniqueID,
|
||||
msgBase *commonpb.MsgBase) error {
|
||||
collection, err := meta.getCollectionByID(0, collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channelRequest := datapb.GetInsertChannelsRequest{
|
||||
DbID: dbID,
|
||||
CollectionID: collectionID,
|
||||
}
|
||||
resp, err := dataService.GetInsertChannels(ctx, &channelRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.Values) == 0 {
|
||||
err = errors.New("haven't assign dm channel to collection")
|
||||
return err
|
||||
}
|
||||
|
||||
dmChannels := resp.Values
|
||||
channelsWithoutPos := make([]string, 0)
|
||||
for _, channel := range dmChannels {
|
||||
findChannel := false
|
||||
ChannelsWithPos := collection.dmChannels
|
||||
for _, ch := range ChannelsWithPos {
|
||||
if channel == ch {
|
||||
findChannel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !findChannel {
|
||||
channelsWithoutPos = append(channelsWithoutPos, channel)
|
||||
}
|
||||
}
|
||||
for _, ch := range channelsWithoutPos {
|
||||
pos := &internalpb.MsgPosition{
|
||||
ChannelName: ch,
|
||||
}
|
||||
err = meta.addDmChannel(dbID, collectionID, ch, pos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
channels2NodeID := shuffleChannelsToQueryNode(dmChannels, queryNodes)
|
||||
for nodeID, channels := range channels2NodeID {
|
||||
node := queryNodes[nodeID]
|
||||
watchDmChannelsInfo := make([]*querypb.WatchDmChannelInfo, 0)
|
||||
for _, ch := range channels {
|
||||
info := &querypb.WatchDmChannelInfo{
|
||||
ChannelID: ch,
|
||||
Pos: collection.dmChannels2Pos[ch],
|
||||
ExcludedSegments: collection.excludeSegmentIds,
|
||||
}
|
||||
watchDmChannelsInfo = append(watchDmChannelsInfo, info)
|
||||
}
|
||||
request := &querypb.WatchDmChannelsRequest{
|
||||
Base: msgBase,
|
||||
CollectionID: collectionID,
|
||||
ChannelIDs: channels,
|
||||
Infos: watchDmChannelsInfo,
|
||||
}
|
||||
_, err := node.WatchDmChannels(ctx, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.AddDmChannels(channels, collectionID)
|
||||
log.Debug("query node ", zap.String("nodeID", strconv.FormatInt(nodeID, 10)), zap.String("watch channels", fmt.Sprintln(channels)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shuffleChannelsToQueryNode(dmChannels []string, queryNodes map[int64]*queryNodeInfo) map[int64][]string {
|
||||
maxNumChannels := 0
|
||||
for _, node := range queryNodes {
|
||||
numChannels := node.getNumChannels()
|
||||
if numChannels > maxNumChannels {
|
||||
maxNumChannels = numChannels
|
||||
}
|
||||
}
|
||||
res := make(map[int64][]string)
|
||||
offset := 0
|
||||
loopAll := false
|
||||
for {
|
||||
lastOffset := offset
|
||||
if !loopAll {
|
||||
for id, node := range queryNodes {
|
||||
if node.getSegmentsLength() >= maxNumChannels {
|
||||
continue
|
||||
}
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]string, 0)
|
||||
}
|
||||
res[id] = append(res[id], dmChannels[offset])
|
||||
offset++
|
||||
if offset == len(dmChannels) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for id := range queryNodes {
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]string, 0)
|
||||
}
|
||||
res[id] = append(res[id], dmChannels[offset])
|
||||
offset++
|
||||
if offset == len(dmChannels) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastOffset == offset {
|
||||
loopAll = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, queryNodes map[int64]*queryNodeInfo) map[int64][]UniqueID {
|
||||
maxNumSegments := 0
|
||||
for _, node := range queryNodes {
|
||||
numSegments := node.getNumSegments()
|
||||
if numSegments > maxNumSegments {
|
||||
maxNumSegments = numSegments
|
||||
}
|
||||
}
|
||||
res := make(map[int64][]UniqueID)
|
||||
for nodeID := range queryNodes {
|
||||
segments := make([]UniqueID, 0)
|
||||
res[nodeID] = segments
|
||||
}
|
||||
|
||||
if len(segmentIDs) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
offset := 0
|
||||
loopAll := false
|
||||
for {
|
||||
lastOffset := offset
|
||||
if !loopAll {
|
||||
for id, node := range queryNodes {
|
||||
if node.getSegmentsLength() >= maxNumSegments {
|
||||
continue
|
||||
}
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]UniqueID, 0)
|
||||
}
|
||||
res[id] = append(res[id], segmentIDs[offset])
|
||||
offset++
|
||||
if offset == len(segmentIDs) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for id := range queryNodes {
|
||||
if _, ok := res[id]; !ok {
|
||||
res[id] = make([]UniqueID, 0)
|
||||
}
|
||||
res[id] = append(res[id], segmentIDs[offset])
|
||||
offset++
|
||||
if offset == len(segmentIDs) {
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastOffset == offset {
|
||||
loopAll = true
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
package queryservice
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"go.uber.org/zap"
|
||||
|
||||
oplog "github.com/opentracing/opentracing-go/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/trace"
|
||||
)
|
||||
|
||||
type TaskQueue interface {
|
||||
utChan() <-chan int
|
||||
utEmpty() bool
|
||||
utFull() bool
|
||||
addUnissuedTask(t task) error
|
||||
FrontUnissuedTask() task
|
||||
PopUnissuedTask() task
|
||||
AddActiveTask(t task)
|
||||
PopActiveTask(ts Timestamp) task
|
||||
Enqueue(t task) error
|
||||
}
|
||||
|
||||
type BaseTaskQueue struct {
|
||||
unissuedTasks *list.List
|
||||
activeTasks map[Timestamp]task
|
||||
utLock sync.Mutex
|
||||
atLock sync.Mutex
|
||||
|
||||
maxTaskNum int64
|
||||
|
||||
utBufChan chan int // to block scheduler
|
||||
|
||||
sched *TaskScheduler
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) utChan() <-chan int {
|
||||
return queue.utBufChan
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) utEmpty() bool {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
return queue.unissuedTasks.Len() == 0
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) utFull() bool {
|
||||
return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) addUnissuedTask(t task) error {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
|
||||
if queue.utFull() {
|
||||
return errors.New("task queue is full")
|
||||
}
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
queue.unissuedTasks.PushBack(t)
|
||||
queue.utBufChan <- 1
|
||||
return nil
|
||||
}
|
||||
|
||||
if t.Timestamp() >= queue.unissuedTasks.Back().Value.(task).Timestamp() {
|
||||
queue.unissuedTasks.PushBack(t)
|
||||
queue.utBufChan <- 1
|
||||
return nil
|
||||
}
|
||||
|
||||
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
|
||||
if t.Timestamp() <= e.Value.(task).Timestamp() {
|
||||
queue.unissuedTasks.InsertBefore(t, e)
|
||||
queue.utBufChan <- 1
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("unexpected error in addUnissuedTask")
|
||||
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) FrontUnissuedTask() task {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
log.Warn("sorry, but the unissued task list is empty!")
|
||||
return nil
|
||||
}
|
||||
|
||||
return queue.unissuedTasks.Front().Value.(task)
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) PopUnissuedTask() task {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
log.Warn("sorry, but the unissued task list is empty!")
|
||||
return nil
|
||||
}
|
||||
|
||||
ft := queue.unissuedTasks.Front()
|
||||
queue.unissuedTasks.Remove(ft)
|
||||
|
||||
return ft.Value.(task)
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) AddActiveTask(t task) {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Unlock()
|
||||
|
||||
ts := t.Timestamp()
|
||||
_, ok := queue.activeTasks[ts]
|
||||
if ok {
|
||||
log.Debug("queryService", zap.Uint64("task with timestamp ts already in active task list! ts:", ts))
|
||||
}
|
||||
|
||||
queue.activeTasks[ts] = t
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) task {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Unlock()
|
||||
|
||||
t, ok := queue.activeTasks[ts]
|
||||
if ok {
|
||||
log.Debug("queryService", zap.Uint64("task with timestamp ts has been deleted in active task list! ts:", ts))
|
||||
delete(queue.activeTasks, ts)
|
||||
return t
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) Enqueue(t task) error {
|
||||
return queue.addUnissuedTask(t)
|
||||
}
|
||||
|
||||
type DdTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (queue *DdTaskQueue) Enqueue(t task) error {
|
||||
queue.lock.Lock()
|
||||
defer queue.lock.Unlock()
|
||||
return queue.BaseTaskQueue.Enqueue(t)
|
||||
}
|
||||
|
||||
func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue {
|
||||
return &DdTaskQueue{
|
||||
BaseTaskQueue: BaseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[Timestamp]task),
|
||||
maxTaskNum: 1024,
|
||||
utBufChan: make(chan int, 1024),
|
||||
sched: sched,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type TaskScheduler struct {
|
||||
DdQueue TaskQueue
|
||||
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewTaskScheduler(ctx context.Context) *TaskScheduler {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
s := &TaskScheduler{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.DdQueue = NewDdTaskQueue(s)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) scheduleDdTask() task {
|
||||
return sched.DdQueue.PopUnissuedTask()
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
|
||||
span, ctx := trace.StartSpanFromContext(t.TraceCtx(),
|
||||
opentracing.Tags{
|
||||
"Type": t.Name(),
|
||||
"ID": t.ID(),
|
||||
})
|
||||
defer span.Finish()
|
||||
span.LogFields(oplog.Int64("scheduler process PreExecute", t.ID()))
|
||||
err := t.PreExecute(ctx)
|
||||
|
||||
defer func() {
|
||||
t.Notify(err)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Debug("preExecute err", zap.String("reason", err.Error()))
|
||||
trace.LogError(span, err)
|
||||
return
|
||||
}
|
||||
|
||||
span.LogFields(oplog.Int64("scheduler process AddActiveTask", t.ID()))
|
||||
q.AddActiveTask(t)
|
||||
|
||||
defer func() {
|
||||
span.LogFields(oplog.Int64("scheduler process PopActiveTask", t.ID()))
|
||||
q.PopActiveTask(t.Timestamp())
|
||||
}()
|
||||
span.LogFields(oplog.Int64("scheduler process Execute", t.ID()))
|
||||
err = t.Execute(ctx)
|
||||
if err != nil {
|
||||
log.Debug("execute err", zap.String("reason", err.Error()))
|
||||
trace.LogError(span, err)
|
||||
return
|
||||
}
|
||||
span.LogFields(oplog.Int64("scheduler process PostExecute", t.ID()))
|
||||
err = t.PostExecute(ctx)
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) definitionLoop() {
|
||||
defer sched.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-sched.ctx.Done():
|
||||
return
|
||||
case <-sched.DdQueue.utChan():
|
||||
if !sched.DdQueue.utEmpty() {
|
||||
t := sched.scheduleDdTask()
|
||||
sched.processTask(t, sched.DdQueue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) Start() error {
|
||||
sched.wg.Add(1)
|
||||
go sched.definitionLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) Close() {
|
||||
sched.cancel()
|
||||
sched.wg.Wait()
|
||||
}
|
Loading…
Reference in New Issue