"""Handle the auth of a connection.""" from __future__ import annotations from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, Final from aiohttp.web import Request import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.const import __version__ from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.helpers.json import json_bytes from homeassistant.util.json import JsonValueType from .connection import ActiveConnection from .error import Disconnect if TYPE_CHECKING: from .http import WebSocketAdapter TYPE_AUTH: Final = "auth" TYPE_AUTH_INVALID: Final = "auth_invalid" TYPE_AUTH_OK: Final = "auth_ok" TYPE_AUTH_REQUIRED: Final = "auth_required" AUTH_MESSAGE_SCHEMA: Final = vol.Schema( { vol.Required("type"): TYPE_AUTH, vol.Exclusive("api_password", "auth"): str, vol.Exclusive("access_token", "auth"): str, } ) AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__}) AUTH_REQUIRED_MESSAGE = json_bytes( {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__} ) def auth_invalid_message(message: str) -> bytes: """Return an auth_invalid message.""" return json_bytes({"type": TYPE_AUTH_INVALID, "message": message}) class AuthPhase: """Connection that requires client to authenticate first.""" def __init__( self, logger: WebSocketAdapter, hass: HomeAssistant, send_message: Callable[[bytes | str | dict[str, Any]], None], cancel_ws: CALLBACK_TYPE, request: Request, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], ) -> None: """Initialize the authenticated connection.""" self._hass = hass # send_message will send a message to the client via the queue. self._send_message = send_message self._cancel_ws = cancel_ws self._logger = logger self._request = request # send_bytes_text will directly send a message to the client. self._send_bytes_text = send_bytes_text async def async_handle(self, msg: JsonValueType) -> ActiveConnection: """Handle authentication.""" try: valid_msg = AUTH_MESSAGE_SCHEMA(msg) except vol.Invalid as err: error_msg = ( f"Auth message incorrectly formatted: {humanize_error(msg, err)}" ) self._logger.warning(error_msg) await self._send_bytes_text(auth_invalid_message(error_msg)) raise Disconnect from err if (access_token := valid_msg.get("access_token")) and ( refresh_token := self._hass.auth.async_validate_access_token(access_token) ): conn = ActiveConnection( self._logger, self._hass, self._send_message, refresh_token.user, refresh_token, ) conn.subscriptions[ "auth" ] = self._hass.auth.async_register_revoke_token_callback( refresh_token.id, self._cancel_ws ) await self._send_bytes_text(AUTH_OK_MESSAGE) self._logger.debug("Auth OK") process_success_login(self._request) return conn await self._send_bytes_text( auth_invalid_message("Invalid access token or password") ) await process_wrong_login(self._request) raise Disconnect