108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
"""Handle the auth of a connection."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Coroutine
|
|
from typing import TYPE_CHECKING, Any, Final
|
|
|
|
from aiohttp.web import Request
|
|
import voluptuous as vol
|
|
from voluptuous.humanize import humanize_error
|
|
|
|
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
|
from homeassistant.const import __version__
|
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
|
from homeassistant.helpers.json import json_bytes
|
|
from homeassistant.util.json import JsonValueType
|
|
|
|
from .connection import ActiveConnection
|
|
from .error import Disconnect
|
|
|
|
if TYPE_CHECKING:
|
|
from .http import WebSocketAdapter
|
|
|
|
|
|
TYPE_AUTH: Final = "auth"
|
|
TYPE_AUTH_INVALID: Final = "auth_invalid"
|
|
TYPE_AUTH_OK: Final = "auth_ok"
|
|
TYPE_AUTH_REQUIRED: Final = "auth_required"
|
|
|
|
AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
|
|
{
|
|
vol.Required("type"): TYPE_AUTH,
|
|
vol.Exclusive("api_password", "auth"): str,
|
|
vol.Exclusive("access_token", "auth"): str,
|
|
}
|
|
)
|
|
|
|
AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__})
|
|
AUTH_REQUIRED_MESSAGE = json_bytes(
|
|
{"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
|
)
|
|
|
|
|
|
def auth_invalid_message(message: str) -> bytes:
|
|
"""Return an auth_invalid message."""
|
|
return json_bytes({"type": TYPE_AUTH_INVALID, "message": message})
|
|
|
|
|
|
class AuthPhase:
|
|
"""Connection that requires client to authenticate first."""
|
|
|
|
def __init__(
|
|
self,
|
|
logger: WebSocketAdapter,
|
|
hass: HomeAssistant,
|
|
send_message: Callable[[bytes | str | dict[str, Any]], None],
|
|
cancel_ws: CALLBACK_TYPE,
|
|
request: Request,
|
|
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
|
) -> None:
|
|
"""Initialize the authenticated connection."""
|
|
self._hass = hass
|
|
# send_message will send a message to the client via the queue.
|
|
self._send_message = send_message
|
|
self._cancel_ws = cancel_ws
|
|
self._logger = logger
|
|
self._request = request
|
|
# send_bytes_text will directly send a message to the client.
|
|
self._send_bytes_text = send_bytes_text
|
|
|
|
async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
|
|
"""Handle authentication."""
|
|
try:
|
|
valid_msg = AUTH_MESSAGE_SCHEMA(msg)
|
|
except vol.Invalid as err:
|
|
error_msg = (
|
|
f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
|
|
)
|
|
self._logger.warning(error_msg)
|
|
await self._send_bytes_text(auth_invalid_message(error_msg))
|
|
raise Disconnect from err
|
|
|
|
if (access_token := valid_msg.get("access_token")) and (
|
|
refresh_token := self._hass.auth.async_validate_access_token(access_token)
|
|
):
|
|
conn = ActiveConnection(
|
|
self._logger,
|
|
self._hass,
|
|
self._send_message,
|
|
refresh_token.user,
|
|
refresh_token,
|
|
)
|
|
conn.subscriptions[
|
|
"auth"
|
|
] = self._hass.auth.async_register_revoke_token_callback(
|
|
refresh_token.id, self._cancel_ws
|
|
)
|
|
await self._send_bytes_text(AUTH_OK_MESSAGE)
|
|
self._logger.debug("Auth OK")
|
|
process_success_login(self._request)
|
|
return conn
|
|
|
|
await self._send_bytes_text(
|
|
auth_invalid_message("Invalid access token or password")
|
|
)
|
|
await process_wrong_login(self._request)
|
|
raise Disconnect
|