Merge pull request #11062 from influxdata/session_renew
feat(http): add renew sessionpull/11148/head
commit
7f20f3a1e6
|
@ -22,6 +22,27 @@ func (c *Client) initializeSessions(ctx context.Context, tx *bolt.Tx) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RenewSession extends the expire time to newExpiration.
|
||||
func (c *Client) RenewSession(ctx context.Context, session *platform.Session, newExpiration time.Time) error {
|
||||
op := getOp(platform.OpRenewSession)
|
||||
if session == nil {
|
||||
return &platform.Error{
|
||||
Op: op,
|
||||
Msg: "session is nil",
|
||||
}
|
||||
}
|
||||
return c.db.Update(func(tx *bolt.Tx) error {
|
||||
session.ExpiresAt = newExpiration
|
||||
if err := c.putSession(ctx, tx, session); err != nil {
|
||||
return &platform.Error{
|
||||
Op: op,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FindSession retrieves the session found at the provided key.
|
||||
func (c *Client) FindSession(ctx context.Context, key string) (*platform.Session, error) {
|
||||
op := getOp(platform.OpFindSession)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
platform "github.com/influxdata/influxdb"
|
||||
platcontext "github.com/influxdata/influxdb/context"
|
||||
|
@ -122,5 +123,11 @@ func (h *AuthenticationHandler) extractSession(ctx context.Context, r *http.Requ
|
|||
return ctx, e
|
||||
}
|
||||
|
||||
// if the session is not expired, renew the session
|
||||
e = h.SessionService.RenewSession(ctx, s, time.Now().Add(platform.RenewSessionTime))
|
||||
if e != nil {
|
||||
return ctx, e
|
||||
}
|
||||
|
||||
return platcontext.SetAuthorizer(ctx, s), nil
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
platform "github.com/influxdata/influxdb"
|
||||
platformhttp "github.com/influxdata/influxdb/http"
|
||||
|
@ -39,6 +40,9 @@ func TestAuthenticationHandler(t *testing.T) {
|
|||
FindSessionFn: func(ctx context.Context, key string) (*platform.Session, error) {
|
||||
return &platform.Session{}, nil
|
||||
},
|
||||
RenewSessionFn: func(ctx context.Context, session *platform.Session, expiredAt time.Time) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
|
|
|
@ -3,6 +3,7 @@ package mock
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
platform "github.com/influxdata/influxdb"
|
||||
)
|
||||
|
@ -13,6 +14,7 @@ type SessionService struct {
|
|||
FindSessionFn func(context.Context, string) (*platform.Session, error)
|
||||
ExpireSessionFn func(context.Context, string) error
|
||||
CreateSessionFn func(context.Context, string) (*platform.Session, error)
|
||||
RenewSessionFn func(ctx context.Context, session *platform.Session, newExpiration time.Time) error
|
||||
}
|
||||
|
||||
// NewSessionService returns a mock SessionService where its methods will return
|
||||
|
@ -22,6 +24,9 @@ func NewSessionService() *SessionService {
|
|||
FindSessionFn: func(context.Context, string) (*platform.Session, error) { return nil, fmt.Errorf("mock session") },
|
||||
CreateSessionFn: func(context.Context, string) (*platform.Session, error) { return nil, fmt.Errorf("mock session") },
|
||||
ExpireSessionFn: func(context.Context, string) error { return fmt.Errorf("mock session") },
|
||||
RenewSessionFn: func(ctx context.Context, session *platform.Session, expiredAt time.Time) error {
|
||||
return fmt.Errorf("mock session")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,3 +44,8 @@ func (s *SessionService) CreateSession(ctx context.Context, user string) (*platf
|
|||
func (s *SessionService) ExpireSession(ctx context.Context, key string) error {
|
||||
return s.ExpireSessionFn(ctx, key)
|
||||
}
|
||||
|
||||
// RenewSession extends the expire time to newExpiration.
|
||||
func (s *SessionService) RenewSession(ctx context.Context, session *platform.Session, expiredAt time.Time) error {
|
||||
return s.RenewSessionFn(ctx, session, expiredAt)
|
||||
}
|
||||
|
|
12
session.go
12
session.go
|
@ -2,7 +2,6 @@ package influxdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -12,6 +11,9 @@ const ErrSessionNotFound = "session not found"
|
|||
// ErrSessionExpired is the error message for expired sessions.
|
||||
const ErrSessionExpired = "session has expired"
|
||||
|
||||
// RenewSessionTime is the the time to extend session, currently set to 5min.
|
||||
var RenewSessionTime = time.Duration(time.Second * 300)
|
||||
|
||||
var (
|
||||
// OpFindSession represents the operation that looks for sessions.
|
||||
OpFindSession = "FindSession"
|
||||
|
@ -19,6 +21,8 @@ var (
|
|||
OpExpireSession = "ExpireSession"
|
||||
// OpCreateSession represents the operation that creates a session for a given user.
|
||||
OpCreateSession = "CreateSession"
|
||||
// OpRenewSession = "RenewSession"
|
||||
OpRenewSession = "RenewSession"
|
||||
)
|
||||
|
||||
// Session is a user session.
|
||||
|
@ -35,7 +39,10 @@ type Session struct {
|
|||
// Expired returns an error if the session is expired.
|
||||
func (s *Session) Expired() error {
|
||||
if time.Now().After(s.ExpiresAt) {
|
||||
return fmt.Errorf(ErrSessionExpired)
|
||||
return &Error{
|
||||
Code: EForbidden,
|
||||
Msg: ErrSessionExpired,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -67,4 +74,5 @@ type SessionService interface {
|
|||
FindSession(ctx context.Context, key string) (*Session, error)
|
||||
ExpireSession(ctx context.Context, key string) error
|
||||
CreateSession(ctx context.Context, user string) (*Session, error)
|
||||
RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error
|
||||
}
|
||||
|
|
|
@ -66,6 +66,10 @@ func SessionService(
|
|||
name: "ExpireSession",
|
||||
fn: ExpireSession,
|
||||
},
|
||||
{
|
||||
name: "RenewSession",
|
||||
fn: RenewSession,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -276,3 +280,113 @@ func ExpireSession(
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RenewSession testing
|
||||
func RenewSession(
|
||||
init func(SessionFields, *testing.T) (platform.SessionService, string, func()),
|
||||
t *testing.T,
|
||||
) {
|
||||
type args struct {
|
||||
session *platform.Session
|
||||
key string
|
||||
expireAt time.Time
|
||||
}
|
||||
|
||||
type wants struct {
|
||||
err error
|
||||
session *platform.Session
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields SessionFields
|
||||
args args
|
||||
wants wants
|
||||
}{
|
||||
{
|
||||
name: "basic renew session",
|
||||
fields: SessionFields{
|
||||
IDGenerator: mock.NewIDGenerator(sessionTwoID, t),
|
||||
TokenGenerator: mock.NewTokenGenerator("abc123xyz", nil),
|
||||
Sessions: []*platform.Session{
|
||||
{
|
||||
ID: MustIDBase16(sessionOneID),
|
||||
UserID: MustIDBase16(sessionTwoID),
|
||||
Key: "abc123xyz",
|
||||
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
session: &platform.Session{
|
||||
ID: MustIDBase16(sessionOneID),
|
||||
UserID: MustIDBase16(sessionTwoID),
|
||||
Key: "abc123xyz",
|
||||
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
key: "abc123xyz",
|
||||
expireAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
|
||||
},
|
||||
wants: wants{
|
||||
session: &platform.Session{
|
||||
ID: MustIDBase16(sessionOneID),
|
||||
UserID: MustIDBase16(sessionTwoID),
|
||||
Key: "abc123xyz",
|
||||
ExpiresAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew nil session",
|
||||
fields: SessionFields{
|
||||
IDGenerator: mock.NewIDGenerator(sessionTwoID, t),
|
||||
TokenGenerator: mock.NewTokenGenerator("abc123xyz", nil),
|
||||
Sessions: []*platform.Session{
|
||||
{
|
||||
ID: MustIDBase16(sessionOneID),
|
||||
UserID: MustIDBase16(sessionTwoID),
|
||||
Key: "abc123xyz",
|
||||
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
key: "abc123xyz",
|
||||
expireAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
|
||||
},
|
||||
wants: wants{
|
||||
err: &platform.Error{
|
||||
Code: platform.EInternal,
|
||||
Msg: "session is nil",
|
||||
Op: platform.OpRenewSession,
|
||||
},
|
||||
session: &platform.Session{
|
||||
ID: MustIDBase16(sessionOneID),
|
||||
UserID: MustIDBase16(sessionTwoID),
|
||||
Key: "abc123xyz",
|
||||
ExpiresAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, opPrefix, done := init(tt.fields, t)
|
||||
defer done()
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.RenewSession(ctx, tt.args.session, tt.args.expireAt)
|
||||
diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t)
|
||||
|
||||
session, err := s.FindSession(ctx, tt.args.key)
|
||||
if err != nil {
|
||||
t.Errorf("err in find session %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" {
|
||||
t.Errorf("session is different -got/+want\ndiff %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue