Refactored session handling into a separate class

pull/134/head
jamespcole 2015-05-19 03:57:35 +10:00
parent 8431fd822f
commit 80f0c42844
2 changed files with 77 additions and 67 deletions

View File

@ -19,7 +19,6 @@ http:
api_password: mypass
# Set to 1 to enable development mode
# development: 1
# sessions_enabled: True
light:
# platform: hue

View File

@ -112,7 +112,6 @@ DATA_API_PASSWORD = 'api_password'
# Throttling time in seconds for expired sessions check
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
SESSION_TIMEOUT_SECONDS = 1800
SESSION_LOCK = threading.RLock()
SESSION_KEY = 'sessionId'
_LOGGER = logging.getLogger(__name__)
@ -138,7 +137,7 @@ def setup(hass, config=None):
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False)
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
server = HomeAssistantHTTPServer(
(server_host, server_port), RequestHandler, hass, api_password,
@ -175,8 +174,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
self.development = development
self.no_password_set = no_password_set
self.paths = []
self.sessions_enabled = sessions_enabled
self._sessions = {}
self.sessions = SessionStore(sessions_enabled)
# We will lazy init this one if needed
self.event_forwarder = None
@ -204,51 +202,6 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
""" Regitsters a path wit the server. """
self.paths.append((method, url, callback, require_auth))
@Throttle(MIN_SEC_SESSION_CLEARING)
def remove_expired_sessions(self):
""" Reemove any expired sessions. """
if SESSION_LOCK.acquire(False):
try:
keys = []
for key in self._sessions.keys():
keys.append(key)
for key in keys:
if self._sessions[key].is_expired:
del self._sessions[key]
_LOGGER.info("Cleared expired session %s", key)
finally:
SESSION_LOCK.release()
def add_session(self, key, session):
""" Add a new session to the list of tracked sessions """
self.remove_expired_sessions()
try:
SESSION_LOCK.acquire()
self._sessions[key] = session
finally:
SESSION_LOCK.release()
def get_session(self, key):
""" get a session by key """
self.remove_expired_sessions()
session = self._sessions.get(key, None)
if session is not None and session.is_expired:
return None
return session
def create_session(self, api_password):
""" Creates a new session and adds it to the sessions """
if self.sessions_enabled is not True:
return None
chars = string.ascii_letters + string.digits
session_id = ''.join([random.choice(chars) for i in range(20)])
session = ServerSession(session_id)
session.cookie_values['api_password'] = api_password
self.add_session(session_id, session)
return session
# pylint: disable=too-many-public-methods,too-many-locals
class RequestHandler(SimpleHTTPRequestHandler):
@ -304,7 +257,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
api_password = data[DATA_API_PASSWORD]
if not api_password and self._session is not None:
api_password = self._session.cookie_values.get('api_password')
api_password = self._session.cookie_values.get(
CONF_API_PASSWORD)
if '_METHOD' in data:
method = data.pop('_METHOD')
@ -345,7 +299,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
else:
if self._session is None and require_auth:
self._session = self.server.create_session(api_password)
self._session = self.server.sessions.create_session(
api_password)
handle_request_method(self, path_match, data)
@ -459,14 +414,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
def set_session_cookie_header(self):
""" Add the header for the session cookie """
if self.server.sessions_enabled and self._session is not None:
cookie = cookies.SimpleCookie()
existing_sess_id = None
if self.headers.get('Cookie', None) is not None:
cookie.load(self.headers.get('Cookie'))
if cookie.get(SESSION_KEY, False):
existing_sess_id = cookie[SESSION_KEY].value
if self.server.sessions.enabled and self._session is not None:
existing_sess_id = self.get_current_session_id()
if existing_sess_id != self._session.session_id:
self.send_header(
@ -475,21 +424,32 @@ class RequestHandler(SimpleHTTPRequestHandler):
def get_session(self):
""" Get the requested session object from cookie value """
if self.server.sessions_enabled is not True:
if self.server.sessions.enabled is not True:
return None
session_id = self.get_current_session_id()
if session_id is not None:
session = self.server.sessions.get_session(session_id)
if session is not None:
session.reset_expiry()
return session
else:
return None
def get_current_session_id(self):
"""
Extracts the current session id from the
cookie or returns None if not set
"""
cookie = cookies.SimpleCookie()
if self.headers.get('Cookie', None) is not None:
cookie.load(self.headers.get("Cookie"))
if cookie.get(SESSION_KEY, False):
session = self.server.get_session(cookie[SESSION_KEY].value)
if session is not None:
session.reset_expiry()
return session
else:
return None
return cookie[SESSION_KEY].value
return None
class ServerSession:
@ -510,3 +470,54 @@ class ServerSession:
def is_expired(self):
""" Return true if the session is expired based on the expiry time """
return self._expiry < date_util.utcnow()
class SessionStore:
""" Responsible for storing and retrieving http sessions """
def __init__(self, enabled=True):
""" Set up the session store """
self._sessions = {}
self.enabled = enabled
self.session_lock = threading.RLock()
@Throttle(MIN_SEC_SESSION_CLEARING)
def remove_expired_sessions(self):
""" Remove any expired sessions. """
if self.session_lock.acquire(False):
try:
keys = []
for key in self._sessions.keys():
keys.append(key)
for key in keys:
if self._sessions[key].is_expired:
del self._sessions[key]
_LOGGER.info("Cleared expired session %s", key)
finally:
self.session_lock.release()
def add_session(self, key, session):
""" Add a new session to the list of tracked sessions """
self.remove_expired_sessions()
with self.session_lock:
self._sessions[key] = session
def get_session(self, key):
""" get a session by key """
self.remove_expired_sessions()
session = self._sessions.get(key, None)
if session is not None and session.is_expired:
return None
return session
def create_session(self, api_password):
""" Creates a new session and adds it to the sessions """
if self.enabled is not True:
return None
chars = string.ascii_letters + string.digits
session_id = ''.join([random.choice(chars) for i in range(20)])
session = ServerSession(session_id)
session.cookie_values[CONF_API_PASSWORD] = api_password
self.add_session(session_id, session)
return session