206 lines
6.8 KiB
Python
206 lines
6.8 KiB
Python
"""Session auth module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta
|
|
import secrets
|
|
from typing import TYPE_CHECKING, Final, TypedDict
|
|
|
|
from aiohttp.web import Request
|
|
from aiohttp_session import Session, get_session, new_session
|
|
from cryptography.fernet import Fernet
|
|
|
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
|
from homeassistant.helpers.event import async_call_later
|
|
from homeassistant.helpers.storage import Store
|
|
from homeassistant.util import dt as dt_util
|
|
|
|
from .models import RefreshToken
|
|
|
|
if TYPE_CHECKING:
|
|
from . import AuthManager
|
|
|
|
|
|
TEMP_TIMEOUT = timedelta(minutes=5)
|
|
TEMP_TIMEOUT_SECONDS = TEMP_TIMEOUT.total_seconds()
|
|
|
|
SESSION_ID = "id"
|
|
STORAGE_VERSION = 1
|
|
STORAGE_KEY = "auth.session"
|
|
|
|
|
|
class StrictConnectionTempSessionData:
|
|
"""Data for accessing unauthorized resources for a short period of time."""
|
|
|
|
__slots__ = ("cancel_remove", "absolute_expiry")
|
|
|
|
def __init__(self, cancel_remove: CALLBACK_TYPE) -> None:
|
|
"""Initialize the temp session data."""
|
|
self.cancel_remove: Final[CALLBACK_TYPE] = cancel_remove
|
|
self.absolute_expiry: Final[datetime] = dt_util.utcnow() + TEMP_TIMEOUT
|
|
|
|
|
|
class StoreData(TypedDict):
|
|
"""Data to store."""
|
|
|
|
unauthorized_sessions: dict[str, str]
|
|
key: str
|
|
|
|
|
|
class SessionManager:
|
|
"""Session manager."""
|
|
|
|
def __init__(self, hass: HomeAssistant, auth: AuthManager) -> None:
|
|
"""Initialize the strict connection manager."""
|
|
self._auth = auth
|
|
self._hass = hass
|
|
self._temp_sessions: dict[str, StrictConnectionTempSessionData] = {}
|
|
self._strict_connection_sessions: dict[str, str] = {}
|
|
self._store = Store[StoreData](
|
|
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
|
)
|
|
self._key: str | None = None
|
|
self._refresh_token_revoke_callbacks: dict[str, CALLBACK_TYPE] = {}
|
|
|
|
@property
|
|
def key(self) -> str:
|
|
"""Return the encryption key."""
|
|
if self._key is None:
|
|
self._key = Fernet.generate_key().decode()
|
|
self._async_schedule_save()
|
|
return self._key
|
|
|
|
async def async_validate_request_for_strict_connection_session(
|
|
self,
|
|
request: Request,
|
|
) -> bool:
|
|
"""Check if a request has a valid strict connection session."""
|
|
session = await get_session(request)
|
|
if session.new or session.empty:
|
|
return False
|
|
result = self.async_validate_strict_connection_session(session)
|
|
if result is False:
|
|
session.invalidate()
|
|
return result
|
|
|
|
@callback
|
|
def async_validate_strict_connection_session(
|
|
self,
|
|
session: Session,
|
|
) -> bool:
|
|
"""Validate a strict connection session."""
|
|
if not (session_id := session.get(SESSION_ID)):
|
|
return False
|
|
|
|
if token_id := self._strict_connection_sessions.get(session_id):
|
|
if self._auth.async_get_refresh_token(token_id):
|
|
return True
|
|
# refresh token is invalid, delete entry
|
|
self._strict_connection_sessions.pop(session_id)
|
|
self._async_schedule_save()
|
|
|
|
if data := self._temp_sessions.get(session_id):
|
|
if dt_util.utcnow() <= data.absolute_expiry:
|
|
return True
|
|
# session expired, delete entry
|
|
self._temp_sessions.pop(session_id).cancel_remove()
|
|
|
|
return False
|
|
|
|
@callback
|
|
def _async_register_revoke_token_callback(self, refresh_token_id: str) -> None:
|
|
"""Register a callback to revoke all sessions for a refresh token."""
|
|
if refresh_token_id in self._refresh_token_revoke_callbacks:
|
|
return
|
|
|
|
@callback
|
|
def async_invalidate_auth_sessions() -> None:
|
|
"""Invalidate all sessions for a refresh token."""
|
|
self._strict_connection_sessions = {
|
|
session_id: token_id
|
|
for session_id, token_id in self._strict_connection_sessions.items()
|
|
if token_id != refresh_token_id
|
|
}
|
|
self._async_schedule_save()
|
|
|
|
self._refresh_token_revoke_callbacks[refresh_token_id] = (
|
|
self._auth.async_register_revoke_token_callback(
|
|
refresh_token_id, async_invalidate_auth_sessions
|
|
)
|
|
)
|
|
|
|
async def async_create_session(
|
|
self,
|
|
request: Request,
|
|
refresh_token: RefreshToken,
|
|
) -> None:
|
|
"""Create new session for given refresh token.
|
|
|
|
Caller needs to make sure that the refresh token is valid.
|
|
By creating a session, we are implicitly revoking all other
|
|
sessions for the given refresh token as there is one refresh
|
|
token per device/user case.
|
|
"""
|
|
self._strict_connection_sessions = {
|
|
session_id: token_id
|
|
for session_id, token_id in self._strict_connection_sessions.items()
|
|
if token_id != refresh_token.id
|
|
}
|
|
|
|
self._async_register_revoke_token_callback(refresh_token.id)
|
|
session_id = await self._async_create_new_session(request)
|
|
self._strict_connection_sessions[session_id] = refresh_token.id
|
|
self._async_schedule_save()
|
|
|
|
async def async_create_temp_unauthorized_session(self, request: Request) -> None:
|
|
"""Create a temporary unauthorized session."""
|
|
session_id = await self._async_create_new_session(
|
|
request, max_age=int(TEMP_TIMEOUT_SECONDS)
|
|
)
|
|
|
|
@callback
|
|
def remove(_: datetime) -> None:
|
|
self._temp_sessions.pop(session_id, None)
|
|
|
|
self._temp_sessions[session_id] = StrictConnectionTempSessionData(
|
|
async_call_later(self._hass, TEMP_TIMEOUT_SECONDS, remove)
|
|
)
|
|
|
|
async def _async_create_new_session(
|
|
self,
|
|
request: Request,
|
|
*,
|
|
max_age: int | None = None,
|
|
) -> str:
|
|
session_id = secrets.token_hex(64)
|
|
|
|
session = await new_session(request)
|
|
session[SESSION_ID] = session_id
|
|
if max_age is not None:
|
|
session.max_age = max_age
|
|
return session_id
|
|
|
|
@callback
|
|
def _async_schedule_save(self, delay: float = 1) -> None:
|
|
"""Save sessions."""
|
|
self._store.async_delay_save(self._data_to_save, delay)
|
|
|
|
@callback
|
|
def _data_to_save(self) -> StoreData:
|
|
"""Return the data to store."""
|
|
return StoreData(
|
|
unauthorized_sessions=self._strict_connection_sessions,
|
|
key=self.key,
|
|
)
|
|
|
|
async def async_setup(self) -> None:
|
|
"""Set up session manager."""
|
|
data = await self._store.async_load()
|
|
if data is None:
|
|
return
|
|
|
|
self._key = data["key"]
|
|
self._strict_connection_sessions = data["unauthorized_sessions"]
|
|
for token_id in self._strict_connection_sessions.values():
|
|
self._async_register_revoke_token_callback(token_id)
|