diff --git a/bolt/session.go b/bolt/session.go index f164e35a10..ab60bc6e74 100644 --- a/bolt/session.go +++ b/bolt/session.go @@ -22,6 +22,27 @@ func (c *Client) initializeSessions(ctx context.Context, tx *bolt.Tx) error { return nil } +// RenewSession extends the expire time to newExpiration. +func (c *Client) RenewSession(ctx context.Context, session *platform.Session, newExpiration time.Time) error { + op := getOp(platform.OpRenewSession) + if session == nil { + return &platform.Error{ + Op: op, + Msg: "session is nil", + } + } + return c.db.Update(func(tx *bolt.Tx) error { + session.ExpiresAt = newExpiration + if err := c.putSession(ctx, tx, session); err != nil { + return &platform.Error{ + Op: op, + Err: err, + } + } + return nil + }) +} + // FindSession retrieves the session found at the provided key. func (c *Client) FindSession(ctx context.Context, key string) (*platform.Session, error) { op := getOp(platform.OpFindSession) diff --git a/http/authentication_middleware.go b/http/authentication_middleware.go index 2f54dd2982..083e0c6481 100644 --- a/http/authentication_middleware.go +++ b/http/authentication_middleware.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "time" platform "github.com/influxdata/influxdb" platcontext "github.com/influxdata/influxdb/context" @@ -122,5 +123,11 @@ func (h *AuthenticationHandler) extractSession(ctx context.Context, r *http.Requ return ctx, e } + // if the session is not expired, renew the session + e = h.SessionService.RenewSession(ctx, s, time.Now().Add(platform.RenewSessionTime)) + if e != nil { + return ctx, e + } + return platcontext.SetAuthorizer(ctx, s), nil } diff --git a/http/authentication_test.go b/http/authentication_test.go index 2d6499f595..43be11c184 100644 --- a/http/authentication_test.go +++ b/http/authentication_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" platform "github.com/influxdata/influxdb" platformhttp "github.com/influxdata/influxdb/http" @@ -39,6 +40,9 @@ func TestAuthenticationHandler(t *testing.T) { FindSessionFn: func(ctx context.Context, key string) (*platform.Session, error) { return &platform.Session{}, nil }, + RenewSessionFn: func(ctx context.Context, session *platform.Session, expiredAt time.Time) error { + return nil + }, }, }, args: args{ diff --git a/mock/session_service.go b/mock/session_service.go index 3463400693..d1ae3440a2 100644 --- a/mock/session_service.go +++ b/mock/session_service.go @@ -3,6 +3,7 @@ package mock import ( "context" "fmt" + "time" platform "github.com/influxdata/influxdb" ) @@ -13,6 +14,7 @@ type SessionService struct { FindSessionFn func(context.Context, string) (*platform.Session, error) ExpireSessionFn func(context.Context, string) error CreateSessionFn func(context.Context, string) (*platform.Session, error) + RenewSessionFn func(ctx context.Context, session *platform.Session, newExpiration time.Time) error } // NewSessionService returns a mock SessionService where its methods will return @@ -22,6 +24,9 @@ func NewSessionService() *SessionService { FindSessionFn: func(context.Context, string) (*platform.Session, error) { return nil, fmt.Errorf("mock session") }, CreateSessionFn: func(context.Context, string) (*platform.Session, error) { return nil, fmt.Errorf("mock session") }, ExpireSessionFn: func(context.Context, string) error { return fmt.Errorf("mock session") }, + RenewSessionFn: func(ctx context.Context, session *platform.Session, expiredAt time.Time) error { + return fmt.Errorf("mock session") + }, } } @@ -39,3 +44,8 @@ func (s *SessionService) CreateSession(ctx context.Context, user string) (*platf func (s *SessionService) ExpireSession(ctx context.Context, key string) error { return s.ExpireSessionFn(ctx, key) } + +// RenewSession extends the expire time to newExpiration. +func (s *SessionService) RenewSession(ctx context.Context, session *platform.Session, expiredAt time.Time) error { + return s.RenewSessionFn(ctx, session, expiredAt) +} diff --git a/session.go b/session.go index 470d08841a..6b6cbf1849 100644 --- a/session.go +++ b/session.go @@ -2,7 +2,6 @@ package influxdb import ( "context" - "fmt" "time" ) @@ -12,6 +11,9 @@ const ErrSessionNotFound = "session not found" // ErrSessionExpired is the error message for expired sessions. const ErrSessionExpired = "session has expired" +// RenewSessionTime is the the time to extend session, currently set to 5min. +var RenewSessionTime = time.Duration(time.Second * 300) + var ( // OpFindSession represents the operation that looks for sessions. OpFindSession = "FindSession" @@ -19,6 +21,8 @@ var ( OpExpireSession = "ExpireSession" // OpCreateSession represents the operation that creates a session for a given user. OpCreateSession = "CreateSession" + // OpRenewSession = "RenewSession" + OpRenewSession = "RenewSession" ) // Session is a user session. @@ -35,7 +39,10 @@ type Session struct { // Expired returns an error if the session is expired. func (s *Session) Expired() error { if time.Now().After(s.ExpiresAt) { - return fmt.Errorf(ErrSessionExpired) + return &Error{ + Code: EForbidden, + Msg: ErrSessionExpired, + } } return nil @@ -67,4 +74,5 @@ type SessionService interface { FindSession(ctx context.Context, key string) (*Session, error) ExpireSession(ctx context.Context, key string) error CreateSession(ctx context.Context, user string) (*Session, error) + RenewSession(ctx context.Context, session *Session, newExpiration time.Time) error } diff --git a/testing/session.go b/testing/session.go index eb7fc9c698..3024b4316c 100644 --- a/testing/session.go +++ b/testing/session.go @@ -66,6 +66,10 @@ func SessionService( name: "ExpireSession", fn: ExpireSession, }, + { + name: "RenewSession", + fn: RenewSession, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -276,3 +280,113 @@ func ExpireSession( }) } } + +// RenewSession testing +func RenewSession( + init func(SessionFields, *testing.T) (platform.SessionService, string, func()), + t *testing.T, +) { + type args struct { + session *platform.Session + key string + expireAt time.Time + } + + type wants struct { + err error + session *platform.Session + } + + tests := []struct { + name string + fields SessionFields + args args + wants wants + }{ + { + name: "basic renew session", + fields: SessionFields{ + IDGenerator: mock.NewIDGenerator(sessionTwoID, t), + TokenGenerator: mock.NewTokenGenerator("abc123xyz", nil), + Sessions: []*platform.Session{ + { + ID: MustIDBase16(sessionOneID), + UserID: MustIDBase16(sessionTwoID), + Key: "abc123xyz", + ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC), + }, + }, + }, + args: args{ + session: &platform.Session{ + ID: MustIDBase16(sessionOneID), + UserID: MustIDBase16(sessionTwoID), + Key: "abc123xyz", + ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC), + }, + key: "abc123xyz", + expireAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC), + }, + wants: wants{ + session: &platform.Session{ + ID: MustIDBase16(sessionOneID), + UserID: MustIDBase16(sessionTwoID), + Key: "abc123xyz", + ExpiresAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC), + }, + }, + }, + { + name: "renew nil session", + fields: SessionFields{ + IDGenerator: mock.NewIDGenerator(sessionTwoID, t), + TokenGenerator: mock.NewTokenGenerator("abc123xyz", nil), + Sessions: []*platform.Session{ + { + ID: MustIDBase16(sessionOneID), + UserID: MustIDBase16(sessionTwoID), + Key: "abc123xyz", + ExpiresAt: time.Date(2030, 9, 26, 0, 0, 0, 0, time.UTC), + }, + }, + }, + args: args{ + key: "abc123xyz", + expireAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC), + }, + wants: wants{ + err: &platform.Error{ + Code: platform.EInternal, + Msg: "session is nil", + Op: platform.OpRenewSession, + }, + session: &platform.Session{ + ID: MustIDBase16(sessionOneID), + UserID: MustIDBase16(sessionTwoID), + Key: "abc123xyz", + ExpiresAt: time.Date(2031, 9, 26, 0, 0, 10, 0, time.UTC), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, opPrefix, done := init(tt.fields, t) + defer done() + ctx := context.Background() + + err := s.RenewSession(ctx, tt.args.session, tt.args.expireAt) + diffPlatformErrors(tt.name, err, tt.wants.err, opPrefix, t) + + session, err := s.FindSession(ctx, tt.args.key) + if err != nil { + t.Errorf("err in find session %v", err) + } + + if diff := cmp.Diff(session, tt.wants.session, sessionCmpOptions...); diff != "" { + t.Errorf("session is different -got/+want\ndiff %s", diff) + } + }) + } +}