mirror of https://github.com/milvus-io/milvus.git
Set pchannels before dml enqueue to prevent panic (#24828)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/24921/head
parent
969517f910
commit
b62429070c
|
@ -186,8 +186,12 @@ type mockDmlTask struct {
|
|||
pchans []pChan
|
||||
}
|
||||
|
||||
func (m *mockDmlTask) getChannels() ([]vChan, error) {
|
||||
return m.vchans, nil
|
||||
func (m *mockDmlTask) setChannels() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDmlTask) getChannels() []vChan {
|
||||
return m.vchans
|
||||
}
|
||||
|
||||
func newMockDmlTask(ctx context.Context) *mockDmlTask {
|
||||
|
|
|
@ -105,7 +105,8 @@ type task interface {
|
|||
|
||||
type dmlTask interface {
|
||||
task
|
||||
getChannels() ([]pChan, error)
|
||||
setChannels() error
|
||||
getChannels() []pChan
|
||||
}
|
||||
|
||||
type BaseInsertTask = msgstream.InsertMsg
|
||||
|
|
|
@ -81,20 +81,21 @@ func (dt *deleteTask) OnEnqueue() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (dt *deleteTask) getChannels() ([]pChan, error) {
|
||||
if len(dt.pChannels) != 0 {
|
||||
return dt.pChannels, nil
|
||||
}
|
||||
func (dt *deleteTask) setChannels() error {
|
||||
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
channels, err := dt.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
dt.pChannels = channels
|
||||
return channels, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt *deleteTask) getChannels() []pChan {
|
||||
return dt.pChannels
|
||||
}
|
||||
|
||||
func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) {
|
||||
|
|
|
@ -66,8 +66,9 @@ func TestDeleteTask(t *testing.T) {
|
|||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := dt.getChannels()
|
||||
err := dt.setChannels()
|
||||
assert.NoError(t, err)
|
||||
resChannels := dt.getChannels()
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, dt.pChannels)
|
||||
|
||||
|
@ -75,8 +76,7 @@ func TestDeleteTask(t *testing.T) {
|
|||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = dt.getChannels()
|
||||
assert.NoError(t, err)
|
||||
resChannels = dt.getChannels()
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -73,20 +73,21 @@ func (it *insertTask) EndTs() Timestamp {
|
|||
return it.insertMsg.EndTimestamp
|
||||
}
|
||||
|
||||
func (it *insertTask) getChannels() ([]pChan, error) {
|
||||
if len(it.pChannels) != 0 {
|
||||
return it.pChannels, nil
|
||||
}
|
||||
func (it *insertTask) setChannels() error {
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
channels, err := it.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
it.pChannels = channels
|
||||
return channels, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) getChannels() []pChan {
|
||||
return it.pChannels
|
||||
}
|
||||
|
||||
func (it *insertTask) OnEnqueue() error {
|
||||
|
|
|
@ -248,8 +248,9 @@ func TestInsertTask(t *testing.T) {
|
|||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := it.getChannels()
|
||||
err := it.setChannels()
|
||||
assert.NoError(t, err)
|
||||
resChannels := it.getChannels()
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, it.pChannels)
|
||||
|
||||
|
@ -257,8 +258,7 @@ func TestInsertTask(t *testing.T) {
|
|||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = it.getChannels()
|
||||
assert.NoError(t, err)
|
||||
resChannels = it.getChannels()
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ package proxy
|
|||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
@ -228,12 +227,11 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
|
|||
// 1) Protect member pChanStatisticsInfos
|
||||
// 2) Serialize the timestamp allocation for dml tasks
|
||||
|
||||
//1. preAdd will check whether provided task is valid or addable
|
||||
//and get the current pChannels for this dmTask
|
||||
//1. set the current pChannels for this dmTask
|
||||
dmt := t.(dmlTask)
|
||||
pChannels, err := dmt.getChannels()
|
||||
err := dmt.setChannels()
|
||||
if err != nil {
|
||||
log.Warn("getChannels failed when Enqueue", zap.Any("tID", t.ID()), zap.Error(err))
|
||||
log.Warn("setChannels failed when Enqueue", zap.Int64("taskID", t.ID()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -244,7 +242,8 @@ func (queue *dmTaskQueue) Enqueue(t task) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//3. if preAdd succeed, commit will use pChannels got previously when preAdding and will definitely succeed
|
||||
//3. commit will use pChannels got previously when preAdding and will definitely succeed
|
||||
pChannels := dmt.getChannels()
|
||||
queue.commitPChanStats(dmt, pChannels)
|
||||
//there's indeed a possibility that the collection info cache was expired after preAddPChanStats
|
||||
//but considering root coord knows everything about meta modification, invalid stats appended after the meta changed
|
||||
|
@ -304,12 +303,7 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) {
|
|||
}
|
||||
|
||||
func (queue *dmTaskQueue) popPChanStats(t task) {
|
||||
channels, err := t.(dmlTask).getChannels()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("get channels failed when popPChanStats, err=%w", err)
|
||||
log.Error(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
channels := t.(dmlTask).getChannels()
|
||||
taskTs := t.BeginTs()
|
||||
for _, cName := range channels {
|
||||
info, ok := queue.pChanStatisticsInfos[cName]
|
||||
|
|
|
@ -100,10 +100,7 @@ func (it *upsertTask) EndTs() Timestamp {
|
|||
func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
||||
ret := make(map[pChan]pChanStatistics)
|
||||
|
||||
channels, err := it.getChannels()
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
channels := it.getChannels()
|
||||
|
||||
beginTs := it.BeginTs()
|
||||
endTs := it.EndTs()
|
||||
|
@ -117,20 +114,21 @@ func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) getChannels() ([]pChan, error) {
|
||||
if len(it.pChannels) != 0 {
|
||||
return it.pChannels, nil
|
||||
}
|
||||
func (it *upsertTask) setChannels() error {
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
channels, err := it.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
it.pChannels = channels
|
||||
return channels, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) getChannels() []pChan {
|
||||
return it.pChannels
|
||||
}
|
||||
|
||||
func (it *upsertTask) OnEnqueue() error {
|
||||
|
|
|
@ -315,8 +315,9 @@ func TestUpsertTask(t *testing.T) {
|
|||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
resChannels, err := ut.getChannels()
|
||||
err := ut.setChannels()
|
||||
assert.NoError(t, err)
|
||||
resChannels := ut.getChannels()
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
assert.ElementsMatch(t, channels, ut.pChannels)
|
||||
|
||||
|
@ -324,8 +325,7 @@ func TestUpsertTask(t *testing.T) {
|
|||
return nil, fmt.Errorf("mock err")
|
||||
}
|
||||
// get channels again, should return task's pChannels, so getChannelsFunc should not invoke again
|
||||
resChannels, err = ut.getChannels()
|
||||
assert.NoError(t, err)
|
||||
resChannels = ut.getChannels()
|
||||
assert.ElementsMatch(t, channels, resChannels)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue