mirror of https://github.com/milvus-io/milvus.git
Fix timetick block caused by dml task pop failed (#23291)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/23361/head
parent
f2f7d8ed53
commit
cf321afe29
|
@ -5,10 +5,18 @@ type removeDMLStreamFuncType = func(collectionID UniqueID) error
|
|||
|
||||
type mockChannelsMgr struct {
|
||||
channelsMgr
|
||||
getChannelsFunc func(collectionID UniqueID) ([]pChan, error)
|
||||
getVChannelsFuncType
|
||||
removeDMLStreamFuncType
|
||||
}
|
||||
|
||||
func (m *mockChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) {
|
||||
if m.getChannelsFunc != nil {
|
||||
return m.getChannelsFunc(collectionID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
|
||||
if m.getVChannelsFuncType != nil {
|
||||
return m.getVChannelsFuncType(collectionID)
|
||||
|
|
|
@ -78,11 +78,19 @@ func (dt *deleteTask) OnEnqueue() error {
|
|||
}
|
||||
|
||||
func (dt *deleteTask) getChannels() ([]pChan, error) {
|
||||
if len(dt.pChannels) != 0 {
|
||||
return dt.pChannels, nil
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dt.chMgr.getChannels(collID)
|
||||
channels, err := dt.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dt.pChannels = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
)
|
||||
|
||||
func Test_getPrimaryKeysFromExpr(t *testing.T) {
|
||||
|
@ -37,3 +41,41 @@ func Test_getPrimaryKeysFromExpr(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteTask(t *testing.T) {
|
||||
t.Run("test getChannels", func(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
dt := deleteTask{
|
||||
ctx: context.Background(),
|
||||
BaseDeleteTask: msgstream.DeleteMsg{
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := dt.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, dt.pChannels)
|
||||
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = dt.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -72,11 +72,19 @@ func (it *insertTask) EndTs() Timestamp {
|
|||
}
|
||||
|
||||
func (it *insertTask) getChannels() ([]pChan, error) {
|
||||
if len(it.pChannels) != 0 {
|
||||
return it.pChannels, nil
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return it.chMgr.getChannels(collID)
|
||||
channels, err := it.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
it.pChannels = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (it *insertTask) OnEnqueue() error {
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -345,3 +348,41 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
err = case2.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInsertTask(t *testing.T) {
|
||||
t.Run("test getChannels", func(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
it := insertTask{
|
||||
ctx: context.Background(),
|
||||
BaseInsertTask: msgstream.InsertMsg{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := it.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, it.pChannels)
|
||||
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = it.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -228,15 +228,19 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
|
|||
// This statsLock has two functions:
|
||||
// 1) Protect member pChanStatisticsInfos
|
||||
// 2) Serialize the timestamp allocation for dml tasks
|
||||
queue.statsLock.Lock()
|
||||
defer queue.statsLock.Unlock()
|
||||
|
||||
//1. preAdd will check whether provided task is valid or addable
|
||||
//and get the current pChannels for this dmTask
|
||||
pChannels, dmt, err := queue.preAddPChanStats(t)
|
||||
dmt := t.(dmlTask)
|
||||
pChannels, err := dmt.getChannels()
|
||||
if err != nil {
|
||||
log.Warn("getChannels failed when Enqueue", zap.Any("tID", t.ID()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
//2. enqueue dml task
|
||||
queue.statsLock.Lock()
|
||||
defer queue.statsLock.Unlock()
|
||||
err = queue.baseTaskQueue.Enqueue(t)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -266,19 +270,6 @@ func (queue *dmTaskQueue) PopActiveTask(taskID UniqueID) task {
|
|||
return t
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) preAddPChanStats(t task) ([]pChan, dmlTask, error) {
|
||||
if dmT, ok := t.(dmlTask); ok {
|
||||
channels, err := dmT.getChannels()
|
||||
if err != nil {
|
||||
log.Warn("Proxy dmTaskQueue preAddPChanStats getChannels failed", zap.Any("tID", t.ID()),
|
||||
zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
return channels, dmT, nil
|
||||
}
|
||||
return nil, nil, fmt.Errorf("proxy preAddPChanStats reflect to dmlTask failed, tID:%v", t.ID())
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
|
||||
//1. prepare new stat for all pChannels
|
||||
newStats := make(map[pChan]pChanStatistics)
|
||||
|
@ -313,34 +304,31 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
|
|||
}
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) popPChanStats(t task) error {
|
||||
if dmT, ok := t.(dmlTask); ok {
|
||||
channels, err := dmT.getChannels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskTs := t.BeginTs()
|
||||
for _, cName := range channels {
|
||||
info, ok := queue.pChanStatisticsInfos[cName]
|
||||
if ok {
|
||||
delete(info.tsSet, taskTs)
|
||||
if len(info.tsSet) <= 0 {
|
||||
delete(queue.pChanStatisticsInfos, cName)
|
||||
} else {
|
||||
newMinTs := info.maxTs
|
||||
for ts := range info.tsSet {
|
||||
if newMinTs > ts {
|
||||
newMinTs = ts
|
||||
}
|
||||
func (queue *dmTaskQueue) popPChanStats(t task) {
|
||||
channels, err := t.(dmlTask).getChannels()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("get channels failed when popPChanStats, err=%w", err)
|
||||
log.Error(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
taskTs := t.BeginTs()
|
||||
for _, cName := range channels {
|
||||
info, ok := queue.pChanStatisticsInfos[cName]
|
||||
if ok {
|
||||
delete(info.tsSet, taskTs)
|
||||
if len(info.tsSet) <= 0 {
|
||||
delete(queue.pChanStatisticsInfos, cName)
|
||||
} else {
|
||||
newMinTs := info.maxTs
|
||||
for ts := range info.tsSet {
|
||||
if newMinTs > ts {
|
||||
newMinTs = ts
|
||||
}
|
||||
info.minTs = newMinTs
|
||||
}
|
||||
info.minTs = newMinTs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
|
||||
|
|
|
@ -24,9 +24,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
)
|
||||
|
||||
func TestBaseTaskQueue(t *testing.T) {
|
||||
|
@ -200,11 +203,6 @@ func TestDmTaskQueue_Basic(t *testing.T) {
|
|||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
//test wrong task type
|
||||
dqlTask := newDefaultMockDqlTask()
|
||||
err = queue.Enqueue(dqlTask)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
st := newDefaultMockDmlTask()
|
||||
stID := st.ID()
|
||||
|
||||
|
@ -570,3 +568,51 @@ func TestTaskScheduler(t *testing.T) {
|
|||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
|
||||
assert.NoError(t, err)
|
||||
|
||||
run := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
it := &insertTask{
|
||||
ctx: context.Background(),
|
||||
BaseInsertTask: msgstream.InsertMsg{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
err := scheduler.dmQueue.Enqueue(it)
|
||||
assert.NoError(t, err)
|
||||
task := scheduler.scheduleDmTask()
|
||||
scheduler.dmQueue.AddActiveTask(task)
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
scheduler.dmQueue.PopActiveTask(task.ID()) // assert no panic
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go run(wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue