Merge pull request #11062 from influxdata/session_renew

feat(http): add renew session
pull/11148/head
kelwang 2019-01-16 13:29:51 -05:00 committed by GitHub
commit 7f20f3a1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 166 additions and 2 deletions

View File

@ -22,6 +22,27 @@ func (c *Client) initializeSessions(ctx context.Context, tx *bolt.Tx) error {
return nil
}
// RenewSession extends the expire time to newExpiration.
func (c *Client) RenewSession(ctx context.Context, session *platform.Session, newExpiration time.Time) error {
op := getOp(platform.OpRenewSession)
if session == nil {
return &platform.Error{
Op: op,
Msg: "session is nil",
}
}
return c.db.Update(func(tx *bolt.Tx) error {
session.ExpiresAt = newExpiration
if err := c.putSession(ctx, tx, session); err != nil {
return &platform.Error{
Op: op,
Err: err,
}
}
return nil
})
}
// FindSession retrieves the session found at the provided key.
func (c *Client) FindSession(ctx context.Context, key string) (*platform.Session, error) {
op := getOp(platform.OpFindSession)

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"time"
platform "github.com/influxdata/influxdb"
platcontext "github.com/influxdata/influxdb/context"
@ -122,5 +123,11 @@ func (h *AuthenticationHandler) extractSession(ctx context.Context, r *http.Requ
return ctx, e
}
// if the session is not expired, renew the session
e = h.SessionService.RenewSession(ctx, s, time.Now().Add(platform.RenewSessionTime))
if e != nil {
return ctx, e
}
return platcontext.SetAuthorizer(ctx, s), nil
}

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
platform "github.com/influxdata/influxdb"
platformhttp "github.com/influxdata/influxdb/http"
@ -39,6 +40,9 @@ func TestAuthenticationHandler(t *testing.T) {
FindSessionFn: func(ctx context.Context, key string) (*platform.Session, error) {
return &platform.Session{}, nil
},
RenewSessionFn: func(ctx context.Context, session *platform.Session, expiredAt time.Time) error {
return nil
},
},
},
args: args{

View File

@ -3,6 +3,7 @@ package mock
import (
"context"
"fmt"
"time"
platform "github.com/influxdata/influxdb"
)
@ -13,6 +14,7 @@ type SessionService struct {
FindSessionFn func(context.Context, string) (*platform.Session, error)
ExpireSessionFn func(context.Context, string) error
CreateSessionFn func(context.Context, string) (*platform.Session, error)
RenewSessionFn func(ctx context.Context, session *platform.Session, newExpiration time.Time) error
}
// NewSessionService returns a mock SessionService where its methods will return
@ -22,6 +24,9 @@ func NewSessionService() *SessionService {
FindSessionFn: func(context.Context, string) (*platform.Session, error) { return nil, fmt.Errorf("mock session") },
CreateSessionFn: func(context.Context, string) (*platform.Session, error) { return nil, fmt.Errorf("mock session") },
ExpireSessionFn: func(context.Context, string) error { return fmt.Errorf("mock session") },
RenewSessionFn: func(ctx context.Context, session *platform.Session, expiredAt time.Time) error {
return fmt.Errorf("mock session")
},
}
}
@ -39,3 +44,8 @@ func (s *SessionService) CreateSession(ctx context.Context, user string) (*platf
func (s *SessionService) ExpireSession(ctx context.Context, key string) error {
return s.ExpireSessionFn(ctx, key)
}
// RenewSession extends the expire time to newExpiration.
func (s *SessionService) RenewSession(ctx context.Context, session *platform.Session, expiredAt time.Time) error {
return s.RenewSessionFn(ctx, session, expiredAt)
}

View File

@ -2,7 +2,6 @@ package influxdb
import (
"context"
"fmt"
"time"
)
@ -12,6 +11,9 @@ const ErrSessionNotFound = "session not found"
// ErrSessionExpired is the error message for expired sessions.
const ErrSessionExpired = "session has expired"
// RenewSessionTime is the the time to extend session, currently set to 5min.
var RenewSessionTime = time.Duration(time.Second * 300)
var (
// OpFindSession represents the operation that looks for sessions.
OpFindSession = "FindSession"
@ -19,6 +21,8 @@ var (
OpExpireSession = "ExpireSession"
// OpCreateSession represents the operation that creates a session for a given user.
OpCreateSession = "CreateSession"
// OpRenewSession = "RenewSession"
OpRenewSession = "RenewSession"
)
// Session is a user session.
@ -35,7 +39,10 @@ type Session struct {
// Expired returns an error if the session is expired.
func (s *Session) Expired() error {
if time.Now().After(s.ExpiresAt) {
return fmt.Errorf(ErrSessionExpired)
return &Error{
Code: EForbidden,
Msg: ErrSessionExpired,
}
}
return nil
@ -67,4 +74,5 @@ type SessionService interface {
FindSession(ctx context.Context, key string) (*Session, error)
ExpireSession(ctx context.Context, key string) error
CreateSession(ctx context.Context, user string) (*Session, error)
RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error
}

View File

@ -66,6 +66,10 @@ func SessionService(
name: "ExpireSession",
fn: ExpireSession,
},
{
name: "RenewSession",
fn: RenewSession,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -276,3 +280,113 @@ func ExpireSession(
})
}
}
// RenewSession testing
func RenewSession(
init func(SessionFields, *testing.T) (platform.SessionService, string, func()),
t *testing.T,
) {
type args struct {
session *platform.Session
key string
expireAt time.Time
}
type wants struct {
err error
session *platform.Session
}
tests := []struct {
name string
fields SessionFields
args args
wants wants
}{
{
name: "basic renew session",
fields: SessionFields{
IDGenerator: mock.NewIDGenerator(sessionTwoID, t),
TokenGenerator: mock.NewTokenGenerator("abc123xyz", nil),
Sessions: []*platform.Session{
{
ID: MustIDBase16(sessionOneID),
UserID: MustIDBase16(sessionTwoID),
Key: "abc123xyz",
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
},
},
},
args: args{
session: &platform.Session{
ID: MustIDBase16(sessionOneID),
UserID: MustIDBase16(sessionTwoID),
Key: "abc123xyz",
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
},
key: "abc123xyz",
expireAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
},
wants: wants{
session: &platform.Session{
ID: MustIDBase16(sessionOneID),
UserID: MustIDBase16(sessionTwoID),
Key: "abc123xyz",
ExpiresAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
},
},
},
{
name: "renew nil session",
fields: SessionFields{
IDGenerator: mock.NewIDGenerator(sessionTwoID, t),
TokenGenerator: mock.NewTokenGenerator("abc123xyz", nil),
Sessions: []*platform.Session{
{
ID: MustIDBase16(sessionOneID),
UserID: MustIDBase16(sessionTwoID),
Key: "abc123xyz",
ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC),
},
},
},
args: args{
key: "abc123xyz",
expireAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
},
wants: wants{
err: &platform.Error{
Code: platform.EInternal,
Msg: "session is nil",
Op: platform.OpRenewSession,
},
session: &platform.Session{
ID: MustIDBase16(sessionOneID),
UserID: MustIDBase16(sessionTwoID),
Key: "abc123xyz",
ExpiresAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, opPrefix, done := init(tt.fields, t)
defer done()
ctx := context.Background()
err := s.RenewSession(ctx, tt.args.session, tt.args.expireAt)
diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t)
session, err := s.FindSession(ctx, tt.args.key)
if err != nil {
t.Errorf("err in find session %v", err)
}
if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" {
t.Errorf("session is different -got/+want\ndiff %s", diff)
}
})
}
}