Fix timetick block caused by dml task pop failed (#23291)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/23361/head
yihao.dai 2023-04-11 14:18:30 +08:00 committed by GitHub
parent f2f7d8ed53
commit cf321afe29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 189 additions and 48 deletions

View File

@ -5,10 +5,18 @@ type removeDMLStreamFuncType = func(collectionID UniqueID) error
type mockChannelsMgr struct {
channelsMgr
getChannelsFunc func(collectionID UniqueID) ([]pChan, error)
getVChannelsFuncType
removeDMLStreamFuncType
}
func (m *mockChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) {
if m.getChannelsFunc != nil {
return m.getChannelsFunc(collectionID)
}
return nil, nil
}
func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
if m.getVChannelsFuncType != nil {
return m.getVChannelsFuncType(collectionID)

View File

@ -78,11 +78,19 @@ func (dt *deleteTask) OnEnqueue() error {
}
func (dt *deleteTask) getChannels() ([]pChan, error) {
if len(dt.pChannels) != 0 {
return dt.pChannels, nil
}
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.CollectionName)
if err != nil {
return nil, err
}
return dt.chMgr.getChannels(collID)
channels, err := dt.chMgr.getChannels(collID)
if err != nil {
return nil, err
}
dt.pChannels = channels
return channels, nil
}
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {

View File

@ -1,12 +1,16 @@
package proxy
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
func Test_getPrimaryKeysFromExpr(t *testing.T) {
@ -37,3 +41,41 @@ func Test_getPrimaryKeysFromExpr(t *testing.T) {
assert.Error(t, err)
})
}
func TestDeleteTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
dt := deleteTask{
ctx: context.Background(),
BaseDeleteTask: msgstream.DeleteMsg{
DeleteRequest: internalpb.DeleteRequest{
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
resChannels, err := dt.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, dt.pChannels)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = dt.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
})
}

View File

@ -72,11 +72,19 @@ func (it *insertTask) EndTs() Timestamp {
}
func (it *insertTask) getChannels() ([]pChan, error) {
if len(it.pChannels) != 0 {
return it.pChannels, nil
}
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.CollectionName)
if err != nil {
return nil, err
}
return it.chMgr.getChannels(collID)
channels, err := it.chMgr.getChannels(collID)
if err != nil {
return nil, err
}
it.pChannels = channels
return channels, nil
}
func (it *insertTask) OnEnqueue() error {

View File

@ -1,10 +1,13 @@
package proxy
import (
"context"
"fmt"
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/stretchr/testify/assert"
)
@ -345,3 +348,41 @@ func TestInsertTask_CheckAligned(t *testing.T) {
err = case2.CheckAligned()
assert.NoError(t, err)
}
func TestInsertTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
it := insertTask{
ctx: context.Background(),
BaseInsertTask: msgstream.InsertMsg{
InsertRequest: internalpb.InsertRequest{
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
resChannels, err := it.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, it.pChannels)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = it.getChannels()
assert.NoError(t, err)
assert.ElementsMatch(t, channels, resChannels)
})
}

View File

@ -228,15 +228,19 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
// This statsLock has two functions:
// 1) Protect member pChanStatisticsInfos
// 2) Serialize the timestamp allocation for dml tasks
queue.statsLock.Lock()
defer queue.statsLock.Unlock()
//1. preAdd will check whether provided task is valid or addable
//and get the current pChannels for this dmTask
pChannels, dmt, err := queue.preAddPChanStats(t)
dmt := t.(dmlTask)
pChannels, err := dmt.getChannels()
if err != nil {
log.Warn("getChannels failed when Enqueue", zap.Any("tID", t.ID()), zap.Error(err))
return err
}
//2. enqueue dml task
queue.statsLock.Lock()
defer queue.statsLock.Unlock()
err = queue.baseTaskQueue.Enqueue(t)
if err != nil {
return err
@ -266,19 +270,6 @@ func (queue *dmTaskQueue) PopActiveTask(taskID UniqueID) task {
return t
}
func (queue *dmTaskQueue) preAddPChanStats(t task) ([]pChan, dmlTask, error) {
if dmT, ok := t.(dmlTask); ok {
channels, err := dmT.getChannels()
if err != nil {
log.Warn("Proxy dmTaskQueue preAddPChanStats getChannels failed", zap.Any("tID", t.ID()),
zap.Error(err))
return nil, nil, err
}
return channels, dmT, nil
}
return nil, nil, fmt.Errorf("proxy preAddPChanStats reflect to dmlTask failed, tID:%v", t.ID())
}
func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
//1. prepare new stat for all pChannels
newStats := make(map[pChan]pChanStatistics)
@ -313,34 +304,31 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
}
}
func (queue *dmTaskQueue) popPChanStats(t task) error {
if dmT, ok := t.(dmlTask); ok {
channels, err := dmT.getChannels()
if err != nil {
return err
}
taskTs := t.BeginTs()
for _, cName := range channels {
info, ok := queue.pChanStatisticsInfos[cName]
if ok {
delete(info.tsSet, taskTs)
if len(info.tsSet) <= 0 {
delete(queue.pChanStatisticsInfos, cName)
} else {
newMinTs := info.maxTs
for ts := range info.tsSet {
if newMinTs > ts {
newMinTs = ts
}
func (queue *dmTaskQueue) popPChanStats(t task) {
channels, err := t.(dmlTask).getChannels()
if err != nil {
err = fmt.Errorf("get channels failed when popPChanStats, err=%w", err)
log.Error(err.Error())
panic(err)
}
taskTs := t.BeginTs()
for _, cName := range channels {
info, ok := queue.pChanStatisticsInfos[cName]
if ok {
delete(info.tsSet, taskTs)
if len(info.tsSet) <= 0 {
delete(queue.pChanStatisticsInfos, cName)
} else {
newMinTs := info.maxTs
for ts := range info.tsSet {
if newMinTs > ts {
newMinTs = ts
}
info.minTs = newMinTs
}
info.minTs = newMinTs
}
}
} else {
return fmt.Errorf("proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
}
return nil
}
func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {

View File

@ -24,9 +24,12 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
)
func TestBaseTaskQueue(t *testing.T) {
@ -200,11 +203,6 @@ func TestDmTaskQueue_Basic(t *testing.T) {
assert.True(t, queue.utEmpty())
assert.False(t, queue.utFull())
//test wrong task type
dqlTask := newDefaultMockDqlTask()
err = queue.Enqueue(dqlTask)
assert.NotNil(t, err)
st := newDefaultMockDmlTask()
stID := st.ID()
@ -570,3 +568,51 @@ func TestTaskScheduler(t *testing.T) {
wg.Wait()
}
func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := newMockCache()
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (UniqueID, error) {
return collectionID, nil
})
globalMetaCache = cache
tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory()
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
assert.NoError(t, err)
run := func(wg *sync.WaitGroup) {
defer wg.Done()
chMgr := newMockChannelsMgr()
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return channels, nil
}
it := &insertTask{
ctx: context.Background(),
BaseInsertTask: msgstream.InsertMsg{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{},
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
err := scheduler.dmQueue.Enqueue(it)
assert.NoError(t, err)
task := scheduler.scheduleDmTask()
scheduler.dmQueue.AddActiveTask(task)
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
return nil, fmt.Errorf("mock err")
}
scheduler.dmQueue.PopActiveTask(task.ID()) // assert no panic
}
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go run(wg)
}
wg.Wait()
}