mirror of https://github.com/milvus-io/milvus.git
Support load segments/channels in parallel (#20036)
Signed-off-by: xiaofan-luan <xiaofan.luan@zilliz.com> Signed-off-by: xiaofan-luan <xiaofan.luan@zilliz.com>pull/20200/head
parent
bf5fb3cd99
commit
4136009a9a
|
@ -50,6 +50,7 @@ type Collection struct {
|
|||
partitionIDs []UniqueID
|
||||
schema *schemapb.CollectionSchema
|
||||
|
||||
// TODO, remove delta channels
|
||||
channelMu sync.RWMutex
|
||||
vChannels []Channel
|
||||
pChannels []Channel
|
||||
|
@ -225,6 +226,41 @@ func (c *Collection) getVDeltaChannels() []Channel {
|
|||
return tmpChannels
|
||||
}
|
||||
|
||||
func (c *Collection) AddChannels(toLoadChannels []Channel, VPChannels map[string]string) []Channel {
|
||||
c.channelMu.Lock()
|
||||
defer c.channelMu.Unlock()
|
||||
|
||||
retVChannels := []Channel{}
|
||||
for _, toLoadChannel := range toLoadChannels {
|
||||
if !c.isVChannelExist(toLoadChannel) {
|
||||
retVChannels = append(retVChannels, toLoadChannel)
|
||||
c.vChannels = append(c.vChannels, toLoadChannel)
|
||||
if !c.isPChannelExist(VPChannels[toLoadChannel]) {
|
||||
c.pChannels = append(c.pChannels, VPChannels[toLoadChannel])
|
||||
}
|
||||
}
|
||||
}
|
||||
return retVChannels
|
||||
}
|
||||
|
||||
func (c *Collection) isVChannelExist(channel string) bool {
|
||||
for _, vChannel := range c.vChannels {
|
||||
if vChannel == channel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Collection) isPChannelExist(channel string) bool {
|
||||
for _, vChannel := range c.pChannels {
|
||||
if vChannel == channel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// addVChannels add virtual channels to collection
|
||||
func (c *Collection) addVDeltaChannels(channels []Channel) {
|
||||
c.channelMu.Lock()
|
||||
|
@ -268,6 +304,41 @@ func (c *Collection) removeVDeltaChannel(channel Channel) {
|
|||
metrics.QueryNodeNumDeltaChannels.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Sub(float64(len(c.vDeltaChannels)))
|
||||
}
|
||||
|
||||
func (c *Collection) AddVDeltaChannels(toLoadChannels []Channel, VPChannels map[string]string) []Channel {
|
||||
c.channelMu.Lock()
|
||||
defer c.channelMu.Unlock()
|
||||
|
||||
retVDeltaChannels := []Channel{}
|
||||
for _, toLoadChannel := range toLoadChannels {
|
||||
if !c.isVDeltaChannelExist(toLoadChannel) {
|
||||
retVDeltaChannels = append(retVDeltaChannels, toLoadChannel)
|
||||
c.vDeltaChannels = append(c.vDeltaChannels, toLoadChannel)
|
||||
if !c.isPDeltaChannelExist(VPChannels[toLoadChannel]) {
|
||||
c.pDeltaChannels = append(c.pDeltaChannels, VPChannels[toLoadChannel])
|
||||
}
|
||||
}
|
||||
}
|
||||
return retVDeltaChannels
|
||||
}
|
||||
|
||||
func (c *Collection) isVDeltaChannelExist(channel string) bool {
|
||||
for _, vDeltaChanel := range c.vDeltaChannels {
|
||||
if vDeltaChanel == channel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Collection) isPDeltaChannelExist(channel string) bool {
|
||||
for _, vChannel := range c.pDeltaChannels {
|
||||
if vChannel == channel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// setReleaseTime records when collection is released
|
||||
func (c *Collection) setReleaseTime(t Timestamp, released bool) {
|
||||
c.releaseMu.Lock()
|
||||
|
|
|
@ -49,32 +49,13 @@ func (dsService *dataSyncService) getFlowGraphNum() int {
|
|||
return len(dsService.dmlChannel2FlowGraph) + len(dsService.deltaChannel2FlowGraph)
|
||||
}
|
||||
|
||||
// checkReplica used to check replica info before init flow graph, it's a private method of dataSyncService
|
||||
func (dsService *dataSyncService) checkReplica(collectionID UniqueID) error {
|
||||
// check if the collection exists
|
||||
coll, err := dsService.metaReplica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, channel := range coll.getVChannels() {
|
||||
if _, err := dsService.tSafeReplica.getTSafe(channel); err != nil {
|
||||
return fmt.Errorf("getTSafe failed, err = %s", err)
|
||||
}
|
||||
}
|
||||
for _, channel := range coll.getVDeltaChannels() {
|
||||
if _, err := dsService.tSafeReplica.getTSafe(channel); err != nil {
|
||||
return fmt.Errorf("getTSafe failed, err = %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addFlowGraphsForDMLChannels add flowGraphs to dmlChannel2FlowGraph
|
||||
func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels []string) (map[string]*queryNodeFlowGraph, error) {
|
||||
dsService.mu.Lock()
|
||||
defer dsService.mu.Unlock()
|
||||
|
||||
if err := dsService.checkReplica(collectionID); err != nil {
|
||||
_, err := dsService.metaReplica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -118,7 +99,8 @@ func (dsService *dataSyncService) addFlowGraphsForDeltaChannels(collectionID Uni
|
|||
dsService.mu.Lock()
|
||||
defer dsService.mu.Unlock()
|
||||
|
||||
if err := dsService.checkReplica(collectionID); err != nil {
|
||||
_, err := dsService.metaReplica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -153,57 +153,6 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDataSyncService_checkReplica(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
replica, err := genSimpleReplica()
|
||||
assert.NoError(t, err)
|
||||
|
||||
fac := genFactory()
|
||||
assert.NoError(t, err)
|
||||
|
||||
tSafe := newTSafeReplica()
|
||||
dataSyncService := newDataSyncService(ctx, replica, tSafe, fac)
|
||||
assert.NotNil(t, dataSyncService)
|
||||
defer dataSyncService.close()
|
||||
|
||||
t.Run("test checkReplica", func(t *testing.T) {
|
||||
err = dataSyncService.checkReplica(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test collection doesn't exist", func(t *testing.T) {
|
||||
err = dataSyncService.metaReplica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
err = dataSyncService.checkReplica(defaultCollectionID)
|
||||
assert.Error(t, err)
|
||||
coll := dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema())
|
||||
assert.NotNil(t, coll)
|
||||
})
|
||||
|
||||
t.Run("test cannot find tSafe", func(t *testing.T) {
|
||||
coll, err := dataSyncService.metaReplica.getCollectionByID(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
coll.addVDeltaChannels([]Channel{defaultDeltaChannel})
|
||||
coll.addVChannels([]Channel{defaultDMLChannel})
|
||||
|
||||
dataSyncService.tSafeReplica.addTSafe(defaultDeltaChannel)
|
||||
dataSyncService.tSafeReplica.addTSafe(defaultDMLChannel)
|
||||
|
||||
dataSyncService.tSafeReplica.removeTSafe(defaultDeltaChannel)
|
||||
err = dataSyncService.checkReplica(defaultCollectionID)
|
||||
assert.Error(t, err)
|
||||
|
||||
dataSyncService.tSafeReplica.removeTSafe(defaultDMLChannel)
|
||||
err = dataSyncService.checkReplica(defaultCollectionID)
|
||||
assert.Error(t, err)
|
||||
|
||||
dataSyncService.tSafeReplica.addTSafe(defaultDeltaChannel)
|
||||
dataSyncService.tSafeReplica.addTSafe(defaultDMLChannel)
|
||||
})
|
||||
}
|
||||
|
||||
type DataSyncServiceSuite struct {
|
||||
suite.Suite
|
||||
factory dependency.Factory
|
||||
|
|
|
@ -20,8 +20,10 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
@ -308,37 +310,57 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
|||
node: node,
|
||||
}
|
||||
|
||||
err := node.scheduler.queue.Enqueue(task)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
return status, nil
|
||||
}
|
||||
log.Info("watchDmChannelsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()), zap.Int64("replicaID", in.GetReplicaID()))
|
||||
waitFunc := func() (*commonpb.Status, error) {
|
||||
err = task.WaitToFinish()
|
||||
startTs := time.Now()
|
||||
log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.String("channelName", in.Infos[0].GetChannelName()),
|
||||
zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
// currently we only support load one channel as a time
|
||||
node.taskLock.RLock(strconv.FormatInt(in.Infos[0].CollectionID, 10))
|
||||
defer node.taskLock.RUnlock(strconv.FormatInt(in.Infos[0].CollectionID, 10))
|
||||
future := node.taskPool.Submit(func() (interface{}, error) {
|
||||
log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.String("channelName", in.Infos[0].GetChannelName()),
|
||||
zap.Duration("timeInQueue", time.Since(startTs)))
|
||||
err := task.PreExecute(ctx)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Warn("failed to subscribe channel on preExecute ", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn("failed to subscribe channel ", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
err = task.PostExecute(ctx)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn("failed to unsubscribe channel on postExecute ", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName())
|
||||
sc.SetupFirstVersion()
|
||||
|
||||
log.Info("watchDmChannelsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return waitFunc()
|
||||
})
|
||||
ret, _ := future.Await()
|
||||
return ret.(*commonpb.Status), nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
|
||||
|
@ -375,13 +397,15 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||
node: node,
|
||||
}
|
||||
|
||||
node.taskLock.Lock(strconv.FormatInt(dct.req.CollectionID, 10))
|
||||
defer node.taskLock.Unlock(strconv.FormatInt(dct.req.CollectionID, 10))
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Warn("failed to enqueue subscribe channel task", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
log.Info("unsubDmChannel(ReleaseCollection) enqueue done", zap.Int64("collectionID", req.GetCollectionID()))
|
||||
|
@ -389,7 +413,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
|
|||
func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
log.Warn("failed to do subscribe channel task successfully", zap.Error(err))
|
||||
return
|
||||
}
|
||||
log.Info("unsubDmChannel(ReleaseCollection) WaitToFinish done", zap.Int64("collectionID", req.GetCollectionID()))
|
||||
|
@ -439,35 +463,66 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
|
|||
for _, info := range in.Infos {
|
||||
segmentIDs = append(segmentIDs, info.SegmentID)
|
||||
}
|
||||
err := node.scheduler.queue.Enqueue(task)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
return status, nil
|
||||
sort.SliceStable(segmentIDs, func(i, j int) bool {
|
||||
return segmentIDs[i] < segmentIDs[j]
|
||||
})
|
||||
|
||||
startTs := time.Now()
|
||||
log.Info("loadSegmentsTask init", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.Int64s("segmentIDs", segmentIDs),
|
||||
zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
|
||||
node.taskLock.RLock(strconv.FormatInt(in.CollectionID, 10))
|
||||
for _, segmentID := range segmentIDs {
|
||||
node.taskLock.Lock(strconv.FormatInt(segmentID, 10))
|
||||
}
|
||||
|
||||
log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
|
||||
waitFunc := func() (*commonpb.Status, error) {
|
||||
err = task.WaitToFinish()
|
||||
// release all task locks
|
||||
defer func() {
|
||||
node.taskLock.RUnlock(strconv.FormatInt(in.CollectionID, 10))
|
||||
for _, id := range segmentIDs {
|
||||
node.taskLock.Unlock(strconv.FormatInt(id, 10))
|
||||
}
|
||||
}()
|
||||
future := node.taskPool.Submit(func() (interface{}, error) {
|
||||
log.Info("loadSegmentsTask start ", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.Int64s("segmentIDs", segmentIDs),
|
||||
zap.Duration("timeInQueue", time.Since(startTs)))
|
||||
err := task.PreExecute(ctx)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Warn("failed to load segments on preExecute ", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
err = task.Execute(ctx)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn("failed to load segment", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
err = task.PostExecute(ctx)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn("failed to load segments on postExecute ", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
log.Info("loadSegmentsTask done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return waitFunc()
|
||||
})
|
||||
ret, _ := future.Await()
|
||||
return ret.(*commonpb.Status), nil
|
||||
}
|
||||
|
||||
// ReleaseCollection clears all data related to this collection on the querynode
|
||||
|
@ -490,6 +545,8 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
|
|||
node: node,
|
||||
}
|
||||
|
||||
node.taskLock.Lock(strconv.FormatInt(dct.req.CollectionID, 10))
|
||||
defer node.taskLock.Unlock(strconv.FormatInt(dct.req.CollectionID, 10))
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
|
@ -536,6 +593,8 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
|
|||
node: node,
|
||||
}
|
||||
|
||||
node.taskLock.Lock(strconv.FormatInt(dct.req.CollectionID, 10))
|
||||
defer node.taskLock.Unlock(strconv.FormatInt(dct.req.CollectionID, 10))
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
|
@ -587,6 +646,23 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
|
|||
return node.TransferRelease(ctx, in)
|
||||
}
|
||||
|
||||
log.Info("start to release segments", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", in.SegmentIDs))
|
||||
node.taskLock.RLock(strconv.FormatInt(in.CollectionID, 10))
|
||||
sort.SliceStable(in.SegmentIDs, func(i, j int) bool {
|
||||
return in.SegmentIDs[i] < in.SegmentIDs[j]
|
||||
})
|
||||
|
||||
for _, segmentID := range in.SegmentIDs {
|
||||
node.taskLock.Lock(strconv.FormatInt(segmentID, 10))
|
||||
}
|
||||
|
||||
// release all task locks
|
||||
defer func() {
|
||||
node.taskLock.RUnlock(strconv.FormatInt(in.CollectionID, 10))
|
||||
for _, id := range in.SegmentIDs {
|
||||
node.taskLock.Unlock(strconv.FormatInt(id, 10))
|
||||
}
|
||||
}()
|
||||
for _, id := range in.SegmentIDs {
|
||||
switch in.GetScope() {
|
||||
case querypb.DataScope_Streaming:
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type loadSegmentsTask struct {
|
||||
baseTask
|
||||
req *queryPb.LoadSegmentsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
// loadSegmentsTask
|
||||
func (l *loadSegmentsTask) PreExecute(ctx context.Context) error {
|
||||
log.Info("LoadSegmentTask PreExecute start", zap.Int64("msgID", l.req.Base.MsgID))
|
||||
var err error
|
||||
// init meta
|
||||
collectionID := l.req.GetCollectionID()
|
||||
l.node.metaReplica.addCollection(collectionID, l.req.GetSchema())
|
||||
for _, partitionID := range l.req.GetLoadMeta().GetPartitionIDs() {
|
||||
err = l.node.metaReplica.addPartition(collectionID, partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// filter segments that are already loaded in this querynode
|
||||
var filteredInfos []*queryPb.SegmentLoadInfo
|
||||
for _, info := range l.req.Infos {
|
||||
has, err := l.node.metaReplica.hasSegment(info.SegmentID, segmentTypeSealed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has {
|
||||
filteredInfos = append(filteredInfos, info)
|
||||
} else {
|
||||
log.Info("ignore segment that is already loaded", zap.Int64("collectionID", info.CollectionID), zap.Int64("segmentID", info.SegmentID))
|
||||
}
|
||||
}
|
||||
l.req.Infos = filteredInfos
|
||||
log.Info("LoadSegmentTask PreExecute done", zap.Int64("msgID", l.req.Base.MsgID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *loadSegmentsTask) Execute(ctx context.Context) error {
|
||||
log.Info("LoadSegmentTask Execute start", zap.Int64("msgID", l.req.Base.MsgID))
|
||||
|
||||
if len(l.req.Infos) == 0 {
|
||||
log.Info("all segments loaded", zap.Int64("msgID", l.req.GetBase().GetMsgID()))
|
||||
return nil
|
||||
}
|
||||
|
||||
segmentIDs := lo.Map(l.req.Infos, func(info *queryPb.SegmentLoadInfo, idx int) UniqueID { return info.SegmentID })
|
||||
l.node.metaReplica.addSegmentsLoadingList(segmentIDs)
|
||||
defer l.node.metaReplica.removeSegmentsLoadingList(segmentIDs)
|
||||
err := l.node.loader.LoadSegment(l.ctx, l.req, segmentTypeSealed)
|
||||
if err != nil {
|
||||
log.Warn("failed to load segment", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
vchanName := make([]string, 0)
|
||||
for _, deltaPosition := range l.req.DeltaPositions {
|
||||
vchanName = append(vchanName, deltaPosition.ChannelName)
|
||||
}
|
||||
|
||||
// TODO delta channel need to released 1. if other watchDeltaChannel fail 2. when segment release
|
||||
err = l.watchDeltaChannel(vchanName)
|
||||
if err != nil {
|
||||
// roll back
|
||||
for _, segment := range l.req.Infos {
|
||||
l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed)
|
||||
}
|
||||
log.Warn("failed to watch Delta channel while load segment", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
runningGroup, groupCtx := errgroup.WithContext(l.ctx)
|
||||
for _, deltaPosition := range l.req.DeltaPositions {
|
||||
pos := deltaPosition
|
||||
runningGroup.Go(func() error {
|
||||
// reload data from dml channel
|
||||
return l.node.loader.FromDmlCPLoadDelete(groupCtx, l.req.CollectionID, pos)
|
||||
})
|
||||
}
|
||||
err = runningGroup.Wait()
|
||||
if err != nil {
|
||||
for _, segment := range l.req.Infos {
|
||||
l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed)
|
||||
}
|
||||
log.Warn("failed to load delete data while load segment", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("LoadSegmentTask Execute done", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Int64("msgID", l.req.Base.MsgID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// internal helper function to subscribe delta channel
|
||||
func (l *loadSegmentsTask) watchDeltaChannel(vchanName []string) error {
|
||||
collectionID := l.req.CollectionID
|
||||
var vDeltaChannels []string
|
||||
VPDeltaChannels := make(map[string]string)
|
||||
for _, v := range vchanName {
|
||||
dc, err := funcutil.ConvertChannelName(v, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta)
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannels, failed to convert deltaChannel from dmlChannel", zap.String("DmlChannel", v), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
p := funcutil.ToPhysicalChannel(dc)
|
||||
vDeltaChannels = append(vDeltaChannels, dc)
|
||||
VPDeltaChannels[dc] = p
|
||||
}
|
||||
log.Info("Starting WatchDeltaChannels ...",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("channels", VPDeltaChannels),
|
||||
)
|
||||
|
||||
coll, err := l.node.metaReplica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// filter out duplicated channels
|
||||
vDeltaChannels = coll.AddVDeltaChannels(vDeltaChannels, VPDeltaChannels)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, vDeltaChannel := range vDeltaChannels {
|
||||
coll.removeVDeltaChannel(vDeltaChannel)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if len(vDeltaChannels) == 0 {
|
||||
log.Warn("all delta channels has be added before, ignore watch delta requests")
|
||||
return nil
|
||||
}
|
||||
|
||||
channel2FlowGraph, err := l.node.dataSyncService.addFlowGraphsForDeltaChannels(collectionID, vDeltaChannels)
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID())
|
||||
|
||||
// channels as consumer
|
||||
for channel, fg := range channel2FlowGraph {
|
||||
pchannel := VPDeltaChannels[channel]
|
||||
// use pChannel to consume
|
||||
err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName)
|
||||
if err != nil {
|
||||
log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Close()
|
||||
}
|
||||
gcChannels := make([]Channel, 0)
|
||||
for channel := range channel2FlowGraph {
|
||||
gcChannels = append(gcChannels, channel)
|
||||
}
|
||||
l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
|
||||
|
||||
// create tSafe
|
||||
for _, channel := range vDeltaChannels {
|
||||
l.node.tSafeReplica.addTSafe(channel)
|
||||
}
|
||||
|
||||
// add tsafe watch in query shard if exists, we find no way to handle it if query shard not exist
|
||||
for _, channel := range vDeltaChannels {
|
||||
dmlChannel, err := funcutil.ConvertChannelName(channel, Params.CommonCfg.RootCoordDelta, Params.CommonCfg.RootCoordDml)
|
||||
if err != nil {
|
||||
log.Error("failed to convert delta channel to dml", zap.String("channel", channel), zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
err = l.node.queryShardService.addQueryShard(collectionID, dmlChannel, l.req.GetReplicaID())
|
||||
if err != nil {
|
||||
log.Error("failed to add shard Service to query shard", zap.String("channel", channel), zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// start flow graphs
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Start()
|
||||
}
|
||||
|
||||
log.Info("WatchDeltaChannels done", zap.Int64("collectionID", collectionID), zap.String("ChannelIDs", fmt.Sprintln(vDeltaChannels)))
|
||||
return nil
|
||||
}
|
|
@ -417,9 +417,8 @@ func (replica *metaReplica) addPartitionPrivate(collection *Collection, partitio
|
|||
collection.addPartitionID(partitionID)
|
||||
var newPartition = newPartition(collection.ID(), partitionID)
|
||||
replica.partitions[partitionID] = newPartition
|
||||
metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(len(replica.partitions)))
|
||||
}
|
||||
|
||||
metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(len(replica.partitions)))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,9 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/indexcgowrapper"
|
||||
"github.com/milvus-io/milvus/internal/util/lock"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
)
|
||||
|
||||
// ---------- unittest util functions ----------
|
||||
|
@ -1262,13 +1264,10 @@ func genSimpleReplicaWithSealSegment(ctx context.Context) (ReplicaInterface, err
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col, err := r.getCollectionByID(defaultCollectionID)
|
||||
_, err = r.getCollectionByID(defaultCollectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.addVChannels([]Channel{
|
||||
defaultDeltaChannel,
|
||||
})
|
||||
return r, nil
|
||||
}
|
||||
|
||||
|
@ -1661,6 +1660,12 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
|
|||
node.etcdCli = etcdCli
|
||||
node.initSession()
|
||||
|
||||
node.taskPool, err = concurrency.NewPool(2, ants.WithPreAlloc(true))
|
||||
if err != nil {
|
||||
log.Error("QueryNode init channel pool failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
node.taskLock = lock.NewKeyLock()
|
||||
etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath)
|
||||
node.etcdKV = etcdKV
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/concurrency"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||
"github.com/milvus-io/milvus/internal/util/lock"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
|
@ -120,6 +121,10 @@ type QueryNode struct {
|
|||
|
||||
// cgoPool is the worker pool to control concurrency of cgo call
|
||||
cgoPool *concurrency.Pool
|
||||
// pool for load/release channel
|
||||
taskPool *concurrency.Pool
|
||||
// lock to avoid same chanel/channel run multiple times
|
||||
taskLock *lock.KeyLock
|
||||
}
|
||||
|
||||
// NewQueryNode will return a QueryNode with abnormal state.
|
||||
|
@ -258,6 +263,15 @@ func (node *QueryNode) Init() error {
|
|||
return
|
||||
}
|
||||
|
||||
node.taskPool, err = concurrency.NewPool(cpuNum, ants.WithPreAlloc(true))
|
||||
if err != nil {
|
||||
log.Error("QueryNode init channel pool failed", zap.Error(err))
|
||||
initError = err
|
||||
return
|
||||
}
|
||||
|
||||
node.taskLock = lock.NewKeyLock()
|
||||
|
||||
// ensure every cgopool go routine is locked with a OS thread
|
||||
// so openmp in knowhere won't create too much request
|
||||
sig := make(chan struct{})
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TODO, remove queryShardService, it's not used any more.
|
||||
type queryShardService struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
|
|
@ -20,12 +20,13 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
|
@ -46,6 +47,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/hardware"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -703,7 +706,7 @@ func (loader *segmentLoader) loadDeltaLogs(ctx context.Context, segment *Segment
|
|||
}
|
||||
|
||||
func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collectionID int64, position *internalpb.MsgPosition) error {
|
||||
log.Info("from dml check point load delete", zap.Any("position", position))
|
||||
startTs := time.Now()
|
||||
stream, err := loader.factory.NewMsgStream(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -717,7 +720,12 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
|
|||
pChannelName := funcutil.ToPhysicalChannel(position.ChannelName)
|
||||
position.ChannelName = pChannelName
|
||||
|
||||
stream.AsConsumer([]string{pChannelName}, fmt.Sprintf("querynode-%d-%d", Params.QueryNodeCfg.GetNodeID(), collectionID), mqwrapper.SubscriptionPositionUnknown)
|
||||
ts, _ := tsoutil.ParseTS(position.Timestamp)
|
||||
|
||||
// Random the subname in case we trying to load same delta at the same time
|
||||
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", Params.QueryNodeCfg.GetNodeID(), collectionID, rand.Int())
|
||||
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("subName", subName), zap.Time("positionTs", ts))
|
||||
stream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
// make sure seek position is earlier than
|
||||
lastMsgID, err := stream.GetLatestMsgID(pChannelName)
|
||||
if err != nil {
|
||||
|
@ -730,7 +738,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
|
|||
}
|
||||
|
||||
if reachLatest || lastMsgID.AtEarliestPosition() {
|
||||
log.Info("there is no more delta msg", zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName))
|
||||
log.Info("there is no more delta msg", zap.Int64("collectionID", collectionID), zap.String("channel", pChannelName))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -748,7 +756,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
|
|||
}
|
||||
|
||||
log.Info("start read delta msg from seek position to last position",
|
||||
zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName), zap.Any("seek pos", position), zap.Any("last msg", lastMsgID))
|
||||
zap.Int64("collectionID", collectionID), zap.String("channel", pChannelName), zap.Any("seekPos", position), zap.Any("lastMsg", lastMsgID))
|
||||
hasMore := true
|
||||
for hasMore {
|
||||
select {
|
||||
|
@ -791,7 +799,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
|
|||
ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID)
|
||||
if err != nil {
|
||||
log.Warn("check whether current MsgID less than last MsgID failed",
|
||||
zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName), zap.Error(err))
|
||||
zap.Int64("collectionID", collectionID), zap.String("channel", pChannelName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -803,8 +811,8 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
|
|||
}
|
||||
}
|
||||
|
||||
log.Info("All data has been read, there is no more data", zap.Int64("Collection ID", collectionID),
|
||||
zap.String("channel", pChannelName), zap.Any("msg id", position.GetMsgID()))
|
||||
log.Info("All data has been read, there is no more data", zap.Int64("collectionID", collectionID),
|
||||
zap.String("channel", pChannelName), zap.Any("msgID", position.GetMsgID()))
|
||||
for segmentID, pks := range delData.deleteIDs {
|
||||
segment, err := loader.metaReplica.getSegmentByID(segmentID, segmentTypeSealed)
|
||||
if err != nil {
|
||||
|
@ -821,7 +829,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection
|
|||
}
|
||||
}
|
||||
|
||||
log.Info("from dml check point load done", zap.Any("msg id", position.GetMsgID()))
|
||||
log.Info("from dml check point load done", zap.String("subName", subName), zap.Any("timeTake", time.Since(startTs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -19,23 +19,13 @@ package querynode
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime/debug"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type task interface {
|
||||
|
@ -90,18 +80,6 @@ func (b *baseTask) Notify(err error) {
|
|||
b.done <- err
|
||||
}
|
||||
|
||||
type watchDmChannelsTask struct {
|
||||
baseTask
|
||||
req *queryPb.WatchDmChannelsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
type loadSegmentsTask struct {
|
||||
baseTask
|
||||
req *queryPb.LoadSegmentsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
type releaseCollectionTask struct {
|
||||
baseTask
|
||||
req *queryPb.ReleaseCollectionRequest
|
||||
|
@ -114,493 +92,6 @@ type releasePartitionsTask struct {
|
|||
node *QueryNode
|
||||
}
|
||||
|
||||
// watchDmChannelsTask
|
||||
func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||
collectionID := w.req.CollectionID
|
||||
partitionIDs := w.req.GetPartitionIDs()
|
||||
|
||||
lType := w.req.GetLoadMeta().GetLoadType()
|
||||
if lType == queryPb.LoadType_UnKnownType {
|
||||
// if no partitionID is specified, load type is load collection
|
||||
if len(partitionIDs) != 0 {
|
||||
lType = queryPb.LoadType_LoadPartition
|
||||
} else {
|
||||
lType = queryPb.LoadType_LoadCollection
|
||||
}
|
||||
}
|
||||
|
||||
// get all vChannels
|
||||
var vChannels, pChannels []Channel
|
||||
VPChannels := make(map[string]string) // map[vChannel]pChannel
|
||||
for _, info := range w.req.Infos {
|
||||
v := info.ChannelName
|
||||
p := funcutil.ToPhysicalChannel(info.ChannelName)
|
||||
vChannels = append(vChannels, v)
|
||||
pChannels = append(pChannels, p)
|
||||
VPChannels[v] = p
|
||||
}
|
||||
|
||||
if len(VPChannels) != len(vChannels) {
|
||||
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
|
||||
}
|
||||
|
||||
log.Info("Starting WatchDmChannels ...",
|
||||
zap.String("collectionName", w.req.Schema.Name),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", w.req.GetReplicaID()),
|
||||
zap.Any("load type", lType),
|
||||
zap.Strings("vChannels", vChannels),
|
||||
zap.Strings("pChannels", pChannels),
|
||||
)
|
||||
|
||||
// init collection meta
|
||||
coll := w.node.metaReplica.addCollection(collectionID, w.req.Schema)
|
||||
|
||||
loadedChannelCounter := 0
|
||||
for _, toLoadChannel := range vChannels {
|
||||
for _, loadedChannel := range coll.vChannels {
|
||||
if toLoadChannel == loadedChannel {
|
||||
loadedChannelCounter++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if all channels has been loaded, if YES, should do nothing and return
|
||||
// in case of query coord trigger same watchDmChannelTask on multi
|
||||
if len(vChannels) == loadedChannelCounter {
|
||||
log.Warn("All channel has been loaded, skip this watchDmChannelsTask")
|
||||
return nil
|
||||
}
|
||||
|
||||
//add shard cluster
|
||||
for _, vchannel := range vChannels {
|
||||
w.node.ShardClusterService.addShardCluster(w.req.GetCollectionID(), w.req.GetReplicaID(), vchannel)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, vchannel := range vChannels {
|
||||
w.node.ShardClusterService.releaseShardCluster(vchannel)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// load growing segments
|
||||
unFlushedSegments := make([]*queryPb.SegmentLoadInfo, 0)
|
||||
unFlushedSegmentIDs := make([]UniqueID, 0)
|
||||
for _, info := range w.req.Infos {
|
||||
for _, ufInfoID := range info.GetUnflushedSegmentIds() {
|
||||
// unFlushed segment may not have binLogs, skip loading
|
||||
ufInfo := w.req.GetSegmentInfos()[ufInfoID]
|
||||
if ufInfo == nil {
|
||||
log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", ufInfoID))
|
||||
continue
|
||||
}
|
||||
if len(ufInfo.GetBinlogs()) > 0 {
|
||||
unFlushedSegments = append(unFlushedSegments, &queryPb.SegmentLoadInfo{
|
||||
SegmentID: ufInfo.ID,
|
||||
PartitionID: ufInfo.PartitionID,
|
||||
CollectionID: ufInfo.CollectionID,
|
||||
BinlogPaths: ufInfo.Binlogs,
|
||||
NumOfRows: ufInfo.NumOfRows,
|
||||
Statslogs: ufInfo.Statslogs,
|
||||
Deltalogs: ufInfo.Deltalogs,
|
||||
InsertChannel: ufInfo.InsertChannel,
|
||||
})
|
||||
unFlushedSegmentIDs = append(unFlushedSegmentIDs, ufInfo.GetID())
|
||||
} else {
|
||||
log.Info("skip segment which binlog is empty", zap.Int64("segmentID", ufInfo.ID))
|
||||
}
|
||||
}
|
||||
}
|
||||
req := &queryPb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments),
|
||||
commonpbutil.WithMsgID(w.req.Base.MsgID), // use parent task's msgID
|
||||
),
|
||||
Infos: unFlushedSegments,
|
||||
CollectionID: collectionID,
|
||||
Schema: w.req.GetSchema(),
|
||||
LoadMeta: w.req.GetLoadMeta(),
|
||||
}
|
||||
|
||||
// update partition info from unFlushedSegments and loadMeta
|
||||
for _, info := range req.Infos {
|
||||
err = w.node.metaReplica.addPartition(collectionID, info.PartitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, partitionID := range req.GetLoadMeta().GetPartitionIDs() {
|
||||
err = w.node.metaReplica.addPartition(collectionID, partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("loading growing segments in WatchDmChannels...",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs),
|
||||
)
|
||||
err = w.node.loader.LoadSegment(w.ctx, req, segmentTypeGrowing)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
return err
|
||||
}
|
||||
log.Info("successfully load growing segments done in WatchDmChannels",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs),
|
||||
)
|
||||
|
||||
// remove growing segment if watch dmChannels failed
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, segmentID := range unFlushedSegmentIDs {
|
||||
w.node.metaReplica.removeSegment(segmentID, segmentTypeGrowing)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// So far, we don't support to enable each node with two different channel
|
||||
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID())
|
||||
|
||||
// group channels by to seeking or consuming
|
||||
channel2SeekPosition := make(map[string]*internalpb.MsgPosition)
|
||||
|
||||
// for channel with no position
|
||||
channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition)
|
||||
for _, info := range w.req.Infos {
|
||||
if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 {
|
||||
channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition
|
||||
continue
|
||||
}
|
||||
info.SeekPosition.MsgGroup = consumeSubName
|
||||
channel2SeekPosition[info.ChannelName] = info.SeekPosition
|
||||
}
|
||||
log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID))
|
||||
|
||||
// add excluded segments for unFlushed segments,
|
||||
// unFlushed segments before check point should be filtered out.
|
||||
unFlushedCheckPointInfos := make([]*datapb.SegmentInfo, 0)
|
||||
for _, info := range w.req.Infos {
|
||||
for _, ufsID := range info.GetUnflushedSegmentIds() {
|
||||
unFlushedCheckPointInfos = append(unFlushedCheckPointInfos, w.req.SegmentInfos[ufsID])
|
||||
}
|
||||
}
|
||||
w.node.metaReplica.addExcludedSegments(collectionID, unFlushedCheckPointInfos)
|
||||
unflushedSegmentIDs := make([]UniqueID, len(unFlushedCheckPointInfos))
|
||||
for i, segInfo := range unFlushedCheckPointInfos {
|
||||
unflushedSegmentIDs[i] = segInfo.GetID()
|
||||
}
|
||||
log.Info("watchDMChannel, add check points info for unflushed segments done",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("unflushedSegmentIDs", unflushedSegmentIDs),
|
||||
)
|
||||
|
||||
// add excluded segments for flushed segments,
|
||||
// flushed segments with later check point than seekPosition should be filtered out.
|
||||
flushedCheckPointInfos := make([]*datapb.SegmentInfo, 0)
|
||||
for _, info := range w.req.Infos {
|
||||
for _, flushedSegmentID := range info.GetFlushedSegmentIds() {
|
||||
flushedSegment := w.req.SegmentInfos[flushedSegmentID]
|
||||
for _, position := range channel2SeekPosition {
|
||||
if flushedSegment.DmlPosition != nil &&
|
||||
flushedSegment.DmlPosition.ChannelName == position.ChannelName &&
|
||||
flushedSegment.DmlPosition.Timestamp > position.Timestamp {
|
||||
flushedCheckPointInfos = append(flushedCheckPointInfos, flushedSegment)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
w.node.metaReplica.addExcludedSegments(collectionID, flushedCheckPointInfos)
|
||||
flushedSegmentIDs := make([]UniqueID, len(flushedCheckPointInfos))
|
||||
for i, segInfo := range flushedCheckPointInfos {
|
||||
flushedSegmentIDs[i] = segInfo.GetID()
|
||||
}
|
||||
log.Info("watchDMChannel, add check points info for flushed segments done",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("flushedSegmentIDs", flushedSegmentIDs),
|
||||
)
|
||||
|
||||
// add excluded segments for dropped segments,
|
||||
// exclude all msgs with dropped segment id
|
||||
// DO NOT refer to dropped segment info, see issue https://github.com/milvus-io/milvus/issues/19704
|
||||
var droppedCheckPointInfos []*datapb.SegmentInfo
|
||||
for _, info := range w.req.Infos {
|
||||
for _, droppedSegmentID := range info.GetDroppedSegmentIds() {
|
||||
droppedCheckPointInfos = append(droppedCheckPointInfos, &datapb.SegmentInfo{
|
||||
ID: droppedSegmentID,
|
||||
CollectionID: collectionID,
|
||||
InsertChannel: info.GetChannelName(),
|
||||
DmlPosition: &internalpb.MsgPosition{
|
||||
ChannelName: info.GetChannelName(),
|
||||
Timestamp: math.MaxUint64,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
w.node.metaReplica.addExcludedSegments(collectionID, droppedCheckPointInfos)
|
||||
droppedSegmentIDs := make([]UniqueID, len(droppedCheckPointInfos))
|
||||
for i, segInfo := range droppedCheckPointInfos {
|
||||
droppedSegmentIDs[i] = segInfo.GetID()
|
||||
}
|
||||
log.Info("watchDMChannel, add check points info for dropped segments done",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("droppedSegmentIDs", droppedSegmentIDs),
|
||||
)
|
||||
|
||||
// add flow graph
|
||||
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels)
|
||||
if err != nil {
|
||||
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels))
|
||||
|
||||
// channels as consumer
|
||||
for channel, fg := range channel2FlowGraph {
|
||||
if _, ok := channel2AsConsumerPosition[channel]; ok {
|
||||
// use pChannel to consume
|
||||
err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName)
|
||||
if err != nil {
|
||||
log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pos, ok := channel2SeekPosition[channel]; ok {
|
||||
pos.MsgGroup = consumeSubName
|
||||
// use pChannel to seek
|
||||
pos.ChannelName = VPChannels[channel]
|
||||
err = fg.consumeFlowGraphFromPosition(pos)
|
||||
if err != nil {
|
||||
log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Close()
|
||||
}
|
||||
gcChannels := make([]Channel, 0)
|
||||
for channel := range channel2FlowGraph {
|
||||
gcChannels = append(gcChannels, channel)
|
||||
}
|
||||
w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
|
||||
coll.addVChannels(vChannels)
|
||||
coll.addPChannels(pChannels)
|
||||
coll.setLoadType(lType)
|
||||
|
||||
log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
|
||||
// create tSafe
|
||||
for _, channel := range vChannels {
|
||||
w.node.tSafeReplica.addTSafe(channel)
|
||||
}
|
||||
|
||||
// add tsafe watch in query shard if exists
|
||||
for _, dmlChannel := range vChannels {
|
||||
w.node.queryShardService.addQueryShard(collectionID, dmlChannel, w.req.GetReplicaID())
|
||||
}
|
||||
|
||||
// start flow graphs
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Start()
|
||||
}
|
||||
|
||||
log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
return nil
|
||||
}
|
||||
|
||||
// internal helper function to subscribe delta channel
|
||||
func (l *loadSegmentsTask) watchDeltaChannel(vchanName []string) error {
|
||||
collectionID := l.req.CollectionID
|
||||
var vDeltaChannels, pDeltaChannels []string
|
||||
VPDeltaChannels := make(map[string]string)
|
||||
for _, v := range vchanName {
|
||||
dc, err := funcutil.ConvertChannelName(v, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta)
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannels, failed to convert deltaChannel from dmlChannel", zap.String("DmlChannel", v), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
p := funcutil.ToPhysicalChannel(dc)
|
||||
vDeltaChannels = append(vDeltaChannels, dc)
|
||||
pDeltaChannels = append(pDeltaChannels, p)
|
||||
VPDeltaChannels[dc] = p
|
||||
}
|
||||
log.Info("Starting WatchDeltaChannels ...",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("channels", VPDeltaChannels),
|
||||
)
|
||||
|
||||
coll, err := l.node.metaReplica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
channel2FlowGraph, err := l.node.dataSyncService.addFlowGraphsForDeltaChannels(collectionID, vDeltaChannels)
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID())
|
||||
|
||||
// channels as consumer
|
||||
for channel, fg := range channel2FlowGraph {
|
||||
pchannel := VPDeltaChannels[channel]
|
||||
// use pChannel to consume
|
||||
err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName)
|
||||
if err != nil {
|
||||
log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err))
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Close()
|
||||
}
|
||||
gcChannels := make([]Channel, 0)
|
||||
for channel := range channel2FlowGraph {
|
||||
gcChannels = append(gcChannels, channel)
|
||||
}
|
||||
l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels))
|
||||
|
||||
//set collection replica
|
||||
coll.addVDeltaChannels(vDeltaChannels)
|
||||
coll.addPDeltaChannels(pDeltaChannels)
|
||||
|
||||
// create tSafe
|
||||
for _, channel := range vDeltaChannels {
|
||||
l.node.tSafeReplica.addTSafe(channel)
|
||||
}
|
||||
|
||||
// add tsafe watch in query shard if exists, we find no way to handle it if query shard not exist
|
||||
for _, channel := range vDeltaChannels {
|
||||
dmlChannel, err := funcutil.ConvertChannelName(channel, Params.CommonCfg.RootCoordDelta, Params.CommonCfg.RootCoordDml)
|
||||
if err != nil {
|
||||
log.Error("failed to convert delta channel to dml", zap.String("channel", channel), zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
err = l.node.queryShardService.addQueryShard(collectionID, dmlChannel, l.req.GetReplicaID())
|
||||
if err != nil {
|
||||
log.Error("failed to add shard Service to query shard", zap.String("channel", channel), zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// start flow graphs
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Start()
|
||||
}
|
||||
|
||||
log.Info("WatchDeltaChannels done", zap.Int64("collectionID", collectionID), zap.String("ChannelIDs", fmt.Sprintln(vDeltaChannels)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadSegmentsTask
|
||||
func (l *loadSegmentsTask) PreExecute(ctx context.Context) error {
|
||||
log.Info("LoadSegmentTask PreExecute start", zap.Int64("msgID", l.req.Base.MsgID))
|
||||
var err error
|
||||
// init meta
|
||||
collectionID := l.req.GetCollectionID()
|
||||
l.node.metaReplica.addCollection(collectionID, l.req.GetSchema())
|
||||
for _, partitionID := range l.req.GetLoadMeta().GetPartitionIDs() {
|
||||
err = l.node.metaReplica.addPartition(collectionID, partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// filter segments that are already loaded in this querynode
|
||||
var filteredInfos []*queryPb.SegmentLoadInfo
|
||||
for _, info := range l.req.Infos {
|
||||
has, err := l.node.metaReplica.hasSegment(info.SegmentID, segmentTypeSealed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has {
|
||||
filteredInfos = append(filteredInfos, info)
|
||||
} else {
|
||||
log.Debug("ignore segment that is already loaded", zap.Int64("collectionID", info.SegmentID), zap.Int64("segmentID", info.SegmentID))
|
||||
}
|
||||
}
|
||||
l.req.Infos = filteredInfos
|
||||
log.Info("LoadSegmentTask PreExecute done", zap.Int64("msgID", l.req.Base.MsgID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *loadSegmentsTask) Execute(ctx context.Context) error {
|
||||
log.Info("LoadSegmentTask Execute start", zap.Int64("msgID", l.req.Base.MsgID))
|
||||
|
||||
if len(l.req.Infos) == 0 {
|
||||
log.Info("all segments loaded",
|
||||
zap.Int64("msgID", l.req.GetBase().GetMsgID()))
|
||||
return nil
|
||||
}
|
||||
|
||||
segmentIDs := lo.Map(l.req.Infos, func(info *queryPb.SegmentLoadInfo, idx int) UniqueID { return info.SegmentID })
|
||||
l.node.metaReplica.addSegmentsLoadingList(segmentIDs)
|
||||
defer l.node.metaReplica.removeSegmentsLoadingList(segmentIDs)
|
||||
err := l.node.loader.LoadSegment(l.ctx, l.req, segmentTypeSealed)
|
||||
if err != nil {
|
||||
log.Warn("failed to load segment", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
vchanName := make([]string, 0)
|
||||
for _, deltaPosition := range l.req.DeltaPositions {
|
||||
vchanName = append(vchanName, deltaPosition.ChannelName)
|
||||
}
|
||||
// TODO delta channel need to released 1. if other watchDeltaChannel fail 2. when segment release
|
||||
err = l.watchDeltaChannel(vchanName)
|
||||
if err != nil {
|
||||
// roll back
|
||||
for _, segment := range l.req.Infos {
|
||||
l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed)
|
||||
}
|
||||
log.Warn("failed to watch Delta channel while load segment", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
runningGroup, groupCtx := errgroup.WithContext(l.ctx)
|
||||
for _, deltaPosition := range l.req.DeltaPositions {
|
||||
pos := deltaPosition
|
||||
runningGroup.Go(func() error {
|
||||
// reload data from dml channel
|
||||
return l.node.loader.FromDmlCPLoadDelete(groupCtx, l.req.CollectionID, pos)
|
||||
})
|
||||
}
|
||||
err = runningGroup.Wait()
|
||||
if err != nil {
|
||||
for _, segment := range l.req.Infos {
|
||||
l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed)
|
||||
}
|
||||
log.Warn("failed to load delete data while load segment", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("LoadSegmentTask Execute done", zap.Int64("collectionID", l.req.CollectionID),
|
||||
zap.Int64("replicaID", l.req.ReplicaID), zap.Int64("msgID", l.req.Base.MsgID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releaseCollectionTask) Execute(ctx context.Context) error {
|
||||
log.Info("Execute release collection task", zap.Any("collectionID", r.req.CollectionID))
|
||||
|
||||
|
|
|
@ -0,0 +1,354 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
)
|
||||
|
||||
type watchDmChannelsTask struct {
|
||||
baseTask
|
||||
req *queryPb.WatchDmChannelsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
// watchDmChannelsTask
|
||||
func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||
collectionID := w.req.CollectionID
|
||||
partitionIDs := w.req.GetPartitionIDs()
|
||||
|
||||
lType := w.req.GetLoadMeta().GetLoadType()
|
||||
if lType == queryPb.LoadType_UnKnownType {
|
||||
// if no partitionID is specified, load type is load collection
|
||||
if len(partitionIDs) != 0 {
|
||||
lType = queryPb.LoadType_LoadPartition
|
||||
} else {
|
||||
lType = queryPb.LoadType_LoadCollection
|
||||
}
|
||||
}
|
||||
|
||||
// get all vChannels
|
||||
var vChannels []Channel
|
||||
VPChannels := make(map[string]string) // map[vChannel]pChannel
|
||||
for _, info := range w.req.Infos {
|
||||
v := info.ChannelName
|
||||
p := funcutil.ToPhysicalChannel(info.ChannelName)
|
||||
vChannels = append(vChannels, v)
|
||||
VPChannels[v] = p
|
||||
}
|
||||
|
||||
if len(VPChannels) != len(vChannels) {
|
||||
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
|
||||
}
|
||||
|
||||
log.Info("Starting WatchDmChannels ...",
|
||||
zap.String("collectionName", w.req.Schema.Name),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", w.req.GetReplicaID()),
|
||||
zap.String("load type", lType.String()),
|
||||
zap.Strings("vChannels", vChannels),
|
||||
)
|
||||
|
||||
// init collection meta
|
||||
coll := w.node.metaReplica.addCollection(collectionID, w.req.Schema)
|
||||
|
||||
// filter out the already exist channels
|
||||
vChannels = coll.AddChannels(vChannels, VPChannels)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, vChannel := range vChannels {
|
||||
coll.removeVChannel(vChannel)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if len(vChannels) == 0 {
|
||||
log.Warn("all channels has be added before, ignore watch dml requests")
|
||||
return nil
|
||||
}
|
||||
|
||||
//add shard cluster
|
||||
for _, vchannel := range vChannels {
|
||||
w.node.ShardClusterService.addShardCluster(w.req.GetCollectionID(), w.req.GetReplicaID(), vchannel)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, vchannel := range vChannels {
|
||||
w.node.ShardClusterService.releaseShardCluster(vchannel)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
unFlushedSegmentIDs, err := w.LoadGrowingSegments(ctx, collectionID)
|
||||
|
||||
// remove growing segment if watch dmChannels failed
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, segmentID := range unFlushedSegmentIDs {
|
||||
w.node.metaReplica.removeSegment(segmentID, segmentTypeGrowing)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
channel2FlowGraph, err := w.initFlowGraph(ctx, collectionID, vChannels, VPChannels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
coll.setLoadType(lType)
|
||||
|
||||
log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
|
||||
// create tSafe
|
||||
for _, channel := range vChannels {
|
||||
w.node.tSafeReplica.addTSafe(channel)
|
||||
}
|
||||
|
||||
// add tsafe watch in query shard if exists
|
||||
for _, dmlChannel := range vChannels {
|
||||
w.node.queryShardService.addQueryShard(collectionID, dmlChannel, w.req.GetReplicaID())
|
||||
}
|
||||
|
||||
// start flow graphs
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Start()
|
||||
}
|
||||
|
||||
log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) LoadGrowingSegments(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) {
|
||||
// load growing segments
|
||||
unFlushedSegments := make([]*queryPb.SegmentLoadInfo, 0)
|
||||
unFlushedSegmentIDs := make([]UniqueID, 0)
|
||||
for _, info := range w.req.Infos {
|
||||
for _, ufInfoID := range info.GetUnflushedSegmentIds() {
|
||||
// unFlushed segment may not have binLogs, skip loading
|
||||
ufInfo := w.req.GetSegmentInfos()[ufInfoID]
|
||||
if ufInfo == nil {
|
||||
log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", ufInfoID))
|
||||
continue
|
||||
}
|
||||
if len(ufInfo.GetBinlogs()) > 0 {
|
||||
unFlushedSegments = append(unFlushedSegments, &queryPb.SegmentLoadInfo{
|
||||
SegmentID: ufInfo.ID,
|
||||
PartitionID: ufInfo.PartitionID,
|
||||
CollectionID: ufInfo.CollectionID,
|
||||
BinlogPaths: ufInfo.Binlogs,
|
||||
NumOfRows: ufInfo.NumOfRows,
|
||||
Statslogs: ufInfo.Statslogs,
|
||||
Deltalogs: ufInfo.Deltalogs,
|
||||
InsertChannel: ufInfo.InsertChannel,
|
||||
})
|
||||
unFlushedSegmentIDs = append(unFlushedSegmentIDs, ufInfo.GetID())
|
||||
} else {
|
||||
log.Info("skip segment which binlog is empty", zap.Int64("segmentID", ufInfo.ID))
|
||||
}
|
||||
}
|
||||
}
|
||||
req := &queryPb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments),
|
||||
commonpbutil.WithMsgID(w.req.Base.MsgID), // use parent task's msgID
|
||||
),
|
||||
Infos: unFlushedSegments,
|
||||
CollectionID: collectionID,
|
||||
Schema: w.req.GetSchema(),
|
||||
LoadMeta: w.req.GetLoadMeta(),
|
||||
}
|
||||
|
||||
// update partition info from unFlushedSegments and loadMeta
|
||||
for _, info := range req.Infos {
|
||||
err := w.node.metaReplica.addPartition(collectionID, info.PartitionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for _, partitionID := range req.GetLoadMeta().GetPartitionIDs() {
|
||||
err := w.node.metaReplica.addPartition(collectionID, partitionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("loading growing segments in WatchDmChannels...",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs),
|
||||
)
|
||||
err := w.node.loader.LoadSegment(w.ctx, req, segmentTypeGrowing)
|
||||
if err != nil {
|
||||
log.Warn("failed to load segment", zap.Int64("collection", collectionID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
log.Info("successfully load growing segments done in WatchDmChannels",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs),
|
||||
)
|
||||
return unFlushedSegmentIDs, nil
|
||||
}
|
||||
|
||||
func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID UniqueID, vChannels []Channel, VPChannels map[string]string) (map[string]*queryNodeFlowGraph, error) {
|
||||
// So far, we don't support to enable each node with two different channel
|
||||
consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID())
|
||||
|
||||
// group channels by to seeking or consuming
|
||||
channel2SeekPosition := make(map[string]*internalpb.MsgPosition)
|
||||
|
||||
// for channel with no position
|
||||
channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition)
|
||||
for _, info := range w.req.Infos {
|
||||
if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 {
|
||||
channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition
|
||||
continue
|
||||
}
|
||||
info.SeekPosition.MsgGroup = consumeSubName
|
||||
channel2SeekPosition[info.ChannelName] = info.SeekPosition
|
||||
}
|
||||
log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID))
|
||||
|
||||
// add excluded segments for unFlushed segments,
|
||||
// unFlushed segments before check point should be filtered out.
|
||||
unFlushedCheckPointInfos := make([]*datapb.SegmentInfo, 0)
|
||||
for _, info := range w.req.Infos {
|
||||
for _, ufsID := range info.GetUnflushedSegmentIds() {
|
||||
unFlushedCheckPointInfos = append(unFlushedCheckPointInfos, w.req.SegmentInfos[ufsID])
|
||||
}
|
||||
}
|
||||
w.node.metaReplica.addExcludedSegments(collectionID, unFlushedCheckPointInfos)
|
||||
unflushedSegmentIDs := make([]UniqueID, len(unFlushedCheckPointInfos))
|
||||
for i, segInfo := range unFlushedCheckPointInfos {
|
||||
unflushedSegmentIDs[i] = segInfo.GetID()
|
||||
}
|
||||
log.Info("watchDMChannel, add check points info for unflushed segments done",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("unflushedSegmentIDs", unflushedSegmentIDs),
|
||||
)
|
||||
|
||||
// add excluded segments for flushed segments,
|
||||
// flushed segments with later check point than seekPosition should be filtered out.
|
||||
flushedCheckPointInfos := make([]*datapb.SegmentInfo, 0)
|
||||
for _, info := range w.req.Infos {
|
||||
for _, flushedSegmentID := range info.GetFlushedSegmentIds() {
|
||||
flushedSegment := w.req.SegmentInfos[flushedSegmentID]
|
||||
for _, position := range channel2SeekPosition {
|
||||
if flushedSegment.DmlPosition != nil &&
|
||||
flushedSegment.DmlPosition.ChannelName == position.ChannelName &&
|
||||
flushedSegment.DmlPosition.Timestamp > position.Timestamp {
|
||||
flushedCheckPointInfos = append(flushedCheckPointInfos, flushedSegment)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
w.node.metaReplica.addExcludedSegments(collectionID, flushedCheckPointInfos)
|
||||
flushedSegmentIDs := make([]UniqueID, len(flushedCheckPointInfos))
|
||||
for i, segInfo := range flushedCheckPointInfos {
|
||||
flushedSegmentIDs[i] = segInfo.GetID()
|
||||
}
|
||||
log.Info("watchDMChannel, add check points info for flushed segments done",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("flushedSegmentIDs", flushedSegmentIDs),
|
||||
)
|
||||
|
||||
// add excluded segments for dropped segments,
|
||||
// exclude all msgs with dropped segment id
|
||||
// DO NOT refer to dropped segment info, see issue https://github.com/milvus-io/milvus/issues/19704
|
||||
var droppedCheckPointInfos []*datapb.SegmentInfo
|
||||
for _, info := range w.req.Infos {
|
||||
for _, droppedSegmentID := range info.GetDroppedSegmentIds() {
|
||||
droppedCheckPointInfos = append(droppedCheckPointInfos, &datapb.SegmentInfo{
|
||||
ID: droppedSegmentID,
|
||||
CollectionID: collectionID,
|
||||
InsertChannel: info.GetChannelName(),
|
||||
DmlPosition: &internalpb.MsgPosition{
|
||||
ChannelName: info.GetChannelName(),
|
||||
Timestamp: math.MaxUint64,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
w.node.metaReplica.addExcludedSegments(collectionID, droppedCheckPointInfos)
|
||||
droppedSegmentIDs := make([]UniqueID, len(droppedCheckPointInfos))
|
||||
for i, segInfo := range droppedCheckPointInfos {
|
||||
droppedSegmentIDs[i] = segInfo.GetID()
|
||||
}
|
||||
log.Info("watchDMChannel, add check points info for dropped segments done",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Any("droppedSegmentIDs", droppedSegmentIDs),
|
||||
)
|
||||
|
||||
// add flow graph
|
||||
channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels)
|
||||
if err != nil {
|
||||
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels))
|
||||
|
||||
// channels as consumer
|
||||
for channel, fg := range channel2FlowGraph {
|
||||
if _, ok := channel2AsConsumerPosition[channel]; ok {
|
||||
// use pChannel to consume
|
||||
err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName)
|
||||
if err != nil {
|
||||
log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pos, ok := channel2SeekPosition[channel]; ok {
|
||||
pos.MsgGroup = consumeSubName
|
||||
// use pChannel to seek
|
||||
pos.ChannelName = VPChannels[channel]
|
||||
err = fg.consumeFlowGraphFromPosition(pos)
|
||||
if err != nil {
|
||||
log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err))
|
||||
for _, fg := range channel2FlowGraph {
|
||||
fg.flowGraph.Close()
|
||||
}
|
||||
gcChannels := make([]Channel, 0)
|
||||
for channel := range channel2FlowGraph {
|
||||
gcChannels = append(gcChannels, channel)
|
||||
}
|
||||
w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
return channel2FlowGraph, nil
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package lock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RefLock struct {
|
||||
mutex sync.RWMutex
|
||||
refCounter int
|
||||
}
|
||||
|
||||
func (m *RefLock) ref() {
|
||||
m.refCounter++
|
||||
}
|
||||
|
||||
func (m *RefLock) unref() {
|
||||
m.refCounter--
|
||||
}
|
||||
|
||||
func newRefLock() *RefLock {
|
||||
c := RefLock{
|
||||
sync.RWMutex{},
|
||||
0,
|
||||
}
|
||||
return &c
|
||||
}
|
||||
|
||||
type KeyLock struct {
|
||||
keyLocksMutex sync.Mutex
|
||||
refLocks map[string]*RefLock
|
||||
}
|
||||
|
||||
func NewKeyLock() *KeyLock {
|
||||
keyLock := KeyLock{
|
||||
refLocks: make(map[string]*RefLock),
|
||||
}
|
||||
return &keyLock
|
||||
}
|
||||
|
||||
func (k *KeyLock) Lock(key string) {
|
||||
k.keyLocksMutex.Lock()
|
||||
// update the key map
|
||||
if keyLock, ok := k.refLocks[key]; ok {
|
||||
keyLock.ref()
|
||||
|
||||
k.keyLocksMutex.Unlock()
|
||||
keyLock.mutex.Lock()
|
||||
} else {
|
||||
newKLock := newRefLock()
|
||||
newKLock.mutex.Lock()
|
||||
k.refLocks[key] = newKLock
|
||||
newKLock.ref()
|
||||
|
||||
k.keyLocksMutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeyLock) Unlock(lockedKey string) {
|
||||
k.keyLocksMutex.Lock()
|
||||
defer k.keyLocksMutex.Unlock()
|
||||
keyLock, ok := k.refLocks[lockedKey]
|
||||
if !ok {
|
||||
log.Warn("Unlocking non-existing key", zap.String("key", lockedKey))
|
||||
return
|
||||
}
|
||||
keyLock.unref()
|
||||
if keyLock.refCounter == 0 {
|
||||
delete(k.refLocks, lockedKey)
|
||||
}
|
||||
keyLock.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyLock) RLock(key string) {
|
||||
k.keyLocksMutex.Lock()
|
||||
// update the key map
|
||||
if keyLock, ok := k.refLocks[key]; ok {
|
||||
keyLock.ref()
|
||||
|
||||
k.keyLocksMutex.Unlock()
|
||||
keyLock.mutex.RLock()
|
||||
} else {
|
||||
newKLock := newRefLock()
|
||||
newKLock.mutex.RLock()
|
||||
k.refLocks[key] = newKLock
|
||||
newKLock.ref()
|
||||
|
||||
k.keyLocksMutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KeyLock) RUnlock(lockedKey string) {
|
||||
k.keyLocksMutex.Lock()
|
||||
defer k.keyLocksMutex.Unlock()
|
||||
keyLock, ok := k.refLocks[lockedKey]
|
||||
if !ok {
|
||||
log.Warn("Unlocking non-existing key", zap.String("key", lockedKey))
|
||||
return
|
||||
}
|
||||
keyLock.unref()
|
||||
if keyLock.refCounter == 0 {
|
||||
delete(k.refLocks, lockedKey)
|
||||
}
|
||||
keyLock.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (k *KeyLock) size() int {
|
||||
k.keyLocksMutex.Lock()
|
||||
defer k.keyLocksMutex.Unlock()
|
||||
return len(k.refLocks)
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
package lock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestKeyLock(t *testing.T) {
|
||||
keys := []string{"Milvus", "Blazing", "Fast"}
|
||||
|
||||
keyLock := NewKeyLock()
|
||||
|
||||
keyLock.Lock(keys[0])
|
||||
keyLock.Lock(keys[1])
|
||||
keyLock.Lock(keys[2])
|
||||
|
||||
// should work
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
keyLock.Lock(keys[0])
|
||||
keyLock.Unlock(keys[0])
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
keyLock.Lock(keys[0])
|
||||
keyLock.Unlock(keys[0])
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
assert.Equal(t, keyLock.size(), 3)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
keyLock.Unlock(keys[0])
|
||||
keyLock.Unlock(keys[1])
|
||||
keyLock.Unlock(keys[2])
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, keyLock.size(), 0)
|
||||
}
|
||||
|
||||
func TestKeyRLock(t *testing.T) {
|
||||
keys := []string{"Milvus", "Blazing", "Fast"}
|
||||
|
||||
keyLock := NewKeyLock()
|
||||
|
||||
keyLock.RLock(keys[0])
|
||||
keyLock.RLock(keys[0])
|
||||
|
||||
// should work
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
keyLock.Lock(keys[0])
|
||||
keyLock.Unlock(keys[0])
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
keyLock.RUnlock(keys[0])
|
||||
keyLock.RUnlock(keys[0])
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, keyLock.size(), 0)
|
||||
}
|
Loading…
Reference in New Issue