Reduce lock operations when get dml stream (#17468)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/17452/head
Jiquan Long 2022-06-09 17:34:09 +08:00 committed by GitHub
parent 9210299706
commit 2ca81620ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 82 deletions

View File

@ -38,8 +38,7 @@ import (
type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error)
createDMLMsgStream(collectionID UniqueID) error
getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error)
getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID) error
removeAllDMLStream() error
}
@ -182,12 +181,6 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool
return ok && streamInfos.stream != nil
}
func (mgr *singleTypeChannelsMgr) streamExist(collectionID UniqueID) bool {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
return mgr.streamExistPrivate(collectionID)
}
func createStream(factory msgstream.Factory, streamType streamType, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error
@ -213,14 +206,6 @@ func createStream(factory msgstream.Factory, streamType streamType, pchans []pCh
return stream, nil
}
func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, channelInfos channelInfos, stream msgstream.MsgStream) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
if !mgr.streamExistPrivate(collectionID) {
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
}
}
func incPChansMetrics(pchans []pChan) {
for _, pc := range pchans {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), pc).Inc()
@ -234,35 +219,42 @@ func decPChanMetrics(pchans []pChan) {
}
// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return nil.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error {
if mgr.streamExist(collectionID) {
log.Info("stream already exist, no need to re-create", zap.Int64("collection_id", collectionID))
return nil
// If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil {
// already exist.
mgr.mu.RUnlock()
return infos.stream, nil
}
mgr.mu.RUnlock()
channelInfos, err := mgr.getChannelsFunc(collectionID)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID))
return err
return nil, err
}
stream, err := createStream(mgr.msgStreamFactory, mgr.singleStreamType, channelInfos.pchans, mgr.repackFunc)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
return err
return nil, err
}
mgr.updateCollection(collectionID, channelInfos, stream)
mgr.mu.Lock()
defer mgr.mu.Unlock()
if !mgr.streamExistPrivate(collectionID) {
log.Info("create message stream", zap.Int64("collection", collectionID),
zap.Strings("virtual_channels", channelInfos.vchans),
zap.Strings("physical_channels", channelInfos.pchans))
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
incPChansMetrics(channelInfos.pchans)
}
log.Info("create message stream",
zap.Int64("collection_id", collectionID),
zap.Strings("virtual_channels", channelInfos.vchans),
zap.Strings("physical_channels", channelInfos.pchans))
incPChansMetrics(channelInfos.pchans)
return nil
return mgr.infos[collectionID].stream, nil
}
func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) {
@ -275,18 +267,14 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea
return nil, fmt.Errorf("collection not found: %d", collectionID)
}
// getStream get message stream of specified collection.
// getOrCreateStream get message stream of specified collection.
// If stream don't exists, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}
if err := mgr.createMsgStream(collectionID); err != nil {
return nil, err
}
return mgr.lockGetStream(collectionID)
return mgr.createMsgStream(collectionID)
}
// removeStream remove the corresponding stream of the specified collection. Idempotent.
@ -343,12 +331,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID)
}
func (mgr *channelsMgrImpl) createDMLMsgStream(collectionID UniqueID) error {
return mgr.dmlChannelsMgr.createMsgStream(collectionID)
}
func (mgr *channelsMgrImpl) getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getStream(collectionID)
func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(collectionID)
}
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) error {

View File

@ -205,31 +205,6 @@ func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) {
})
}
func Test_singleTypeChannelsMgr_streamExist(t *testing.T) {
t.Run("exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newSimpleMockMsgStream()},
},
}
exist := m.streamExist(100)
assert.True(t, exist)
})
t.Run("not exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: nil},
},
}
exist := m.streamExist(100)
assert.False(t, exist)
m.infos = make(map[UniqueID]streamInfos)
exist = m.streamExist(100)
assert.False(t, exist)
})
}
func Test_createStream(t *testing.T) {
t.Run("failed to create msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
@ -268,8 +243,9 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
err := m.createMsgStream(100)
stream, err := m.createMsgStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
t.Run("failed to get channels", func(t *testing.T) {
@ -278,7 +254,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
err := m.createMsgStream(100)
_, err := m.createMsgStream(100)
assert.Error(t, err)
})
@ -295,7 +271,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
singleStreamType: dmlStreamType,
repackFunc: nil,
}
err := m.createMsgStream(100)
_, err := m.createMsgStream(100)
assert.Error(t, err)
})
@ -313,9 +289,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
singleStreamType: dmlStreamType,
repackFunc: nil,
}
err := m.createMsgStream(100)
stream, err := m.createMsgStream(100)
assert.NoError(t, err)
stream, err := m.getStream(100)
assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
@ -349,7 +326,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getStream(100)
stream, err := m.getOrCreateStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
@ -361,7 +338,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getStream(100)
_, err := m.getOrCreateStream(100)
assert.Error(t, err)
})
@ -379,7 +356,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
singleStreamType: dmlStreamType,
repackFunc: nil,
}
stream, err := m.getStream(100)
stream, err := m.getOrCreateStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})

View File

@ -497,7 +497,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
it.PartitionID = partitionID
tr.Record("get collection id & partition id from cache")
stream, err := it.chMgr.getDMLStream(collID)
stream, err := it.chMgr.getOrCreateDmlStream(collID)
if err != nil {
return err
}
@ -3260,7 +3260,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
collID := dt.DeleteRequest.CollectionID
stream, err := dt.chMgr.getDMLStream(collID)
stream, err := dt.chMgr.getOrCreateDmlStream(collID)
if err != nil {
return err
}

View File

@ -1716,7 +1716,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
err = chMgr.createDMLMsgStream(collectionID)
_, err = chMgr.getOrCreateDmlStream(collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -1971,7 +1971,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
err = chMgr.createDMLMsgStream(collectionID)
_, err = chMgr.getOrCreateDmlStream(collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)