mirror of https://github.com/milvus-io/milvus.git
Fix datarace between GetComponentStates and Register (#11935)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/11935/merge
parent
074687e32e
commit
ee0f753f7a
|
@ -49,6 +49,9 @@ const (
|
|||
|
||||
// InvalidFieldID indicates that the field does not exist . It will be set when the field is not found.
|
||||
InvalidFieldID = int64(-1)
|
||||
|
||||
// NotRegisteredID means node is not registered into etcd.
|
||||
NotRegisteredID = int64(-1)
|
||||
)
|
||||
|
||||
// Endian is type alias of binary.LittleEndian.
|
||||
|
|
|
@ -27,6 +27,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
memkv "github.com/milvus-io/milvus/internal/kv/mem"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
|
@ -491,6 +493,12 @@ func TestGetSegmentInfo(t *testing.T) {
|
|||
|
||||
func TestGetComponentStates(t *testing.T) {
|
||||
svr := &Server{}
|
||||
resp, err := svr.GetComponentStates(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
svr.session = &sessionutil.Session{}
|
||||
svr.session.UpdateRegistered(true)
|
||||
type testCase struct {
|
||||
state ServerState
|
||||
code internalpb.StateCode
|
||||
|
|
|
@ -23,6 +23,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -384,9 +386,15 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath
|
|||
|
||||
// GetComponentStates returns DataCoord's current state
|
||||
func (s *Server) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
nodeID := common.NotRegisteredID
|
||||
if s.session != nil && s.session.Registered() {
|
||||
nodeID = s.session.ServerID // or Params.NodeID
|
||||
}
|
||||
|
||||
resp := &internalpb.ComponentStates{
|
||||
State: &internalpb.ComponentInfo{
|
||||
NodeID: Params.NodeID,
|
||||
// NodeID: Params.NodeID, // will race with Server.Register()
|
||||
NodeID: nodeID,
|
||||
Role: "datacoord",
|
||||
StateCode: 0,
|
||||
},
|
||||
|
|
|
@ -32,6 +32,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
|
@ -473,9 +475,14 @@ func (node *DataNode) WatchDmChannels(ctx context.Context, in *datapb.WatchDmCha
|
|||
// GetComponentStates will return current state of DataNode
|
||||
func (node *DataNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
log.Debug("DataNode current state", zap.Any("State", node.State.Load()))
|
||||
nodeID := common.NotRegisteredID
|
||||
if node.session != nil && node.session.Registered() {
|
||||
nodeID = node.session.ServerID
|
||||
}
|
||||
states := &internalpb.ComponentStates{
|
||||
State: &internalpb.ComponentInfo{
|
||||
NodeID: Params.NodeID,
|
||||
// NodeID: Params.NodeID, // will race with DataNode.Register()
|
||||
NodeID: nodeID,
|
||||
Role: node.Role,
|
||||
StateCode: node.State.Load().(internalpb.StateCode),
|
||||
},
|
||||
|
|
|
@ -28,6 +28,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
@ -584,3 +586,17 @@ func TestWatchChannel(t *testing.T) {
|
|||
|
||||
})
|
||||
}
|
||||
|
||||
func TestDataNode_GetComponentStates(t *testing.T) {
|
||||
n := &DataNode{}
|
||||
n.State.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
n.session = &sessionutil.Session{}
|
||||
n.session.UpdateRegistered(true)
|
||||
resp, err = n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -59,8 +61,6 @@ var _ types.IndexCoord = (*IndexCoord)(nil)
|
|||
type IndexCoord struct {
|
||||
stateCode atomic.Value
|
||||
|
||||
ID UniqueID
|
||||
|
||||
loopCtx context.Context
|
||||
loopCancel func()
|
||||
loopWg sync.WaitGroup
|
||||
|
@ -193,13 +193,6 @@ func (i *IndexCoord) Init() error {
|
|||
return
|
||||
}
|
||||
|
||||
i.ID, err = i.idAllocator.AllocOne()
|
||||
if err != nil {
|
||||
log.Error("IndexCoord idAllocator allocOne failed", zap.Error(err))
|
||||
initErr = err
|
||||
return
|
||||
}
|
||||
|
||||
option := &miniokv.Option{
|
||||
Address: Params.MinIOAddress,
|
||||
AccessKeyID: Params.MinIOAccessKeyID,
|
||||
|
@ -302,8 +295,14 @@ func (i *IndexCoord) isHealthy() bool {
|
|||
// GetComponentStates gets the component states of IndexCoord.
|
||||
func (i *IndexCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
log.Debug("get IndexCoord component states ...")
|
||||
|
||||
nodeID := common.NotRegisteredID
|
||||
if i.session != nil && i.session.Registered() {
|
||||
nodeID = i.session.ServerID
|
||||
}
|
||||
|
||||
stateInfo := &internalpb.ComponentInfo{
|
||||
NodeID: i.ID,
|
||||
NodeID: nodeID,
|
||||
Role: "IndexCoord",
|
||||
StateCode: i.stateCode.Load().(internalpb.StateCode),
|
||||
}
|
||||
|
@ -515,19 +514,19 @@ func (i *IndexCoord) GetIndexFilePaths(ctx context.Context, req *indexpb.GetInde
|
|||
// GetMetrics gets the metrics info of IndexCoord.
|
||||
func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
|
||||
log.Debug("IndexCoord.GetMetrics",
|
||||
zap.Int64("node_id", i.ID),
|
||||
zap.Int64("node_id", i.session.ServerID),
|
||||
zap.String("req", req.Request))
|
||||
|
||||
if !i.isHealthy() {
|
||||
log.Warn("IndexCoord.GetMetrics failed",
|
||||
zap.Int64("node_id", i.ID),
|
||||
zap.Int64("node_id", i.session.ServerID),
|
||||
zap.String("req", req.Request),
|
||||
zap.Error(errIndexCoordIsUnhealthy(i.ID)))
|
||||
zap.Error(errIndexCoordIsUnhealthy(i.session.ServerID)))
|
||||
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msgIndexCoordIsUnhealthy(i.ID),
|
||||
Reason: msgIndexCoordIsUnhealthy(i.session.ServerID),
|
||||
},
|
||||
Response: "",
|
||||
}, nil
|
||||
|
@ -536,7 +535,7 @@ func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReq
|
|||
metricType, err := metricsinfo.ParseMetricType(req.Request)
|
||||
if err != nil {
|
||||
log.Error("IndexCoord.GetMetrics failed to parse metric type",
|
||||
zap.Int64("node_id", i.ID),
|
||||
zap.Int64("node_id", i.session.ServerID),
|
||||
zap.String("req", req.Request),
|
||||
zap.Error(err))
|
||||
|
||||
|
@ -563,7 +562,7 @@ func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReq
|
|||
metrics, err := getSystemInfoMetrics(ctx, req, i)
|
||||
|
||||
log.Debug("IndexCoord.GetMetrics",
|
||||
zap.Int64("node_id", i.ID),
|
||||
zap.Int64("node_id", i.session.ServerID),
|
||||
zap.String("req", req.Request),
|
||||
zap.String("metric_type", metricType),
|
||||
zap.Any("metrics", metrics), // TODO(dragondriver): necessary? may be very large
|
||||
|
@ -575,7 +574,7 @@ func (i *IndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReq
|
|||
}
|
||||
|
||||
log.Debug("IndexCoord.GetMetrics failed, request metric type is not implemented yet",
|
||||
zap.Int64("node_id", i.ID),
|
||||
zap.Int64("node_id", i.session.ServerID),
|
||||
zap.String("req", req.Request),
|
||||
zap.String("metric_type", metricType))
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
|
||||
|
@ -231,3 +233,17 @@ func TestIndexCoord_watchNodeLoop(t *testing.T) {
|
|||
assert.True(t, flag)
|
||||
|
||||
}
|
||||
|
||||
func TestIndexCoord_GetComponentStates(t *testing.T) {
|
||||
n := &IndexCoord{}
|
||||
n.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
n.session = &sessionutil.Session{}
|
||||
n.session.UpdateRegistered(true)
|
||||
resp, err = n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ import (
|
|||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
|
||||
|
@ -285,8 +287,13 @@ func (i *IndexNode) CreateIndex(ctx context.Context, request *indexpb.CreateInde
|
|||
// GetComponentStates gets the component states of IndexNode.
|
||||
func (i *IndexNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
log.Debug("get IndexNode components states ...")
|
||||
nodeID := common.NotRegisteredID
|
||||
if i.session != nil && i.session.Registered() {
|
||||
nodeID = i.session.ServerID
|
||||
}
|
||||
stateInfo := &internalpb.ComponentInfo{
|
||||
NodeID: Params.NodeID,
|
||||
// NodeID: Params.NodeID, // will race with i.Register()
|
||||
NodeID: nodeID,
|
||||
Role: "NodeImpl",
|
||||
StateCode: i.stateCode.Load().(internalpb.StateCode),
|
||||
}
|
||||
|
|
|
@ -24,12 +24,15 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
||||
|
@ -811,3 +814,17 @@ func TestIndexNode_InitError(t *testing.T) {
|
|||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndexNode_GetComponentStates(t *testing.T) {
|
||||
n := &IndexNode{}
|
||||
n.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
n.session = &sessionutil.Session{}
|
||||
n.session.UpdateRegistered(true)
|
||||
resp, err = n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
|
@ -65,8 +67,13 @@ func (node *Proxy) GetComponentStates(ctx context.Context) (*internalpb.Componen
|
|||
}
|
||||
return stats, errors.New(errMsg)
|
||||
}
|
||||
nodeID := common.NotRegisteredID
|
||||
if node.session != nil && node.session.Registered() {
|
||||
nodeID = node.session.ServerID
|
||||
}
|
||||
info := &internalpb.ComponentInfo{
|
||||
NodeID: Params.ProxyID,
|
||||
// NodeID: Params.ProxyID, // will race with Proxy.Register()
|
||||
NodeID: nodeID,
|
||||
Role: typeutil.ProxyRole,
|
||||
StateCode: code,
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
@ -2480,3 +2482,17 @@ func Test_GetCompactionStateWithPlans(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy_GetComponentStates(t *testing.T) {
|
||||
n := &Proxy{}
|
||||
n.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
n.session = &sessionutil.Session{}
|
||||
n.session.UpdateRegistered(true)
|
||||
resp, err = n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -33,8 +35,13 @@ import (
|
|||
|
||||
// GetComponentStates return information about whether the coord is healthy
|
||||
func (qc *QueryCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
nodeID := common.NotRegisteredID
|
||||
if qc.session != nil && qc.session.Registered() {
|
||||
nodeID = qc.session.ServerID
|
||||
}
|
||||
serviceComponentInfo := &internalpb.ComponentInfo{
|
||||
NodeID: Params.QueryCoordID,
|
||||
// NodeID: Params.QueryCoordID, // will race with QueryCoord.Register()
|
||||
NodeID: nodeID,
|
||||
StateCode: qc.stateCode.Load().(internalpb.StateCode),
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
|
@ -694,3 +697,17 @@ func Test_GrpcGetQueryChannelFail(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, res.Status.ErrorCode)
|
||||
}
|
||||
|
||||
func TestQueryCoord_GetComponentStates(t *testing.T) {
|
||||
n := &QueryCoord{}
|
||||
n.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
n.session = &sessionutil.Session{}
|
||||
n.session.UpdateRegistered(true)
|
||||
resp, err = n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/kv"
|
||||
|
@ -1194,9 +1196,15 @@ func (c *Core) GetComponentStates(ctx context.Context) (*internalpb.ComponentSta
|
|||
code := c.stateCode.Load().(internalpb.StateCode)
|
||||
log.Debug("GetComponentStates", zap.String("State Code", internalpb.StateCode_name[int32(code)]))
|
||||
|
||||
nodeID := common.NotRegisteredID
|
||||
if c.session != nil && c.session.Registered() {
|
||||
nodeID = c.session.ServerID
|
||||
}
|
||||
|
||||
return &internalpb.ComponentStates{
|
||||
State: &internalpb.ComponentInfo{
|
||||
NodeID: c.session.ServerID,
|
||||
// NodeID: c.session.ServerID, // will race with Core.Register()
|
||||
NodeID: nodeID,
|
||||
Role: typeutil.RootCoordRole,
|
||||
StateCode: code,
|
||||
ExtraInfo: nil,
|
||||
|
@ -1207,7 +1215,7 @@ func (c *Core) GetComponentStates(ctx context.Context) (*internalpb.ComponentSta
|
|||
},
|
||||
SubcomponentStates: []*internalpb.ComponentInfo{
|
||||
{
|
||||
NodeID: c.session.ServerID,
|
||||
NodeID: nodeID,
|
||||
Role: typeutil.RootCoordRole,
|
||||
StateCode: code,
|
||||
ExtraInfo: nil,
|
||||
|
|
|
@ -2715,3 +2715,17 @@ func TestRootCoord_CheckZeroShardsNum(t *testing.T) {
|
|||
err = core.Stop()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestCore_GetComponentStates(t *testing.T) {
|
||||
n := &Core{}
|
||||
n.stateCode.Store(internalpb.StateCode_Healthy)
|
||||
resp, err := n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
|
||||
n.session = &sessionutil.Session{}
|
||||
n.session.UpdateRegistered(true)
|
||||
resp, err = n.GetComponentStates(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -58,6 +59,8 @@ type Session struct {
|
|||
leaseID *clientv3.LeaseID
|
||||
|
||||
metaRoot string
|
||||
|
||||
registered atomic.Value
|
||||
}
|
||||
|
||||
// NewSession is a helper to build Session object.
|
||||
|
@ -70,6 +73,8 @@ func NewSession(ctx context.Context, metaRoot string, etcdEndpoints []string) *S
|
|||
metaRoot: metaRoot,
|
||||
}
|
||||
|
||||
session.UpdateRegistered(false)
|
||||
|
||||
connectEtcdFn := func() error {
|
||||
log.Debug("Session try to connect to etcd")
|
||||
etcdCli, err := clientv3.New(clientv3.Config{Endpoints: etcdEndpoints, DialTimeout: 5 * time.Second})
|
||||
|
@ -112,6 +117,7 @@ func (s *Session) Init(serverName, address string, exclusive bool) {
|
|||
panic(err)
|
||||
}
|
||||
s.liveCh = s.processKeepAliveResponse(ch)
|
||||
s.UpdateRegistered(true)
|
||||
}
|
||||
|
||||
func (s *Session) getServerID() (int64, error) {
|
||||
|
@ -403,3 +409,17 @@ func (s *Session) Revoke(timeout time.Duration) {
|
|||
// ignores resp & error, just do best effort to revoke
|
||||
_, _ = s.etcdCli.Revoke(ctx, *s.leaseID)
|
||||
}
|
||||
|
||||
// UpdateRegistered update the state of registered.
|
||||
func (s *Session) UpdateRegistered(b bool) {
|
||||
s.registered.Store(b)
|
||||
}
|
||||
|
||||
// Registered check if session was registered into etcd.
|
||||
func (s *Session) Registered() bool {
|
||||
b, ok := s.registered.Load().(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -238,3 +238,11 @@ func TestSessionRevoke(t *testing.T) {
|
|||
s.Revoke(time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSession_Registered(t *testing.T) {
|
||||
session := &Session{}
|
||||
session.UpdateRegistered(false)
|
||||
assert.False(t, session.Registered())
|
||||
session.UpdateRegistered(true)
|
||||
assert.True(t, session.Registered())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue