Set pchannels before dml enqueue to prevent panic (#24828)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/24921/head
yihao.dai 2023-06-16 16:36:40 +08:00 committed by GitHub
parent 969517f910
commit b62429070c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 48 additions and 49 deletions

View File

@ -186,8 +186,12 @@ type mockDmlTask struct {
pchans []pChan
}
func (m *mockDmlTask) getChannels() ([]vChan, error) {
return m.vchans, nil
func (m *mockDmlTask) setChannels() error {
return nil
}
func (m *mockDmlTask) getChannels() []vChan {
return m.vchans
}
func newMockDmlTask(ctx context.Context) *mockDmlTask {

View File

@ -105,7 +105,8 @@ type task interface {
type dmlTask interface {
task
getChannels() ([]pChan, error)
setChannels() error
getChannels() []pChan
}
type BaseInsertTask = msgstream.InsertMsg

View File

@ -81,20 +81,21 @@ func (dt *deleteTask) OnEnqueue() error {
return nil
}
func (dt *deleteTask) getChannels() ([]pChan, error) {
if len(dt.pChannels) != 0 {
return dt.pChannels, nil
}
func (dt *deleteTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName)
if err != nil {
return nil, err
return err
}
channels, err := dt.chMgr.getChannels(collID)
if err != nil {
return nil, err
return err
}
dt.pChannels = channels
return channels, nil
return nil
}
func (dt *deleteTask) getChannels() []pChan {
return dt.pChannels
}
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {

View File

@ -66,8 +66,9 @@ func TestDeleteTask(t *testing.T) {
},
chMgr: chMgr,
}
resChannels, err := dt.getChannels()
err := dt.setChannels()
assert.NoError(t, err)
resChannels := dt.getChannels()
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, dt.pChannels)
@ -75,8 +76,7 @@ func TestDeleteTask(t *testing.T) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = dt.getChannels()
assert.NoError(t, err)
resChannels = dt.getChannels()
assert.ElementsMatch(t, channels, resChannels)
})
}

View File

@ -73,20 +73,21 @@ func (it *insertTask) EndTs() Timestamp {
return it.insertMsg.EndTimestamp
}
func (it *insertTask) getChannels() ([]pChan, error) {
if len(it.pChannels) != 0 {
return it.pChannels, nil
}
func (it *insertTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName)
if err != nil {
return nil, err
return err
}
channels, err := it.chMgr.getChannels(collID)
if err != nil {
return nil, err
return err
}
it.pChannels = channels
return channels, nil
return nil
}
func (it *insertTask) getChannels() []pChan {
return it.pChannels
}
func (it *insertTask) OnEnqueue() error {

View File

@ -248,8 +248,9 @@ func TestInsertTask(t *testing.T) {
},
chMgr: chMgr,
}
resChannels, err := it.getChannels()
err := it.setChannels()
assert.NoError(t, err)
resChannels := it.getChannels()
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, it.pChannels)
@ -257,8 +258,7 @@ func TestInsertTask(t *testing.T) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = it.getChannels()
assert.NoError(t, err)
resChannels = it.getChannels()
assert.ElementsMatch(t, channels, resChannels)
})
}

View File

@ -19,7 +19,6 @@ package proxy
import (
"container/list"
"context"
"fmt"
"sync"
"github.com/cockroachdb/errors"
@ -228,12 +227,11 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
// 1) Protect member pChanStatisticsInfos
// 2) Serialize the timestamp allocation for dml tasks
//1. preAdd will check whether provided task is valid or addable
//and get the current pChannels for this dmTask
//1. set the current pChannels for this dmTask
dmt := t.(dmlTask)
pChannels, err := dmt.getChannels()
err := dmt.setChannels()
if err != nil {
log.Warn("getChannels failed when Enqueue", zap.Any("tID", t.ID()), zap.Error(err))
log.Warn("setChannels failed when Enqueue", zap.Int64("taskID", t.ID()), zap.Error(err))
return err
}
@ -244,7 +242,8 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
if err != nil {
return err
}
//3. if preAdd succeed, commit will use pChannels got previously when preAdding and will definitely succeed
//3. commit will use pChannels got previously when preAdding and will definitely succeed
pChannels := dmt.getChannels()
queue.commitPChanStats(dmt, pChannels)
//there's indeed a possibility that the collection info cache was expired after preAddPChanStats
//but considering root coord knows everything about meta modification, invalid stats appended after the meta changed
@ -304,12 +303,7 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
}
func (queue *dmTaskQueue) popPChanStats(t task) {
channels, err := t.(dmlTask).getChannels()
if err != nil {
err = fmt.Errorf("get channels failed when popPChanStats, err=%w", err)
log.Error(err.Error())
panic(err)
}
channels := t.(dmlTask).getChannels()
taskTs := t.BeginTs()
for _, cName := range channels {
info, ok := queue.pChanStatisticsInfos[cName]

View File

@ -100,10 +100,7 @@ func (it *upsertTask) EndTs() Timestamp {
func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
ret := make(map[pChan]pChanStatistics)
channels, err := it.getChannels()
if err != nil {
return ret, err
}
channels := it.getChannels()
beginTs := it.BeginTs()
endTs := it.EndTs()
@ -117,20 +114,21 @@ func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
return ret, nil
}
func (it *upsertTask) getChannels() ([]pChan, error) {
if len(it.pChannels) != 0 {
return it.pChannels, nil
}
func (it *upsertTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName)
if err != nil {
return nil, err
return err
}
channels, err := it.chMgr.getChannels(collID)
if err != nil {
return nil, err
return err
}
it.pChannels = channels
return channels, nil
return nil
}
func (it *upsertTask) getChannels() []pChan {
return it.pChannels
}
func (it *upsertTask) OnEnqueue() error {

View File

@ -315,8 +315,9 @@ func TestUpsertTask(t *testing.T) {
},
chMgr: chMgr,
}
resChannels, err := ut.getChannels()
err := ut.setChannels()
assert.NoError(t, err)
resChannels := ut.getChannels()
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, ut.pChannels)
@ -324,8 +325,7 @@ func TestUpsertTask(t *testing.T) {
return nil, fmt.Errorf("mock err")
}
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
resChannels, err = ut.getChannels()
assert.NoError(t, err)
resChannels = ut.getChannels()
assert.ElementsMatch(t, channels, resChannels)
})
}