From 80f0c42844adc8147791ebffb54758d7a3995e60 Mon Sep 17 00:00:00 2001 From: jamespcole Date: Tue, 19 May 2015 03:57:35 +1000 Subject: [PATCH] Refactored session handling into a separate class --- config/configuration.yaml.example | 1 - homeassistant/components/http.py | 143 ++++++++++++++++-------------- 2 files changed, 77 insertions(+), 67 deletions(-) diff --git a/config/configuration.yaml.example b/config/configuration.yaml.example index b69dd3fd203..c99f760f21f 100644 --- a/config/configuration.yaml.example +++ b/config/configuration.yaml.example @@ -19,7 +19,6 @@ http: api_password: mypass # Set to 1 to enable development mode # development: 1 - # sessions_enabled: True light: # platform: hue diff --git a/homeassistant/components/http.py b/homeassistant/components/http.py index 2412ff5875f..951608d4a8f 100644 --- a/homeassistant/components/http.py +++ b/homeassistant/components/http.py @@ -112,7 +112,6 @@ DATA_API_PASSWORD = 'api_password' # Throttling time in seconds for expired sessions check MIN_SEC_SESSION_CLEARING = timedelta(seconds=20) SESSION_TIMEOUT_SECONDS = 1800 -SESSION_LOCK = threading.RLock() SESSION_KEY = 'sessionId' _LOGGER = logging.getLogger(__name__) @@ -138,7 +137,7 @@ def setup(hass, config=None): development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1" - sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False) + sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True) server = HomeAssistantHTTPServer( (server_host, server_port), RequestHandler, hass, api_password, @@ -175,8 +174,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): self.development = development self.no_password_set = no_password_set self.paths = [] - self.sessions_enabled = sessions_enabled - self._sessions = {} + self.sessions = SessionStore(sessions_enabled) # We will lazy init this one if needed self.event_forwarder = None @@ -204,51 +202,6 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): """ Regitsters a path wit the server. """ self.paths.append((method, url, callback, require_auth)) - @Throttle(MIN_SEC_SESSION_CLEARING) - def remove_expired_sessions(self): - """ Reemove any expired sessions. """ - if SESSION_LOCK.acquire(False): - try: - keys = [] - for key in self._sessions.keys(): - keys.append(key) - - for key in keys: - if self._sessions[key].is_expired: - del self._sessions[key] - _LOGGER.info("Cleared expired session %s", key) - finally: - SESSION_LOCK.release() - - def add_session(self, key, session): - """ Add a new session to the list of tracked sessions """ - self.remove_expired_sessions() - try: - SESSION_LOCK.acquire() - self._sessions[key] = session - finally: - SESSION_LOCK.release() - - def get_session(self, key): - """ get a session by key """ - self.remove_expired_sessions() - session = self._sessions.get(key, None) - if session is not None and session.is_expired: - return None - return session - - def create_session(self, api_password): - """ Creates a new session and adds it to the sessions """ - if self.sessions_enabled is not True: - return None - - chars = string.ascii_letters + string.digits - session_id = ''.join([random.choice(chars) for i in range(20)]) - session = ServerSession(session_id) - session.cookie_values['api_password'] = api_password - self.add_session(session_id, session) - return session - # pylint: disable=too-many-public-methods,too-many-locals class RequestHandler(SimpleHTTPRequestHandler): @@ -304,7 +257,8 @@ class RequestHandler(SimpleHTTPRequestHandler): api_password = data[DATA_API_PASSWORD] if not api_password and self._session is not None: - api_password = self._session.cookie_values.get('api_password') + api_password = self._session.cookie_values.get( + CONF_API_PASSWORD) if '_METHOD' in data: method = data.pop('_METHOD') @@ -345,7 +299,8 @@ class RequestHandler(SimpleHTTPRequestHandler): else: if self._session is None and require_auth: - self._session = self.server.create_session(api_password) + self._session = self.server.sessions.create_session( + api_password) handle_request_method(self, path_match, data) @@ -459,14 +414,8 @@ class RequestHandler(SimpleHTTPRequestHandler): def set_session_cookie_header(self): """ Add the header for the session cookie """ - if self.server.sessions_enabled and self._session is not None: - cookie = cookies.SimpleCookie() - existing_sess_id = None - - if self.headers.get('Cookie', None) is not None: - cookie.load(self.headers.get('Cookie')) - if cookie.get(SESSION_KEY, False): - existing_sess_id = cookie[SESSION_KEY].value + if self.server.sessions.enabled and self._session is not None: + existing_sess_id = self.get_current_session_id() if existing_sess_id != self._session.session_id: self.send_header( @@ -475,21 +424,32 @@ class RequestHandler(SimpleHTTPRequestHandler): def get_session(self): """ Get the requested session object from cookie value """ - if self.server.sessions_enabled is not True: + if self.server.sessions.enabled is not True: return None + session_id = self.get_current_session_id() + if session_id is not None: + session = self.server.sessions.get_session(session_id) + if session is not None: + session.reset_expiry() + return session + else: + return None + + def get_current_session_id(self): + """ + Extracts the current session id from the + cookie or returns None if not set + """ cookie = cookies.SimpleCookie() if self.headers.get('Cookie', None) is not None: cookie.load(self.headers.get("Cookie")) if cookie.get(SESSION_KEY, False): - session = self.server.get_session(cookie[SESSION_KEY].value) - if session is not None: - session.reset_expiry() - return session - else: - return None + return cookie[SESSION_KEY].value + + return None class ServerSession: @@ -510,3 +470,54 @@ class ServerSession: def is_expired(self): """ Return true if the session is expired based on the expiry time """ return self._expiry < date_util.utcnow() + + +class SessionStore: + """ Responsible for storing and retrieving http sessions """ + def __init__(self, enabled=True): + """ Set up the session store """ + self._sessions = {} + self.enabled = enabled + self.session_lock = threading.RLock() + + @Throttle(MIN_SEC_SESSION_CLEARING) + def remove_expired_sessions(self): + """ Remove any expired sessions. """ + if self.session_lock.acquire(False): + try: + keys = [] + for key in self._sessions.keys(): + keys.append(key) + + for key in keys: + if self._sessions[key].is_expired: + del self._sessions[key] + _LOGGER.info("Cleared expired session %s", key) + finally: + self.session_lock.release() + + def add_session(self, key, session): + """ Add a new session to the list of tracked sessions """ + self.remove_expired_sessions() + with self.session_lock: + self._sessions[key] = session + + def get_session(self, key): + """ get a session by key """ + self.remove_expired_sessions() + session = self._sessions.get(key, None) + if session is not None and session.is_expired: + return None + return session + + def create_session(self, api_password): + """ Creates a new session and adds it to the sessions """ + if self.enabled is not True: + return None + + chars = string.ascii_letters + string.digits + session_id = ''.join([random.choice(chars) for i in range(20)]) + session = ServerSession(session_id) + session.cookie_values[CONF_API_PASSWORD] = api_password + self.add_session(session_id, session) + return session