mirror of https://github.com/milvus-io/milvus.git
Fix timetick block caused by dml task pop failed (#23277)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/23356/head
parent
5276529524
commit
0b354cbab1
|
@ -5,10 +5,18 @@ type removeDMLStreamFuncType = func(collectionID UniqueID) error
|
|||
|
||||
type mockChannelsMgr struct {
|
||||
channelsMgr
|
||||
getChannelsFunc func(collectionID UniqueID) ([]pChan, error)
|
||||
getVChannelsFuncType
|
||||
removeDMLStreamFuncType
|
||||
}
|
||||
|
||||
func (m *mockChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) {
|
||||
if m.getChannelsFunc != nil {
|
||||
return m.getChannelsFunc(collectionID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) {
|
||||
if m.getVChannelsFuncType != nil {
|
||||
return m.getVChannelsFuncType(collectionID)
|
||||
|
|
|
@ -82,11 +82,19 @@ func (dt *deleteTask) OnEnqueue() error {
|
|||
}
|
||||
|
||||
func (dt *deleteTask) getChannels() ([]pChan, error) {
|
||||
if len(dt.pChannels) != 0 {
|
||||
return dt.pChannels, nil
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dt.chMgr.getChannels(collID)
|
||||
channels, err := dt.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dt.pChannels = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func Test_getPrimaryKeysFromExpr(t *testing.T) {
|
||||
|
@ -37,3 +42,41 @@ func Test_getPrimaryKeysFromExpr(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteTask(t *testing.T) {
|
||||
t.Run("test getChannels", func(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
dt := deleteTask{
|
||||
ctx: context.Background(),
|
||||
deleteMsg: &msgstream.DeleteMsg{
|
||||
DeleteRequest: msgpb.DeleteRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := dt.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, dt.pChannels)
|
||||
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = dt.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -71,11 +71,19 @@ func (it *insertTask) EndTs() Timestamp {
|
|||
}
|
||||
|
||||
func (it *insertTask) getChannels() ([]pChan, error) {
|
||||
if len(it.pChannels) != 0 {
|
||||
return it.pChannels, nil
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return it.chMgr.getChannels(collID)
|
||||
channels, err := it.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
it.pChannels = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (it *insertTask) OnEnqueue() error {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -8,6 +10,8 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestInsertTask_CheckAligned(t *testing.T) {
|
||||
|
@ -219,3 +223,41 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInsertTask(t *testing.T) {
|
||||
t.Run("test getChannels", func(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
it := insertTask{
|
||||
ctx: context.Background(),
|
||||
insertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: msgpb.InsertRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := it.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, it.pChannels)
|
||||
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = it.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -227,15 +227,19 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
|
|||
// This statsLock has two functions:
|
||||
// 1) Protect member pChanStatisticsInfos
|
||||
// 2) Serialize the timestamp allocation for dml tasks
|
||||
queue.statsLock.Lock()
|
||||
defer queue.statsLock.Unlock()
|
||||
|
||||
//1. preAdd will check whether provided task is valid or addable
|
||||
//and get the current pChannels for this dmTask
|
||||
pChannels, dmt, err := queue.preAddPChanStats(t)
|
||||
dmt := t.(dmlTask)
|
||||
pChannels, err := dmt.getChannels()
|
||||
if err != nil {
|
||||
log.Warn("getChannels failed when Enqueue", zap.Any("tID", t.ID()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
//2. enqueue dml task
|
||||
queue.statsLock.Lock()
|
||||
defer queue.statsLock.Unlock()
|
||||
err = queue.baseTaskQueue.Enqueue(t)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -265,19 +269,6 @@ func (queue *dmTaskQueue) PopActiveTask(taskID UniqueID) task {
|
|||
return t
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) preAddPChanStats(t task) ([]pChan, dmlTask, error) {
|
||||
if dmT, ok := t.(dmlTask); ok {
|
||||
channels, err := dmT.getChannels()
|
||||
if err != nil {
|
||||
log.Warn("Proxy dmTaskQueue preAddPChanStats getChannels failed", zap.Any("tID", t.ID()),
|
||||
zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
return channels, dmT, nil
|
||||
}
|
||||
return nil, nil, fmt.Errorf("proxy preAddPChanStats reflect to dmlTask failed, tID:%v", t.ID())
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
|
||||
//1. prepare new stat for all pChannels
|
||||
newStats := make(map[pChan]pChanStatistics)
|
||||
|
@ -312,34 +303,31 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
|
|||
}
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) popPChanStats(t task) error {
|
||||
if dmT, ok := t.(dmlTask); ok {
|
||||
channels, err := dmT.getChannels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskTs := t.BeginTs()
|
||||
for _, cName := range channels {
|
||||
info, ok := queue.pChanStatisticsInfos[cName]
|
||||
if ok {
|
||||
delete(info.tsSet, taskTs)
|
||||
if len(info.tsSet) <= 0 {
|
||||
delete(queue.pChanStatisticsInfos, cName)
|
||||
} else {
|
||||
newMinTs := info.maxTs
|
||||
for ts := range info.tsSet {
|
||||
if newMinTs > ts {
|
||||
newMinTs = ts
|
||||
}
|
||||
func (queue *dmTaskQueue) popPChanStats(t task) {
|
||||
channels, err := t.(dmlTask).getChannels()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("get channels failed when popPChanStats, err=%w", err)
|
||||
log.Error(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
taskTs := t.BeginTs()
|
||||
for _, cName := range channels {
|
||||
info, ok := queue.pChanStatisticsInfos[cName]
|
||||
if ok {
|
||||
delete(info.tsSet, taskTs)
|
||||
if len(info.tsSet) <= 0 {
|
||||
delete(queue.pChanStatisticsInfos, cName)
|
||||
} else {
|
||||
newMinTs := info.maxTs
|
||||
for ts := range info.tsSet {
|
||||
if newMinTs > ts {
|
||||
newMinTs = ts
|
||||
}
|
||||
info.minTs = newMinTs
|
||||
}
|
||||
info.minTs = newMinTs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
|
||||
|
|
|
@ -26,7 +26,11 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestBaseTaskQueue(t *testing.T) {
|
||||
|
@ -197,11 +201,6 @@ func TestDmTaskQueue_Basic(t *testing.T) {
|
|||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
//test wrong task type
|
||||
dqlTask := newDefaultMockDqlTask()
|
||||
err = queue.Enqueue(dqlTask)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
st := newDefaultMockDmlTask()
|
||||
stID := st.ID()
|
||||
|
||||
|
@ -563,3 +562,51 @@ func TestTaskScheduler(t *testing.T) {
|
|||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
|
||||
assert.NoError(t, err)
|
||||
|
||||
run := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
it := &insertTask{
|
||||
ctx: context.Background(),
|
||||
insertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
err := scheduler.dmQueue.Enqueue(it)
|
||||
assert.NoError(t, err)
|
||||
task := scheduler.scheduleDmTask()
|
||||
scheduler.dmQueue.AddActiveTask(task)
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
scheduler.dmQueue.PopActiveTask(task.ID()) // assert no panic
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go run(wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
@ -116,11 +116,19 @@ func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
|||
}
|
||||
|
||||
func (it *upsertTask) getChannels() ([]pChan, error) {
|
||||
if len(it.pChannels) != 0 {
|
||||
return it.pChannels, nil
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return it.chMgr.getChannels(collID)
|
||||
channels, err := it.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
it.pChannels = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) OnEnqueue() error {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -26,6 +28,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func TestUpsertTask_CheckAligned(t *testing.T) {
|
||||
|
@ -290,3 +293,39 @@ func TestUpsertTask_CheckAligned(t *testing.T) {
|
|||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUpsertTask(t *testing.T) {
|
||||
t.Run("test getChannels", func(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionName := "col-0"
|
||||
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
chMgr := newMockChannelsMgr()
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return channels, nil
|
||||
}
|
||||
ut := upsertTask{
|
||||
ctx: context.Background(),
|
||||
req: &milvuspb.UpsertRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := ut.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, ut.pChannels)
|
||||
|
||||
chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) {
|
||||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = ut.getChannels()
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue