Dynamic add tSafe watcher, use sync.Cond instead of selectCase, add ref count (#8050)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/8478/head
bigsheeper 2021-09-24 13:57:54 +08:00 committed by GitHub
parent 43432f47d7
commit 52126f2d5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 368 additions and 112 deletions

View File

@ -179,9 +179,15 @@ func (dsService *dataSyncService) removePartitionFlowGraph(partitionID UniqueID)
defer dsService.mu.Unlock() defer dsService.mu.Unlock()
if _, ok := dsService.partitionFlowGraphs[partitionID]; ok { if _, ok := dsService.partitionFlowGraphs[partitionID]; ok {
for _, nodeFG := range dsService.partitionFlowGraphs[partitionID] { for channel, nodeFG := range dsService.partitionFlowGraphs[partitionID] {
// close flow graph // close flow graph
nodeFG.close() nodeFG.close()
// remove tSafe record
// no tSafe in tSafeReplica, don't return error
err := dsService.tSafeReplica.removeRecord(channel, partitionID)
if err != nil {
log.Warn(err.Error())
}
} }
dsService.partitionFlowGraphs[partitionID] = nil dsService.partitionFlowGraphs[partitionID] = nil
} }

View File

@ -213,3 +213,25 @@ func TestDataSyncService_partitionFlowGraphs(t *testing.T) {
assert.Nil(t, fg) assert.Nil(t, fg)
assert.Error(t, err) assert.Error(t, err)
} }
func TestDataSyncService_removePartitionFlowGraphs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("test no tSafe", func(t *testing.T) {
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
fac, err := genFactory()
assert.NoError(t, err)
dataSyncService := newDataSyncService(ctx, streaming.replica, streaming.tSafeReplica, fac)
assert.NotNil(t, dataSyncService)
dataSyncService.addPartitionFlowGraph(defaultPartitionID, defaultPartitionID, []Channel{defaultVChannel})
err = dataSyncService.tSafeReplica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
dataSyncService.removePartitionFlowGraph(defaultPartitionID)
})
}

View File

@ -65,7 +65,10 @@ func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
} else { } else {
id = stNode.collectionID id = stNode.collectionID
} }
stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax) err := stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax)
if err != nil {
log.Warn(err.Error())
}
//log.Debug("update tSafe:", //log.Debug("update tSafe:",
// zap.Int64("tSafe", int64(serviceTimeMsg.timeRange.timestampMax)), // zap.Int64("tSafe", int64(serviceTimeMsg.timeRange.timestampMax)),
// zap.Any("collectionID", stNode.collectionID), // zap.Any("collectionID", stNode.collectionID),

View File

@ -77,4 +77,18 @@ func TestServiceTimeNode_Operate(t *testing.T) {
in := []flowgraph.Msg{msg, msg} in := []flowgraph.Msg{msg, msg}
node.Operate(in) node.Operate(in)
}) })
t.Run("test no tSafe", func(t *testing.T) {
node := genServiceTimeNode()
err := node.tSafeReplica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
msg := &serviceTimeMsg{
timeRange: TimeRange{
timestampMin: 0,
timestampMax: 1000,
},
}
in := []flowgraph.Msg{msg, msg}
node.Operate(in)
})
} }

View File

@ -106,12 +106,26 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
// add search collection // add search collection
if !node.queryService.hasQueryCollection(collectionID) { if !node.queryService.hasQueryCollection(collectionID) {
node.queryService.addQueryCollection(collectionID) err := node.queryService.addQueryCollection(collectionID)
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return status, err
}
log.Debug("add query collection", zap.Any("collectionID", collectionID)) log.Debug("add query collection", zap.Any("collectionID", collectionID))
} }
// add request channel // add request channel
sc := node.queryService.queryCollections[in.CollectionID] sc, err := node.queryService.getQueryCollection(in.CollectionID)
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return status, err
}
consumeChannels := []string{in.RequestChannelID} consumeChannels := []string{in.RequestChannelID}
//consumeSubName := Params.MsgChannelSubName //consumeSubName := Params.MsgChannelSubName
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int()) consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())

View File

@ -17,13 +17,14 @@ import (
"math/rand" "math/rand"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/stretchr/testify/assert"
) )
func TestImpl_GetComponentStates(t *testing.T) { func TestImpl_GetComponentStates(t *testing.T) {
@ -108,6 +109,26 @@ func TestImpl_AddQueryChannel(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}) })
t.Run("test add query collection failed", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = node.streaming.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
status, err := node.AddQueryChannel(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
} }
func TestImpl_RemoveQueryChannel(t *testing.T) { func TestImpl_RemoveQueryChannel(t *testing.T) {

View File

@ -14,9 +14,9 @@ package querynode
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"math" "math"
"reflect"
"sync" "sync"
"unsafe" "unsafe"
@ -50,15 +50,16 @@ type queryCollection struct {
cancel context.CancelFunc cancel context.CancelFunc
collectionID UniqueID collectionID UniqueID
collection *Collection
historical *historical historical *historical
streaming *streaming streaming *streaming
unsolvedMsgMu sync.Mutex // guards unsolvedMsg unsolvedMsgMu sync.Mutex // guards unsolvedMsg
unsolvedMsg []queryMsg unsolvedMsg []queryMsg
tSafeWatchersMu sync.Mutex // guards tSafeWatchers
tSafeWatchers map[Channel]*tSafeWatcher tSafeWatchers map[Channel]*tSafeWatcher
watcherSelectCase []reflect.SelectCase tSafeUpdate bool
watcherCond *sync.Cond
serviceableTimeMutex sync.Mutex // guards serviceableTime serviceableTimeMutex sync.Mutex // guards serviceableTime
serviceableTime Timestamp serviceableTime Timestamp
@ -83,25 +84,26 @@ func newQueryCollection(releaseCtx context.Context,
localChunkManager storage.ChunkManager, localChunkManager storage.ChunkManager,
remoteChunkManager storage.ChunkManager, remoteChunkManager storage.ChunkManager,
localCacheEnabled bool, localCacheEnabled bool,
) *queryCollection { ) (*queryCollection, error) {
unsolvedMsg := make([]queryMsg, 0) unsolvedMsg := make([]queryMsg, 0)
queryStream, _ := factory.NewQueryMsgStream(releaseCtx) queryStream, _ := factory.NewQueryMsgStream(releaseCtx)
queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx) queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
collection, _ := streaming.replica.getCollectionByID(collectionID) condMu := sync.Mutex{}
qc := &queryCollection{ qc := &queryCollection{
releaseCtx: releaseCtx, releaseCtx: releaseCtx,
cancel: cancel, cancel: cancel,
collectionID: collectionID, collectionID: collectionID,
collection: collection,
historical: historical, historical: historical,
streaming: streaming, streaming: streaming,
tSafeWatchers: make(map[Channel]*tSafeWatcher), tSafeWatchers: make(map[Channel]*tSafeWatcher),
tSafeUpdate: false,
watcherCond: sync.NewCond(&condMu),
unsolvedMsg: unsolvedMsg, unsolvedMsg: unsolvedMsg,
@ -113,8 +115,11 @@ func newQueryCollection(releaseCtx context.Context,
localCacheEnabled: localCacheEnabled, localCacheEnabled: localCacheEnabled,
} }
qc.register() err := qc.registerCollectionTSafe()
return qc if err != nil {
return nil, err
}
return qc, nil
} }
func (q *queryCollection) start() { func (q *queryCollection) start() {
@ -133,26 +138,61 @@ func (q *queryCollection) close() {
} }
} }
func (q *queryCollection) register() { // registerCollectionTSafe registers tSafe watcher if vChannels exists
func (q *queryCollection) registerCollectionTSafe() error {
collection, err := q.streaming.replica.getCollectionByID(q.collectionID) collection, err := q.streaming.replica.getCollectionByID(q.collectionID)
if err != nil { if err != nil {
log.Warn(err.Error()) return err
return
} }
//TODO:: can't add new vChannel to selectCase
q.watcherSelectCase = make([]reflect.SelectCase, 0)
log.Debug("register tSafe watcher and init watcher select case", log.Debug("register tSafe watcher and init watcher select case",
zap.Any("collectionID", collection.ID()), zap.Any("collectionID", collection.ID()),
zap.Any("dml channels", collection.getVChannels()), zap.Any("dml channels", collection.getVChannels()),
) )
for _, channel := range collection.getVChannels() { for _, channel := range collection.getVChannels() {
q.tSafeWatchers[channel] = newTSafeWatcher() err = q.addTSafeWatcher(channel)
q.streaming.tSafeReplica.registerTSafeWatcher(channel, q.tSafeWatchers[channel]) if err != nil {
q.watcherSelectCase = append(q.watcherSelectCase, reflect.SelectCase{ return err
Dir: reflect.SelectRecv, }
Chan: reflect.ValueOf(q.tSafeWatchers[channel].watcherChan()), }
}) return nil
}
func (q *queryCollection) addTSafeWatcher(vChannel Channel) error {
q.tSafeWatchersMu.Lock()
defer q.tSafeWatchersMu.Unlock()
if _, ok := q.tSafeWatchers[vChannel]; ok {
err := errors.New(fmt.Sprintln("tSafeWatcher of queryCollection has been exists, ",
"collectionID = ", q.collectionID, ", ",
"channel = ", vChannel))
return err
}
q.tSafeWatchers[vChannel] = newTSafeWatcher()
err := q.streaming.tSafeReplica.registerTSafeWatcher(vChannel, q.tSafeWatchers[vChannel])
if err != nil {
return err
}
log.Debug("add tSafeWatcher to queryCollection",
zap.Any("collectionID", q.collectionID),
zap.Any("channel", vChannel),
)
go q.startWatcher(q.tSafeWatchers[vChannel].watcherChan())
return nil
}
// TODO: add stopWatcher(), add close() to tSafeWatcher
func (q *queryCollection) startWatcher(channel <-chan bool) {
for {
select {
case <-q.releaseCtx.Done():
return
case <-channel:
// TODO: check if channel is closed
q.watcherCond.L.Lock()
q.tSafeUpdate = true
q.watcherCond.Broadcast()
q.watcherCond.L.Unlock()
}
} }
} }
@ -171,22 +211,24 @@ func (q *queryCollection) popAllUnsolvedMsg() []queryMsg {
return ret return ret
} }
func (q *queryCollection) waitNewTSafe() Timestamp { func (q *queryCollection) waitNewTSafe() (Timestamp, error) {
// block until any vChannel updating tSafe q.watcherCond.L.Lock()
_, _, recvOK := reflect.Select(q.watcherSelectCase) for !q.tSafeUpdate {
if !recvOK { q.watcherCond.Wait()
//log.Warn("tSafe has been closed", zap.Any("collectionID", q.collectionID))
return Timestamp(math.MaxInt64)
} }
q.watcherCond.L.Unlock()
//log.Debug("wait new tSafe", zap.Any("collectionID", s.collectionID)) //log.Debug("wait new tSafe", zap.Any("collectionID", s.collectionID))
t := Timestamp(math.MaxInt64) t := Timestamp(math.MaxInt64)
for channel := range q.tSafeWatchers { for channel := range q.tSafeWatchers {
ts := q.streaming.tSafeReplica.getTSafe(channel) ts, err := q.streaming.tSafeReplica.getTSafe(channel)
if err != nil {
return 0, err
}
if ts <= t { if ts <= t {
t = ts t = ts
} }
} }
return t return t, nil
} }
func (q *queryCollection) getServiceableTime() Timestamp { func (q *queryCollection) getServiceableTime() Timestamp {
@ -397,7 +439,11 @@ func (q *queryCollection) doUnsolvedQueryMsg() {
return return
default: default:
//time.Sleep(10 * time.Millisecond) //time.Sleep(10 * time.Millisecond)
serviceTime := q.waitNewTSafe() serviceTime, err := q.waitNewTSafe()
if err != nil {
log.Error(err.Error())
return
}
//st, _ := tsoutil.ParseTS(serviceTime) //st, _ := tsoutil.ParseTS(serviceTime)
//log.Debug("get tSafe from flow graph", //log.Debug("get tSafe from flow graph",
// zap.Int64("collectionID", q.collectionID), // zap.Int64("collectionID", q.collectionID),
@ -769,7 +815,12 @@ func (q *queryCollection) search(msg queryMsg) error {
searchTimestamp := searchMsg.BeginTs() searchTimestamp := searchMsg.BeginTs()
travelTimestamp := searchMsg.TravelTimestamp travelTimestamp := searchMsg.TravelTimestamp
schema, err := typeutil.CreateSchemaHelper(q.collection.schema) collection, err := q.streaming.replica.getCollectionByID(searchMsg.CollectionID)
if err != nil {
return err
}
schema, err := typeutil.CreateSchemaHelper(collection.schema)
if err != nil { if err != nil {
return err return err
} }
@ -777,13 +828,13 @@ func (q *queryCollection) search(msg queryMsg) error {
var plan *SearchPlan var plan *SearchPlan
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 { if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
expr := searchMsg.SerializedExprPlan expr := searchMsg.SerializedExprPlan
plan, err = createSearchPlanByExpr(q.collection, expr) plan, err = createSearchPlanByExpr(collection, expr)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
dsl := searchMsg.Dsl dsl := searchMsg.Dsl
plan, err = createSearchPlan(q.collection, dsl) plan, err = createSearchPlan(collection, dsl)
if err != nil { if err != nil {
return err return err
} }
@ -821,13 +872,13 @@ func (q *queryCollection) search(msg queryMsg) error {
if len(searchMsg.PartitionIDs) > 0 { if len(searchMsg.PartitionIDs) > 0 {
globalSealedSegments = q.historical.getGlobalSegmentIDsByPartitionIds(searchMsg.PartitionIDs) globalSealedSegments = q.historical.getGlobalSegmentIDsByPartitionIds(searchMsg.PartitionIDs)
} else { } else {
globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(q.collection.id) globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(collection.id)
} }
searchResults := make([]*SearchResult, 0) searchResults := make([]*SearchResult, 0)
// historical search // historical search
hisSearchResults, sealedSegmentSearched, err1 := q.historical.search(searchRequests, q.collection.id, searchMsg.PartitionIDs, plan, travelTimestamp) hisSearchResults, sealedSegmentSearched, err1 := q.historical.search(searchRequests, collection.id, searchMsg.PartitionIDs, plan, travelTimestamp)
if err1 != nil { if err1 != nil {
log.Warn(err1.Error()) log.Warn(err1.Error())
return err1 return err1
@ -837,9 +888,9 @@ func (q *queryCollection) search(msg queryMsg) error {
// streaming search // streaming search
var err2 error var err2 error
for _, channel := range q.collection.getVChannels() { for _, channel := range collection.getVChannels() {
var strSearchResults []*SearchResult var strSearchResults []*SearchResult
strSearchResults, err2 = q.streaming.search(searchRequests, q.collection.id, searchMsg.PartitionIDs, channel, plan, travelTimestamp) strSearchResults, err2 = q.streaming.search(searchRequests, collection.id, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
if err2 != nil { if err2 != nil {
log.Warn(err2.Error()) log.Warn(err2.Error())
return err2 return err2
@ -870,14 +921,14 @@ func (q *queryCollection) search(msg queryMsg) error {
SlicedOffset: 1, SlicedOffset: 1,
SlicedNumCount: 1, SlicedNumCount: 1,
SealedSegmentIDsSearched: sealedSegmentSearched, SealedSegmentIDsSearched: sealedSegmentSearched,
ChannelIDsSearched: q.collection.getVChannels(), ChannelIDsSearched: collection.getVChannels(),
GlobalSealedSegmentIDs: globalSealedSegments, GlobalSealedSegmentIDs: globalSealedSegments,
}, },
} }
log.Debug("QueryNode Empty SearchResultMsg", log.Debug("QueryNode Empty SearchResultMsg",
zap.Any("collectionID", q.collection.id), zap.Any("collectionID", collection.id),
zap.Any("msgID", searchMsg.ID()), zap.Any("msgID", searchMsg.ID()),
zap.Any("vChannels", q.collection.getVChannels()), zap.Any("vChannels", collection.getVChannels()),
zap.Any("sealedSegmentSearched", sealedSegmentSearched), zap.Any("sealedSegmentSearched", sealedSegmentSearched),
) )
err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID) err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
@ -962,14 +1013,14 @@ func (q *queryCollection) search(msg queryMsg) error {
SlicedOffset: 1, SlicedOffset: 1,
SlicedNumCount: 1, SlicedNumCount: 1,
SealedSegmentIDsSearched: sealedSegmentSearched, SealedSegmentIDsSearched: sealedSegmentSearched,
ChannelIDsSearched: q.collection.getVChannels(), ChannelIDsSearched: collection.getVChannels(),
GlobalSealedSegmentIDs: globalSealedSegments, GlobalSealedSegmentIDs: globalSealedSegments,
}, },
} }
log.Debug("QueryNode SearchResultMsg", log.Debug("QueryNode SearchResultMsg",
zap.Any("collectionID", q.collection.id), zap.Any("collectionID", collection.id),
zap.Any("msgID", searchMsg.ID()), zap.Any("msgID", searchMsg.ID()),
zap.Any("vChannels", q.collection.getVChannels()), zap.Any("vChannels", collection.getVChannels()),
zap.Any("sealedSegmentSearched", sealedSegmentSearched), zap.Any("sealedSegmentSearched", sealedSegmentSearched),
) )

View File

@ -4,10 +4,8 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"math" "math"
"math/rand" "math/rand"
"reflect"
"testing" "testing"
"time" "time"
@ -52,7 +50,7 @@ func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*
return nil, err return nil, err
} }
queryCollection := newQueryCollection(ctx, cancel, queryCollection, err := newQueryCollection(ctx, cancel,
defaultCollectionID, defaultCollectionID,
historical, historical,
streaming, streaming,
@ -60,22 +58,15 @@ func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*
localCM, localCM,
remoteCM, remoteCM,
false) false)
if queryCollection == nil { return queryCollection, err
return nil, errors.New("nil simple query collection")
}
return queryCollection, nil
} }
func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) { func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) {
// register // register
queryCollection.watcherSelectCase = make([]reflect.SelectCase, 0)
queryCollection.tSafeWatchers[defaultVChannel] = newTSafeWatcher() queryCollection.tSafeWatchers[defaultVChannel] = newTSafeWatcher()
queryCollection.streaming.tSafeReplica.addTSafe(defaultVChannel) queryCollection.streaming.tSafeReplica.addTSafe(defaultVChannel)
queryCollection.streaming.tSafeReplica.registerTSafeWatcher(defaultVChannel, queryCollection.tSafeWatchers[defaultVChannel]) queryCollection.streaming.tSafeReplica.registerTSafeWatcher(defaultVChannel, queryCollection.tSafeWatchers[defaultVChannel])
queryCollection.watcherSelectCase = append(queryCollection.watcherSelectCase, reflect.SelectCase{ queryCollection.addTSafeWatcher(defaultVChannel)
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(queryCollection.tSafeWatchers[defaultVChannel].watcherChan()),
})
queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, defaultCollectionID, timestamp) queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
} }
@ -125,7 +116,8 @@ func TestQueryCollection_withoutVChannel(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
queryCollection := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false) queryCollection, err := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
assert.NoError(t, err)
producerChannels := []string{"testResultChannel"} producerChannels := []string{"testResultChannel"}
queryCollection.queryResultMsgStream.AsProducer(producerChannels) queryCollection.queryResultMsgStream.AsProducer(producerChannels)
@ -484,6 +476,15 @@ func TestQueryCollection_serviceableTime(t *testing.T) {
assert.Equal(t, st+gracefulTime, resST) assert.Equal(t, st+gracefulTime, resST)
} }
func TestQueryCollection_addTSafeWatcher(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
queryCollection.addTSafeWatcher(defaultVChannel)
}
func TestQueryCollection_waitNewTSafe(t *testing.T) { func TestQueryCollection_waitNewTSafe(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -493,7 +494,8 @@ func TestQueryCollection_waitNewTSafe(t *testing.T) {
timestamp := Timestamp(1000) timestamp := Timestamp(1000)
updateTSafe(queryCollection, timestamp) updateTSafe(queryCollection, timestamp)
resTimestamp := queryCollection.waitNewTSafe() resTimestamp, err := queryCollection.waitNewTSafe()
assert.NoError(t, err)
assert.Equal(t, timestamp, resTimestamp) assert.Equal(t, timestamp, resTimestamp)
} }

View File

@ -14,7 +14,10 @@ package querynode
import "C" import "C"
import ( import (
"context" "context"
"errors"
"fmt"
"strconv" "strconv"
"sync"
"go.uber.org/zap" "go.uber.org/zap"
@ -31,6 +34,7 @@ type queryService struct {
historical *historical historical *historical
streaming *streaming streaming *streaming
queryCollectionMu sync.Mutex // guards queryCollections
queryCollections map[UniqueID]*queryCollection queryCollections map[UniqueID]*queryCollection
factory msgstream.Factory factory msgstream.Factory
@ -94,17 +98,22 @@ func (q *queryService) close() {
for collectionID := range q.queryCollections { for collectionID := range q.queryCollections {
q.stopQueryCollection(collectionID) q.stopQueryCollection(collectionID)
} }
q.queryCollectionMu.Lock()
q.queryCollections = make(map[UniqueID]*queryCollection) q.queryCollections = make(map[UniqueID]*queryCollection)
q.queryCollectionMu.Unlock()
q.cancel() q.cancel()
} }
func (q *queryService) addQueryCollection(collectionID UniqueID) { func (q *queryService) addQueryCollection(collectionID UniqueID) error {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
if _, ok := q.queryCollections[collectionID]; ok { if _, ok := q.queryCollections[collectionID]; ok {
log.Warn("query collection already exists", zap.Any("collectionID", collectionID)) log.Warn("query collection already exists", zap.Any("collectionID", collectionID))
return err := errors.New(fmt.Sprintln("query collection already exists, collectionID = ", collectionID))
return err
} }
ctx1, cancel := context.WithCancel(q.ctx) ctx1, cancel := context.WithCancel(q.ctx)
qc := newQueryCollection(ctx1, qc, err := newQueryCollection(ctx1,
cancel, cancel,
collectionID, collectionID,
q.historical, q.historical,
@ -114,15 +123,33 @@ func (q *queryService) addQueryCollection(collectionID UniqueID) {
q.remoteChunkManager, q.remoteChunkManager,
q.localCacheEnabled, q.localCacheEnabled,
) )
if err != nil {
return err
}
q.queryCollections[collectionID] = qc q.queryCollections[collectionID] = qc
return nil
} }
func (q *queryService) hasQueryCollection(collectionID UniqueID) bool { func (q *queryService) hasQueryCollection(collectionID UniqueID) bool {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
_, ok := q.queryCollections[collectionID] _, ok := q.queryCollections[collectionID]
return ok return ok
} }
func (q *queryService) getQueryCollection(collectionID UniqueID) (*queryCollection, error) {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
_, ok := q.queryCollections[collectionID]
if ok {
return q.queryCollections[collectionID], nil
}
return nil, errors.New(fmt.Sprintln("queryCollection not exists, collectionID = ", collectionID))
}
func (q *queryService) stopQueryCollection(collectionID UniqueID) { func (q *queryService) stopQueryCollection(collectionID UniqueID) {
q.queryCollectionMu.Lock()
defer q.queryCollectionMu.Unlock()
sc, ok := q.queryCollections[collectionID] sc, ok := q.queryCollections[collectionID]
if !ok { if !ok {
log.Warn("stopQueryCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID)) log.Warn("stopQueryCollection failed, collection doesn't exist", zap.Int64("collectionID", collectionID))

View File

@ -154,7 +154,8 @@ func TestSearch_Search(t *testing.T) {
err = loadFields(segment, DIM, N) err = loadFields(segment, DIM, N)
assert.NoError(t, err) assert.NoError(t, err)
node.queryService.addQueryCollection(collectionID) err = node.queryService.addQueryCollection(collectionID)
assert.Error(t, err)
err = sendSearchRequest(node.queryNodeLoopCtx, DIM) err = sendSearchRequest(node.queryNodeLoopCtx, DIM)
assert.NoError(t, err) assert.NoError(t, err)
@ -184,7 +185,8 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
node.historical, node.historical,
node.streaming, node.streaming,
msFactory) msFactory)
node.queryService.addQueryCollection(collectionID) err = node.queryService.addQueryCollection(collectionID)
assert.Error(t, err)
// load segments // load segments
err = node.historical.replica.addSegment(segmentID1, defaultPartitionID, collectionID, "", segmentTypeSealed, true) err = node.historical.replica.addSegment(segmentID1, defaultPartitionID, collectionID, "", segmentTypeSealed, true)
@ -227,13 +229,16 @@ func TestQueryService_addQueryCollection(t *testing.T) {
qs := newQueryService(ctx, his, str, fac) qs := newQueryService(ctx, his, str, fac)
assert.NotNil(t, qs) assert.NotNil(t, qs)
qs.addQueryCollection(defaultCollectionID) err = qs.addQueryCollection(defaultCollectionID)
assert.NoError(t, err)
assert.Len(t, qs.queryCollections, 1) assert.Len(t, qs.queryCollections, 1)
qs.addQueryCollection(defaultCollectionID) err = qs.addQueryCollection(defaultCollectionID)
assert.Error(t, err)
assert.Len(t, qs.queryCollections, 1) assert.Len(t, qs.queryCollections, 1)
const invalidCollectionID = 10000 const invalidCollectionID = 10000
qs.addQueryCollection(invalidCollectionID) err = qs.addQueryCollection(invalidCollectionID)
assert.Len(t, qs.queryCollections, 2) assert.Error(t, err)
assert.Len(t, qs.queryCollections, 1)
} }

View File

@ -236,6 +236,18 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) error {
log.Debug("query node add collection flow graphs", zap.Any("channels", vChannels)) log.Debug("query node add collection flow graphs", zap.Any("channels", vChannels))
} }
// add tSafe watcher if queryCollection exists
qc, err := w.node.queryService.getQueryCollection(collectionID)
if err == nil {
for _, channel := range vChannels {
err = qc.addTSafeWatcher(channel)
if err != nil {
// tSafe have been exist, not error
log.Warn(err.Error())
}
}
}
// channels as consumer // channels as consumer
var nodeFGs map[Channel]*queryNodeFlowGraph var nodeFGs map[Channel]*queryNodeFlowGraph
if loadPartition { if loadPartition {
@ -467,7 +479,11 @@ func (r *releaseCollectionTask) Execute(ctx context.Context) error {
zap.Any("collectionID", r.req.CollectionID), zap.Any("collectionID", r.req.CollectionID),
zap.Any("vChannel", channel), zap.Any("vChannel", channel),
) )
r.node.streaming.tSafeReplica.removeTSafe(channel) // no tSafe in tSafeReplica, don't return error
err = r.node.streaming.tSafeReplica.removeTSafe(channel)
if err != nil {
log.Warn(err.Error())
}
} }
// remove excludedSegments record // remove excludedSegments record
@ -561,7 +577,11 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error {
zap.Any("partitionID", id), zap.Any("partitionID", id),
zap.Any("vChannel", channel), zap.Any("vChannel", channel),
) )
r.node.streaming.tSafeReplica.removeTSafe(channel) // no tSafe in tSafeReplica, don't return error
err = r.node.streaming.tSafeReplica.removeTSafe(channel)
if err != nil {
log.Warn(err.Error())
}
} }
} }

View File

@ -16,6 +16,8 @@ import (
"math" "math"
"sync" "sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
) )
@ -51,6 +53,7 @@ type tSafer interface {
registerTSafeWatcher(t *tSafeWatcher) registerTSafeWatcher(t *tSafeWatcher)
start() start()
close() close()
removeRecord(partitionID UniqueID)
} }
type tSafeMsg struct { type tSafeMsg struct {
@ -89,7 +92,9 @@ func (ts *tSafe) start() {
for { for {
select { select {
case <-ts.ctx.Done(): case <-ts.ctx.Done():
log.Debug("tSafe context done") log.Debug("tSafe context done",
zap.Any("channel", ts.channel),
)
return return
case m := <-ts.tSafeChan: case m := <-ts.tSafeChan:
ts.tSafeMu.Lock() ts.tSafeMu.Lock()
@ -116,6 +121,21 @@ func (ts *tSafe) start() {
}() }()
} }
// removeRecord for deleting the old partition which has been released,
// if we don't delete this, tSafe would always be the old partition's timestamp
// (because we set tSafe to the minimum timestamp) from old partition
// flow graph which has been closed and would not update tSafe any more.
// removeRecord should be called when flow graph is been removed.
func (ts *tSafe) removeRecord(partitionID UniqueID) {
ts.tSafeMu.Lock()
defer ts.tSafeMu.Unlock()
log.Debug("remove tSafeRecord",
zap.Any("partitionID", partitionID),
)
delete(ts.tSafeRecord, partitionID)
}
func (ts *tSafe) registerTSafeWatcher(t *tSafeWatcher) { func (ts *tSafe) registerTSafeWatcher(t *tSafeWatcher) {
ts.tSafeMu.Lock() ts.tSafeMu.Lock()
defer ts.tSafeMu.Unlock() defer ts.tSafeMu.Unlock()

View File

@ -23,38 +23,48 @@ import (
// TSafeReplicaInterface is the interface wrapper of tSafeReplica // TSafeReplicaInterface is the interface wrapper of tSafeReplica
type TSafeReplicaInterface interface { type TSafeReplicaInterface interface {
getTSafe(vChannel Channel) Timestamp getTSafe(vChannel Channel) (Timestamp, error)
setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error
addTSafe(vChannel Channel) addTSafe(vChannel Channel)
removeTSafe(vChannel Channel) removeTSafe(vChannel Channel) error
registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error
removeRecord(vChannel Channel, partitionID UniqueID) error
}
type tSafeRef struct {
tSafer tSafer
ref int
} }
type tSafeReplica struct { type tSafeReplica struct {
mu sync.Mutex // guards tSafes mu sync.Mutex // guards tSafes
tSafes map[string]tSafer // map[vChannel]tSafer tSafes map[Channel]*tSafeRef // map[vChannel]tSafeRef
} }
func (t *tSafeReplica) getTSafe(vChannel Channel) Timestamp { func (t *tSafeReplica) getTSafe(vChannel Channel) (Timestamp, error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel) safer, err := t.getTSaferPrivate(vChannel)
if err != nil { if err != nil {
log.Warn("get tSafe failed", zap.Error(err)) //log.Warn("get tSafe failed",
return 0 // zap.Any("channel", vChannel),
// zap.Error(err),
//)
return 0, err
} }
return safer.get() return safer.get(), nil
} }
func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) { func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel) safer, err := t.getTSaferPrivate(vChannel)
if err != nil { if err != nil {
log.Warn("set tSafe failed", zap.Error(err)) //log.Warn("set tSafe failed", zap.Error(err))
return return err
} }
safer.set(id, timestamp) safer.set(id, timestamp)
return nil
} }
func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) { func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) {
@ -63,7 +73,7 @@ func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) {
//log.Warn(err.Error()) //log.Warn(err.Error())
return nil, err return nil, err
} }
return t.tSafes[vChannel], nil return t.tSafes[vChannel].tSafer, nil
} }
func (t *tSafeReplica) addTSafe(vChannel Channel) { func (t *tSafeReplica) addTSafe(vChannel Channel) {
@ -71,42 +81,74 @@ func (t *tSafeReplica) addTSafe(vChannel Channel) {
defer t.mu.Unlock() defer t.mu.Unlock()
ctx := context.Background() ctx := context.Background()
if _, ok := t.tSafes[vChannel]; !ok { if _, ok := t.tSafes[vChannel]; !ok {
t.tSafes[vChannel] = newTSafe(ctx, vChannel) t.tSafes[vChannel] = &tSafeRef{
t.tSafes[vChannel].start() tSafer: newTSafe(ctx, vChannel),
log.Debug("add tSafe done", zap.Any("channel", vChannel)) ref: 1,
}
t.tSafes[vChannel].tSafer.start()
log.Debug("add tSafe done",
zap.Any("channel", vChannel),
zap.Any("count", t.tSafes[vChannel].ref),
)
} else { } else {
log.Warn("tSafe has been existed", zap.Any("channel", vChannel)) t.tSafes[vChannel].ref++
log.Debug("tSafe has been existed",
zap.Any("channel", vChannel),
zap.Any("count", t.tSafes[vChannel].ref),
)
} }
} }
func (t *tSafeReplica) removeTSafe(vChannel Channel) { func (t *tSafeReplica) removeTSafe(vChannel Channel) error {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if _, ok := t.tSafes[vChannel]; !ok {
return errors.New("tSafe not exist, vChannel = " + vChannel)
}
t.tSafes[vChannel].ref--
log.Debug("reduce tSafe reference count",
zap.Any("vChannel", vChannel),
zap.Any("count", t.tSafes[vChannel].ref),
)
if t.tSafes[vChannel].ref == 0 {
safer, err := t.getTSaferPrivate(vChannel) safer, err := t.getTSaferPrivate(vChannel)
if err != nil { if err != nil {
return return err
} }
log.Debug("remove tSafe replica", log.Debug("remove tSafe replica",
zap.Any("vChannel", vChannel), zap.Any("vChannel", vChannel),
) )
safer.close() safer.close()
delete(t.tSafes, vChannel) delete(t.tSafes, vChannel)
}
return nil
} }
func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) { func (t *tSafeReplica) removeRecord(vChannel Channel, partitionID UniqueID) error {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel) safer, err := t.getTSaferPrivate(vChannel)
if err != nil { if err != nil {
log.Warn("register tSafe watcher failed", zap.Error(err)) return err
return }
safer.removeRecord(partitionID)
return nil
}
func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error {
t.mu.Lock()
defer t.mu.Unlock()
safer, err := t.getTSaferPrivate(vChannel)
if err != nil {
return err
} }
safer.registerTSafeWatcher(watcher) safer.registerTSafeWatcher(watcher)
return nil
} }
func newTSafeReplica() TSafeReplicaInterface { func newTSafeReplica() TSafeReplicaInterface {
var replica TSafeReplicaInterface = &tSafeReplica{ var replica TSafeReplicaInterface = &tSafeReplica{
tSafes: make(map[string]tSafer), tSafes: make(map[string]*tSafeRef),
} }
return replica return replica
} }

View File

@ -23,30 +23,39 @@ func TestTSafeReplica_valid(t *testing.T) {
replica.addTSafe(defaultVChannel) replica.addTSafe(defaultVChannel)
watcher := newTSafeWatcher() watcher := newTSafeWatcher()
replica.registerTSafeWatcher(defaultVChannel, watcher) err := replica.registerTSafeWatcher(defaultVChannel, watcher)
assert.NoError(t, err)
timestamp := Timestamp(1000) timestamp := Timestamp(1000)
replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp) err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
resT := replica.getTSafe(defaultVChannel) resT, err := replica.getTSafe(defaultVChannel)
assert.NoError(t, err)
assert.Equal(t, timestamp, resT) assert.Equal(t, timestamp, resT)
replica.removeTSafe(defaultVChannel) err = replica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
} }
func TestTSafeReplica_invalid(t *testing.T) { func TestTSafeReplica_invalid(t *testing.T) {
replica := newTSafeReplica() replica := newTSafeReplica()
replica.addTSafe(defaultVChannel)
watcher := newTSafeWatcher() watcher := newTSafeWatcher()
replica.registerTSafeWatcher(defaultVChannel, watcher) err := replica.registerTSafeWatcher(defaultVChannel, watcher)
assert.NoError(t, err)
timestamp := Timestamp(1000) timestamp := Timestamp(1000)
replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp) err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
resT := replica.getTSafe(defaultVChannel) resT, err := replica.getTSafe(defaultVChannel)
assert.Equal(t, Timestamp(0), resT) assert.NoError(t, err)
assert.Equal(t, timestamp, resT)
replica.removeTSafe(defaultVChannel) err = replica.removeTSafe(defaultVChannel)
assert.NoError(t, err)
replica.addTSafe(defaultVChannel) replica.addTSafe(defaultVChannel)
replica.addTSafe(defaultVChannel) replica.addTSafe(defaultVChannel)