Refactored session handling into a separate class
parent
8431fd822f
commit
80f0c42844
|
@ -19,7 +19,6 @@ http:
|
|||
api_password: mypass
|
||||
# Set to 1 to enable development mode
|
||||
# development: 1
|
||||
# sessions_enabled: True
|
||||
|
||||
light:
|
||||
# platform: hue
|
||||
|
|
|
@ -112,7 +112,6 @@ DATA_API_PASSWORD = 'api_password'
|
|||
# Throttling time in seconds for expired sessions check
|
||||
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
||||
SESSION_TIMEOUT_SECONDS = 1800
|
||||
SESSION_LOCK = threading.RLock()
|
||||
SESSION_KEY = 'sessionId'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -138,7 +137,7 @@ def setup(hass, config=None):
|
|||
|
||||
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
||||
|
||||
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False)
|
||||
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
|
||||
|
||||
server = HomeAssistantHTTPServer(
|
||||
(server_host, server_port), RequestHandler, hass, api_password,
|
||||
|
@ -175,8 +174,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||
self.development = development
|
||||
self.no_password_set = no_password_set
|
||||
self.paths = []
|
||||
self.sessions_enabled = sessions_enabled
|
||||
self._sessions = {}
|
||||
self.sessions = SessionStore(sessions_enabled)
|
||||
|
||||
# We will lazy init this one if needed
|
||||
self.event_forwarder = None
|
||||
|
@ -204,51 +202,6 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||
""" Regitsters a path wit the server. """
|
||||
self.paths.append((method, url, callback, require_auth))
|
||||
|
||||
@Throttle(MIN_SEC_SESSION_CLEARING)
|
||||
def remove_expired_sessions(self):
|
||||
""" Reemove any expired sessions. """
|
||||
if SESSION_LOCK.acquire(False):
|
||||
try:
|
||||
keys = []
|
||||
for key in self._sessions.keys():
|
||||
keys.append(key)
|
||||
|
||||
for key in keys:
|
||||
if self._sessions[key].is_expired:
|
||||
del self._sessions[key]
|
||||
_LOGGER.info("Cleared expired session %s", key)
|
||||
finally:
|
||||
SESSION_LOCK.release()
|
||||
|
||||
def add_session(self, key, session):
|
||||
""" Add a new session to the list of tracked sessions """
|
||||
self.remove_expired_sessions()
|
||||
try:
|
||||
SESSION_LOCK.acquire()
|
||||
self._sessions[key] = session
|
||||
finally:
|
||||
SESSION_LOCK.release()
|
||||
|
||||
def get_session(self, key):
|
||||
""" get a session by key """
|
||||
self.remove_expired_sessions()
|
||||
session = self._sessions.get(key, None)
|
||||
if session is not None and session.is_expired:
|
||||
return None
|
||||
return session
|
||||
|
||||
def create_session(self, api_password):
|
||||
""" Creates a new session and adds it to the sessions """
|
||||
if self.sessions_enabled is not True:
|
||||
return None
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
session_id = ''.join([random.choice(chars) for i in range(20)])
|
||||
session = ServerSession(session_id)
|
||||
session.cookie_values['api_password'] = api_password
|
||||
self.add_session(session_id, session)
|
||||
return session
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods,too-many-locals
|
||||
class RequestHandler(SimpleHTTPRequestHandler):
|
||||
|
@ -304,7 +257,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
api_password = data[DATA_API_PASSWORD]
|
||||
|
||||
if not api_password and self._session is not None:
|
||||
api_password = self._session.cookie_values.get('api_password')
|
||||
api_password = self._session.cookie_values.get(
|
||||
CONF_API_PASSWORD)
|
||||
|
||||
if '_METHOD' in data:
|
||||
method = data.pop('_METHOD')
|
||||
|
@ -345,7 +299,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
|
||||
else:
|
||||
if self._session is None and require_auth:
|
||||
self._session = self.server.create_session(api_password)
|
||||
self._session = self.server.sessions.create_session(
|
||||
api_password)
|
||||
|
||||
handle_request_method(self, path_match, data)
|
||||
|
||||
|
@ -459,14 +414,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
|
||||
def set_session_cookie_header(self):
|
||||
""" Add the header for the session cookie """
|
||||
if self.server.sessions_enabled and self._session is not None:
|
||||
cookie = cookies.SimpleCookie()
|
||||
existing_sess_id = None
|
||||
|
||||
if self.headers.get('Cookie', None) is not None:
|
||||
cookie.load(self.headers.get('Cookie'))
|
||||
if cookie.get(SESSION_KEY, False):
|
||||
existing_sess_id = cookie[SESSION_KEY].value
|
||||
if self.server.sessions.enabled and self._session is not None:
|
||||
existing_sess_id = self.get_current_session_id()
|
||||
|
||||
if existing_sess_id != self._session.session_id:
|
||||
self.send_header(
|
||||
|
@ -475,21 +424,32 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||
|
||||
def get_session(self):
|
||||
""" Get the requested session object from cookie value """
|
||||
if self.server.sessions_enabled is not True:
|
||||
if self.server.sessions.enabled is not True:
|
||||
return None
|
||||
|
||||
session_id = self.get_current_session_id()
|
||||
if session_id is not None:
|
||||
session = self.server.sessions.get_session(session_id)
|
||||
if session is not None:
|
||||
session.reset_expiry()
|
||||
return session
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_current_session_id(self):
|
||||
"""
|
||||
Extracts the current session id from the
|
||||
cookie or returns None if not set
|
||||
"""
|
||||
cookie = cookies.SimpleCookie()
|
||||
|
||||
if self.headers.get('Cookie', None) is not None:
|
||||
cookie.load(self.headers.get("Cookie"))
|
||||
|
||||
if cookie.get(SESSION_KEY, False):
|
||||
session = self.server.get_session(cookie[SESSION_KEY].value)
|
||||
if session is not None:
|
||||
session.reset_expiry()
|
||||
return session
|
||||
else:
|
||||
return None
|
||||
return cookie[SESSION_KEY].value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ServerSession:
|
||||
|
@ -510,3 +470,54 @@ class ServerSession:
|
|||
def is_expired(self):
|
||||
""" Return true if the session is expired based on the expiry time """
|
||||
return self._expiry < date_util.utcnow()
|
||||
|
||||
|
||||
class SessionStore:
|
||||
""" Responsible for storing and retrieving http sessions """
|
||||
def __init__(self, enabled=True):
|
||||
""" Set up the session store """
|
||||
self._sessions = {}
|
||||
self.enabled = enabled
|
||||
self.session_lock = threading.RLock()
|
||||
|
||||
@Throttle(MIN_SEC_SESSION_CLEARING)
|
||||
def remove_expired_sessions(self):
|
||||
""" Remove any expired sessions. """
|
||||
if self.session_lock.acquire(False):
|
||||
try:
|
||||
keys = []
|
||||
for key in self._sessions.keys():
|
||||
keys.append(key)
|
||||
|
||||
for key in keys:
|
||||
if self._sessions[key].is_expired:
|
||||
del self._sessions[key]
|
||||
_LOGGER.info("Cleared expired session %s", key)
|
||||
finally:
|
||||
self.session_lock.release()
|
||||
|
||||
def add_session(self, key, session):
|
||||
""" Add a new session to the list of tracked sessions """
|
||||
self.remove_expired_sessions()
|
||||
with self.session_lock:
|
||||
self._sessions[key] = session
|
||||
|
||||
def get_session(self, key):
|
||||
""" get a session by key """
|
||||
self.remove_expired_sessions()
|
||||
session = self._sessions.get(key, None)
|
||||
if session is not None and session.is_expired:
|
||||
return None
|
||||
return session
|
||||
|
||||
def create_session(self, api_password):
|
||||
""" Creates a new session and adds it to the sessions """
|
||||
if self.enabled is not True:
|
||||
return None
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
session_id = ''.join([random.choice(chars) for i in range(20)])
|
||||
session = ServerSession(session_id)
|
||||
session.cookie_values[CONF_API_PASSWORD] = api_password
|
||||
self.add_session(session_id, session)
|
||||
return session
|
||||
|
|
Loading…
Reference in New Issue